feat(retrieval): add grade matching and same-family part disambiguation.
Boost scores when query grade matches standard title grade, penalize mismatches. Add part disambiguation to correctly route queries to specific standard parts (e.g., IS 12269 (Part 1) vs (Part 2)). Regenerate retrieval results with improved ranking.
This commit is contained in:
+46
-3
@@ -54,6 +54,10 @@ _TOP_N_FINAL = 5
|
||||
_SHORT_CHUNK_THRESHOLD = 40 # body words below this get a penalty
|
||||
_SHORT_CHUNK_PENALTY = 0.15
|
||||
|
||||
_GRADE_RE = re.compile(r"\b(33|43|53)\b")
|
||||
_PART_BASE_RE = re.compile(r"IS\s+(\d+)\s*\(", re.IGNORECASE)
|
||||
_STOP_WORDS = {"and", "or", "for", "the", "of", "in", "a", "an", "to"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text helpers
|
||||
@@ -205,6 +209,9 @@ class Retriever:
|
||||
query_lower = query.lower()
|
||||
query_words = set(query_tokens)
|
||||
|
||||
grade_in_query = set(_GRADE_RE.findall(query))
|
||||
sig_query = query_words - _STOP_WORDS
|
||||
|
||||
for cid, base in list(chunk_scores.items()):
|
||||
chunk = self.idx.chunks[cid]
|
||||
sid_norm = _norm_std_id(chunk["standard_id"])
|
||||
@@ -227,9 +234,7 @@ class Retriever:
|
||||
|
||||
# Strong boost: majority of title words present in query — likely
|
||||
# the most on-point standard even if its chunk body is polluted.
|
||||
stop = {"and", "or", "for", "the", "of", "in", "a", "an", "to"}
|
||||
sig_title = full_title_tokens - stop
|
||||
sig_query = query_words - stop
|
||||
sig_title = full_title_tokens - _STOP_WORDS
|
||||
if sig_title and len(sig_title & sig_query) / len(sig_title) >= 0.6:
|
||||
bonus += 0.25
|
||||
|
||||
@@ -245,6 +250,13 @@ class Retriever:
|
||||
if body_wc < _SHORT_CHUNK_THRESHOLD:
|
||||
bonus -= _SHORT_CHUNK_PENALTY
|
||||
|
||||
if grade_in_query:
|
||||
grade_in_title = set(_GRADE_RE.findall(full_title))
|
||||
if grade_in_title and grade_in_title == grade_in_query:
|
||||
bonus += 0.35
|
||||
elif grade_in_title and grade_in_title.isdisjoint(grade_in_query):
|
||||
bonus -= 0.40
|
||||
|
||||
chunk_scores[cid] = base + bonus
|
||||
|
||||
# --- Group by standard_id, keep best chunk score ---
|
||||
@@ -257,6 +269,37 @@ class Retriever:
|
||||
std_best[sid] = score
|
||||
std_chunk_repr[sid] = chunk
|
||||
|
||||
# --- Same-family Part disambiguation ---
|
||||
base_to_sids: dict[str, list[str]] = {}
|
||||
for sid in std_best:
|
||||
if "(" not in sid:
|
||||
continue
|
||||
m = _PART_BASE_RE.match(sid)
|
||||
if m:
|
||||
base_to_sids.setdefault(m.group(1), []).append(sid)
|
||||
for sids in base_to_sids.values():
|
||||
if len(sids) < 2:
|
||||
continue
|
||||
# Only act when all siblings share the same title
|
||||
titles = {_norm_std_id(std_chunk_repr[s].get("full_title") or
|
||||
std_chunk_repr[s].get("title", ""))
|
||||
for s in sids}
|
||||
if len(titles) > 1:
|
||||
continue
|
||||
sib_kws = {s: self.idx.std_keywords.get(_norm_std_id(s), set()) for s in sids}
|
||||
disc_weight: dict[str, float] = {s: 0.0 for s in sids}
|
||||
for token in query_words:
|
||||
owners = [s for s in sids if token in sib_kws[s]]
|
||||
if len(owners) == 1:
|
||||
# IDF floors at 0 for very common terms; +1 keeps weight positive
|
||||
idf_weight = max(self.idx.bm25.idf.get(token, 0.0), 0.0) + 1.0
|
||||
disc_weight[owners[0]] += idf_weight
|
||||
max_w = max(disc_weight.values())
|
||||
if max_w == 0.0:
|
||||
continue
|
||||
for s, w in disc_weight.items():
|
||||
std_best[s] += w * 0.22 - (max_w - w) * 0.14
|
||||
|
||||
# --- Sort and take top N ---
|
||||
ranked = sorted(std_best.items(), key=lambda x: x[1], reverse=True)[:top_n]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user