feat(retrieval): add part-number discriminator and improve part disambiguation.
This commit is contained in:
+28
-2
@@ -56,6 +56,7 @@ _SHORT_CHUNK_PENALTY = 0.15
|
|||||||
|
|
||||||
_GRADE_RE = re.compile(r"\b(33|43|53)\b")
|
_GRADE_RE = re.compile(r"\b(33|43|53)\b")
|
||||||
_PART_BASE_RE = re.compile(r"IS\s+(\d+)\s*\(", re.IGNORECASE)
|
_PART_BASE_RE = re.compile(r"IS\s+(\d+)\s*\(", re.IGNORECASE)
|
||||||
|
_PART_NUM_RE = re.compile(r"\bpart[\s\-–—]*(\d+)\b", re.IGNORECASE)
|
||||||
_STOP_WORDS = {"and", "or", "for", "the", "of", "in", "a", "an", "to"}
|
_STOP_WORDS = {"and", "or", "for", "the", "of", "in", "a", "an", "to"}
|
||||||
|
|
||||||
|
|
||||||
@@ -212,6 +213,9 @@ class Retriever:
|
|||||||
grade_in_query = set(_GRADE_RE.findall(query))
|
grade_in_query = set(_GRADE_RE.findall(query))
|
||||||
sig_query = query_words - _STOP_WORDS
|
sig_query = query_words - _STOP_WORDS
|
||||||
|
|
||||||
|
# Detect explicit part number in query (e.g. "Part 2", "Part 1")
|
||||||
|
part_nums_in_query = set(_PART_NUM_RE.findall(query))
|
||||||
|
|
||||||
for cid, base in list(chunk_scores.items()):
|
for cid, base in list(chunk_scores.items()):
|
||||||
chunk = self.idx.chunks[cid]
|
chunk = self.idx.chunks[cid]
|
||||||
sid_norm = _norm_std_id(chunk["standard_id"])
|
sid_norm = _norm_std_id(chunk["standard_id"])
|
||||||
@@ -257,6 +261,25 @@ class Retriever:
|
|||||||
elif grade_in_title and grade_in_title.isdisjoint(grade_in_query):
|
elif grade_in_title and grade_in_title.isdisjoint(grade_in_query):
|
||||||
bonus -= 0.40
|
bonus -= 0.40
|
||||||
|
|
||||||
|
# Part-number discriminator: when query explicitly names a part number,
|
||||||
|
# boost matching parts and penalise non-matching siblings.
|
||||||
|
if part_nums_in_query:
|
||||||
|
sid_raw = chunk["standard_id"]
|
||||||
|
m_part = re.search(r"\(Part\s*(\d+)\)", sid_raw, re.IGNORECASE)
|
||||||
|
if m_part:
|
||||||
|
part_num = m_part.group(1)
|
||||||
|
if part_num in part_nums_in_query:
|
||||||
|
bonus += 0.30
|
||||||
|
else:
|
||||||
|
bonus -= 0.20
|
||||||
|
else:
|
||||||
|
# Base standard (no Part N) — if query says "Part 1 General"
|
||||||
|
# treat as matching (base IS is often the Part 1/General Req)
|
||||||
|
if "1" in part_nums_in_query and (
|
||||||
|
"general" in query_lower or "requirement" in query_lower
|
||||||
|
):
|
||||||
|
bonus += 0.20
|
||||||
|
|
||||||
chunk_scores[cid] = base + bonus
|
chunk_scores[cid] = base + bonus
|
||||||
|
|
||||||
# --- Group by standard_id, keep best chunk score ---
|
# --- Group by standard_id, keep best chunk score ---
|
||||||
@@ -272,11 +295,14 @@ class Retriever:
|
|||||||
# --- Same-family Part disambiguation ---
|
# --- Same-family Part disambiguation ---
|
||||||
base_to_sids: dict[str, list[str]] = {}
|
base_to_sids: dict[str, list[str]] = {}
|
||||||
for sid in std_best:
|
for sid in std_best:
|
||||||
if "(" not in sid:
|
|
||||||
continue
|
|
||||||
m = _PART_BASE_RE.match(sid)
|
m = _PART_BASE_RE.match(sid)
|
||||||
if m:
|
if m:
|
||||||
base_to_sids.setdefault(m.group(1), []).append(sid)
|
base_to_sids.setdefault(m.group(1), []).append(sid)
|
||||||
|
elif re.match(r"IS\s+\d+\s*:", sid, re.IGNORECASE):
|
||||||
|
# Base standard (no Part N) — register under its IS number
|
||||||
|
m2 = re.match(r"IS\s+(\d+)", sid, re.IGNORECASE)
|
||||||
|
if m2:
|
||||||
|
base_to_sids.setdefault(m2.group(1), []).append(sid)
|
||||||
for sids in base_to_sids.values():
|
for sids in base_to_sids.values():
|
||||||
if len(sids) < 2:
|
if len(sids) < 2:
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user