feat(retrieval): add part-number discriminator and improve part disambiguation.
This commit is contained in:
+28
-2
@@ -56,6 +56,7 @@ _SHORT_CHUNK_PENALTY = 0.15
|
||||
|
||||
_GRADE_RE = re.compile(r"\b(33|43|53)\b")
|
||||
_PART_BASE_RE = re.compile(r"IS\s+(\d+)\s*\(", re.IGNORECASE)
|
||||
_PART_NUM_RE = re.compile(r"\bpart[\s\-–—]*(\d+)\b", re.IGNORECASE)
|
||||
_STOP_WORDS = {"and", "or", "for", "the", "of", "in", "a", "an", "to"}
|
||||
|
||||
|
||||
@@ -212,6 +213,9 @@ class Retriever:
|
||||
grade_in_query = set(_GRADE_RE.findall(query))
|
||||
sig_query = query_words - _STOP_WORDS
|
||||
|
||||
# Detect explicit part number in query (e.g. "Part 2", "Part 1")
|
||||
part_nums_in_query = set(_PART_NUM_RE.findall(query))
|
||||
|
||||
for cid, base in list(chunk_scores.items()):
|
||||
chunk = self.idx.chunks[cid]
|
||||
sid_norm = _norm_std_id(chunk["standard_id"])
|
||||
@@ -257,6 +261,25 @@ class Retriever:
|
||||
elif grade_in_title and grade_in_title.isdisjoint(grade_in_query):
|
||||
bonus -= 0.40
|
||||
|
||||
# Part-number discriminator: when query explicitly names a part number,
|
||||
# boost matching parts and penalise non-matching siblings.
|
||||
if part_nums_in_query:
|
||||
sid_raw = chunk["standard_id"]
|
||||
m_part = re.search(r"\(Part\s*(\d+)\)", sid_raw, re.IGNORECASE)
|
||||
if m_part:
|
||||
part_num = m_part.group(1)
|
||||
if part_num in part_nums_in_query:
|
||||
bonus += 0.30
|
||||
else:
|
||||
bonus -= 0.20
|
||||
else:
|
||||
# Base standard (no Part N) — if query says "Part 1 General"
|
||||
# treat as matching (base IS is often the Part 1/General Req)
|
||||
if "1" in part_nums_in_query and (
|
||||
"general" in query_lower or "requirement" in query_lower
|
||||
):
|
||||
bonus += 0.20
|
||||
|
||||
chunk_scores[cid] = base + bonus
|
||||
|
||||
# --- Group by standard_id, keep best chunk score ---
|
||||
@@ -272,11 +295,14 @@ class Retriever:
|
||||
# --- Same-family Part disambiguation ---
|
||||
base_to_sids: dict[str, list[str]] = {}
|
||||
for sid in std_best:
|
||||
if "(" not in sid:
|
||||
continue
|
||||
m = _PART_BASE_RE.match(sid)
|
||||
if m:
|
||||
base_to_sids.setdefault(m.group(1), []).append(sid)
|
||||
elif re.match(r"IS\s+\d+\s*:", sid, re.IGNORECASE):
|
||||
# Base standard (no Part N) — register under its IS number
|
||||
m2 = re.match(r"IS\s+(\d+)", sid, re.IGNORECASE)
|
||||
if m2:
|
||||
base_to_sids.setdefault(m2.group(1), []).append(sid)
|
||||
for sids in base_to_sids.values():
|
||||
if len(sids) < 2:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user