From 458bd934349e0bf1c502dfc11ab6d6175eca61bf Mon Sep 17 00:00:00 2001 From: Kshitij <160704796+kshitij-ka@users.noreply.github.com> Date: Mon, 4 May 2026 15:45:02 +0530 Subject: [PATCH] feat(retrieval): add part-number discriminator and improve part disambiguation. --- inference.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/inference.py b/inference.py index 768fcd4..0b575fe 100644 --- a/inference.py +++ b/inference.py @@ -56,6 +56,7 @@ _SHORT_CHUNK_PENALTY = 0.15 _GRADE_RE = re.compile(r"\b(33|43|53)\b") _PART_BASE_RE = re.compile(r"IS\s+(\d+)\s*\(", re.IGNORECASE) +_PART_NUM_RE = re.compile(r"\bpart[\s\-–—]*(\d+)\b", re.IGNORECASE) _STOP_WORDS = {"and", "or", "for", "the", "of", "in", "a", "an", "to"} @@ -212,6 +213,9 @@ class Retriever: grade_in_query = set(_GRADE_RE.findall(query)) sig_query = query_words - _STOP_WORDS + # Detect explicit part number in query (e.g. "Part 2", "Part 1") + part_nums_in_query = set(_PART_NUM_RE.findall(query)) + for cid, base in list(chunk_scores.items()): chunk = self.idx.chunks[cid] sid_norm = _norm_std_id(chunk["standard_id"]) @@ -257,6 +261,25 @@ class Retriever: elif grade_in_title and grade_in_title.isdisjoint(grade_in_query): bonus -= 0.40 + # Part-number discriminator: when query explicitly names a part number, + # boost matching parts and penalise non-matching siblings. + if part_nums_in_query: + sid_raw = chunk["standard_id"] + m_part = re.search(r"\(Part\s*(\d+)\)", sid_raw, re.IGNORECASE) + if m_part: + part_num = m_part.group(1) + if part_num in part_nums_in_query: + bonus += 0.30 + else: + bonus -= 0.20 + else: + # Base standard (no Part N) — if query says "Part 1 General" + # treat as matching (base IS is often the Part 1/General Req) + if "1" in part_nums_in_query and ( + "general" in query_lower or "requirement" in query_lower + ): + bonus += 0.20 + chunk_scores[cid] = base + bonus # --- Group by standard_id, keep best chunk score --- @@ -272,11 +295,14 @@ class Retriever: # --- Same-family Part disambiguation --- base_to_sids: dict[str, list[str]] = {} for sid in std_best: - if "(" not in sid: - continue m = _PART_BASE_RE.match(sid) if m: base_to_sids.setdefault(m.group(1), []).append(sid) + elif re.match(r"IS\s+\d+\s*:", sid, re.IGNORECASE): + # Base standard (no Part N) — register under its IS number + m2 = re.match(r"IS\s+(\d+)", sid, re.IGNORECASE) + if m2: + base_to_sids.setdefault(m2.group(1), []).append(sid) for sids in base_to_sids.values(): if len(sids) < 2: continue