feat(retrieval): add part-number discriminator and improve part disambiguation.

This commit is contained in:
K
2026-05-04 15:45:02 +05:30
parent fdae5d2318
commit 458bd93434
+28 -2
View File
@@ -56,6 +56,7 @@ _SHORT_CHUNK_PENALTY = 0.15
_GRADE_RE = re.compile(r"\b(33|43|53)\b")
_PART_BASE_RE = re.compile(r"IS\s+(\d+)\s*\(", re.IGNORECASE)
_PART_NUM_RE = re.compile(r"\bpart[\s\-–—]*(\d+)\b", re.IGNORECASE)
_STOP_WORDS = {"and", "or", "for", "the", "of", "in", "a", "an", "to"}
@@ -212,6 +213,9 @@ class Retriever:
grade_in_query = set(_GRADE_RE.findall(query))
sig_query = query_words - _STOP_WORDS
# Detect explicit part number in query (e.g. "Part 2", "Part 1")
part_nums_in_query = set(_PART_NUM_RE.findall(query))
for cid, base in list(chunk_scores.items()):
chunk = self.idx.chunks[cid]
sid_norm = _norm_std_id(chunk["standard_id"])
@@ -257,6 +261,25 @@ class Retriever:
elif grade_in_title and grade_in_title.isdisjoint(grade_in_query):
bonus -= 0.40
# Part-number discriminator: when query explicitly names a part number,
# boost matching parts and penalise non-matching siblings.
if part_nums_in_query:
sid_raw = chunk["standard_id"]
m_part = re.search(r"\(Part\s*(\d+)\)", sid_raw, re.IGNORECASE)
if m_part:
part_num = m_part.group(1)
if part_num in part_nums_in_query:
bonus += 0.30
else:
bonus -= 0.20
else:
# Base standard (no Part N) — if query says "Part 1 General"
# treat as matching (base IS is often the Part 1/General Req)
if "1" in part_nums_in_query and (
"general" in query_lower or "requirement" in query_lower
):
bonus += 0.20
chunk_scores[cid] = base + bonus
# --- Group by standard_id, keep best chunk score ---
@@ -272,11 +295,14 @@ class Retriever:
# --- Same-family Part disambiguation ---
base_to_sids: dict[str, list[str]] = {}
for sid in std_best:
if "(" not in sid:
continue
m = _PART_BASE_RE.match(sid)
if m:
base_to_sids.setdefault(m.group(1), []).append(sid)
elif re.match(r"IS\s+\d+\s*:", sid, re.IGNORECASE):
# Base standard (no Part N) — register under its IS number
m2 = re.match(r"IS\s+(\d+)", sid, re.IGNORECASE)
if m2:
base_to_sids.setdefault(m2.group(1), []).append(sid)
for sids in base_to_sids.values():
if len(sids) < 2:
continue