3fbf91c706
Boost scores when query grade matches standard title grade, penalize mismatches. Add part disambiguation to correctly route queries to specific standard parts (e.g., IS 12269 (Part 1) vs (Part 2)). Regenerate retrieval results with improved ranking.
448 lines
17 KiB
Python
448 lines
17 KiB
Python
"""
|
||
BIS SP-21 Hybrid Retrieval System
|
||
----------------------------------
|
||
Combines dense (FAISS + sentence-transformers) and sparse (BM25) search,
|
||
then re-ranks and deduplicates to return the top-5 unique IS standards.
|
||
|
||
Usage
|
||
-----
|
||
# Index build (one-time, caches to data/processed/):
|
||
python inference.py --build
|
||
|
||
# Single query:
|
||
python inference.py --query "Which standard covers 33 grade OPC cement?"
|
||
|
||
# Batch from JSON file:
|
||
python inference.py --input data/processed/public_test_set.json
|
||
|
||
# Batch + write results JSON:
|
||
python inference.py --input data/processed/public_test_set.json \
|
||
--output data/processed/retrieval_results.json
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import math
|
||
import re
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import faiss
|
||
import numpy as np
|
||
from rank_bm25 import BM25Okapi
|
||
from sentence_transformers import SentenceTransformer
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Paths
|
||
# ---------------------------------------------------------------------------
|
||
_ROOT = Path(__file__).resolve().parent
|
||
_CHUNKS_PATH = _ROOT / "data/processed/standards_chunks.json"
|
||
_STANDARDS_PATH = _ROOT / "data/processed/standards.json"
|
||
_EMBED_CACHE = _ROOT / "data/processed/embeddings.npy"
|
||
_INDEX_CACHE = _ROOT / "data/processed/faiss.index"
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Constants
|
||
# ---------------------------------------------------------------------------
|
||
_MODEL_NAME = "all-MiniLM-L6-v2"
|
||
_TOP_K_DENSE = 10
|
||
_TOP_K_SPARSE = 10
|
||
_TOP_N_FINAL = 5
|
||
_SHORT_CHUNK_THRESHOLD = 40 # body words below this get a penalty
|
||
_SHORT_CHUNK_PENALTY = 0.15
|
||
|
||
_GRADE_RE = re.compile(r"\b(33|43|53)\b")
|
||
_PART_BASE_RE = re.compile(r"IS\s+(\d+)\s*\(", re.IGNORECASE)
|
||
_STOP_WORDS = {"and", "or", "for", "the", "of", "in", "a", "an", "to"}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Text helpers
|
||
# ---------------------------------------------------------------------------
|
||
def _body_text(chunk_text: str) -> str:
|
||
"""Strip the leading 'IS XXXX: YYYY Title [Section]' prefix line."""
|
||
parts = chunk_text.strip().split("\n", 1)
|
||
return parts[1].strip() if len(parts) > 1 else parts[0]
|
||
|
||
|
||
def _tokenize(text: str) -> list[str]:
|
||
"""Lowercase word tokenizer for BM25."""
|
||
return re.findall(r"[a-z0-9]+", text.lower())
|
||
|
||
|
||
def _bm25_doc(chunk: dict) -> list[str]:
|
||
"""
|
||
Build the BM25 document for a chunk.
|
||
|
||
Uses the full title from standards.json (stored in chunk["full_title"] by
|
||
load_or_build) to avoid truncated-title misses. Title is repeated ×4 so
|
||
an exact title match dominates over body-text noise.
|
||
"""
|
||
# full_title is injected by load_or_build; fall back to chunk title
|
||
title = chunk.get("full_title") or chunk.get("title", "")
|
||
title_tokens = _tokenize(title)
|
||
kw_tokens = _tokenize(" ".join(chunk.get("keywords", [])))
|
||
section_tokens = _tokenize(chunk.get("section", ""))
|
||
text_tokens = _tokenize(_body_text(chunk.get("text", "")))
|
||
return title_tokens * 4 + kw_tokens * 3 + section_tokens * 2 + text_tokens
|
||
|
||
|
||
def _norm_std_id(sid: str) -> str:
|
||
return re.sub(r"\s+", " ", sid).strip().upper()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Index builder
|
||
# ---------------------------------------------------------------------------
|
||
class RetrievalIndex:
|
||
def __init__(
|
||
self,
|
||
chunks: list[dict],
|
||
standards: list[dict],
|
||
model: SentenceTransformer,
|
||
) -> None:
|
||
self.chunks = chunks
|
||
self.standards = standards
|
||
self.model = model
|
||
|
||
# Build lookup: standard_id → standard record
|
||
self.std_lookup: dict[str, dict] = {
|
||
_norm_std_id(s["standard_id"]): s for s in standards
|
||
}
|
||
|
||
# Build per-standard keyword set for boosting
|
||
self.std_keywords: dict[str, set[str]] = {
|
||
_norm_std_id(s["standard_id"]): set(_tokenize(" ".join(s.get("keywords", []))))
|
||
for s in standards
|
||
}
|
||
|
||
# Dense index (FAISS)
|
||
self.faiss_index: faiss.IndexFlatIP | None = None
|
||
self.embeddings: np.ndarray | None = None
|
||
|
||
# Sparse index (BM25)
|
||
self.bm25: BM25Okapi | None = None
|
||
self._bm25_docs: list[list[str]] = []
|
||
|
||
# ------------------------------------------------------------------
|
||
def build(self, use_cache: bool = True) -> None:
|
||
self._build_dense(use_cache)
|
||
self._build_sparse()
|
||
|
||
def _build_dense(self, use_cache: bool) -> None:
|
||
if use_cache and _EMBED_CACHE.exists() and _INDEX_CACHE.exists():
|
||
print("Loading cached embeddings and FAISS index…")
|
||
self.embeddings = np.load(str(_EMBED_CACHE))
|
||
self.faiss_index = faiss.read_index(str(_INDEX_CACHE))
|
||
return
|
||
|
||
print(f"Encoding {len(self.chunks)} chunks with {_MODEL_NAME}…")
|
||
texts = [c["text"] for c in self.chunks]
|
||
emb = self.model.encode(
|
||
texts,
|
||
batch_size=64,
|
||
show_progress_bar=True,
|
||
normalize_embeddings=True, # cosine via inner product
|
||
)
|
||
self.embeddings = emb.astype(np.float32)
|
||
|
||
dim = self.embeddings.shape[1]
|
||
self.faiss_index = faiss.IndexFlatIP(dim)
|
||
self.faiss_index.add(self.embeddings)
|
||
|
||
np.save(str(_EMBED_CACHE), self.embeddings)
|
||
faiss.write_index(self.faiss_index, str(_INDEX_CACHE))
|
||
print(f"FAISS index built: {self.faiss_index.ntotal} vectors, dim={dim}")
|
||
|
||
def _build_sparse(self) -> None:
|
||
print("Building BM25 index…")
|
||
self._bm25_docs = [_bm25_doc(c) for c in self.chunks]
|
||
self.bm25 = BM25Okapi(self._bm25_docs)
|
||
print("BM25 index built.")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Retrieval
|
||
# ---------------------------------------------------------------------------
|
||
class Retriever:
|
||
def __init__(self, index: RetrievalIndex) -> None:
|
||
self.idx = index
|
||
|
||
def retrieve(self, query: str, top_n: int = _TOP_N_FINAL) -> list[dict]:
|
||
t0 = time.perf_counter()
|
||
|
||
query_tokens = _tokenize(query)
|
||
|
||
# --- Dense retrieval ---
|
||
q_emb = self.idx.model.encode(
|
||
[query], normalize_embeddings=True
|
||
).astype(np.float32)
|
||
dense_scores, dense_ids = self.idx.faiss_index.search(q_emb, _TOP_K_DENSE)
|
||
dense_scores = dense_scores[0]
|
||
dense_ids = dense_ids[0]
|
||
|
||
# Normalise dense scores (already cosine, range ~[-1, 1] → shift to [0, 1])
|
||
d_min, d_max = dense_scores.min(), dense_scores.max()
|
||
d_range = d_max - d_min if d_max > d_min else 1.0
|
||
dense_norm = {int(i): (s - d_min) / d_range for i, s in zip(dense_ids, dense_scores)}
|
||
|
||
# --- Sparse retrieval ---
|
||
bm25_raw = self.idx.bm25.get_scores(query_tokens)
|
||
top_sparse_ids = np.argsort(bm25_raw)[::-1][:_TOP_K_SPARSE]
|
||
top_sparse_scores = bm25_raw[top_sparse_ids]
|
||
|
||
s_max = top_sparse_scores.max() if top_sparse_scores.max() > 0 else 1.0
|
||
sparse_norm = {int(i): s / s_max for i, s in zip(top_sparse_ids, top_sparse_scores)}
|
||
|
||
# --- Merge candidates ---
|
||
candidate_ids = set(dense_norm) | set(sparse_norm)
|
||
chunk_scores: dict[int, float] = {}
|
||
for cid in candidate_ids:
|
||
d = dense_norm.get(cid, 0.0)
|
||
s = sparse_norm.get(cid, 0.0)
|
||
chunk_scores[cid] = 0.6 * d + 0.4 * s # weighted fusion
|
||
|
||
# --- Re-ranking ---
|
||
query_lower = query.lower()
|
||
query_words = set(query_tokens)
|
||
|
||
grade_in_query = set(_GRADE_RE.findall(query))
|
||
sig_query = query_words - _STOP_WORDS
|
||
|
||
for cid, base in list(chunk_scores.items()):
|
||
chunk = self.idx.chunks[cid]
|
||
sid_norm = _norm_std_id(chunk["standard_id"])
|
||
bonus = 0.0
|
||
|
||
# Use the authoritative full title for all title-based signals
|
||
full_title = chunk.get("full_title") or chunk.get("title", "")
|
||
full_title_tokens = set(_tokenize(full_title))
|
||
|
||
# Boost: keyword overlap with query
|
||
kw_set = self.idx.std_keywords.get(sid_norm, set())
|
||
kw_overlap = len(kw_set & query_words)
|
||
if kw_overlap:
|
||
bonus += 0.05 * min(kw_overlap, 4)
|
||
|
||
# Boost: title word overlap with query (uses full, untruncated title)
|
||
title_overlap = len(full_title_tokens & query_words)
|
||
if title_overlap:
|
||
bonus += 0.05 * min(title_overlap, 5)
|
||
|
||
# Strong boost: majority of title words present in query — likely
|
||
# the most on-point standard even if its chunk body is polluted.
|
||
sig_title = full_title_tokens - _STOP_WORDS
|
||
if sig_title and len(sig_title & sig_query) / len(sig_title) >= 0.6:
|
||
bonus += 0.25
|
||
|
||
# Boost: exact IS ID in query (user specifies a standard directly)
|
||
if re.search(r'\bIS\s*\d+', query, re.IGNORECASE):
|
||
for m in re.finditer(r'\bIS\s*\d+[\s:()A-Za-z\d]*:\s*\d{4}', query, re.IGNORECASE):
|
||
if _norm_std_id(m.group()) == sid_norm:
|
||
bonus += 0.20
|
||
break
|
||
|
||
# Penalize very short chunks
|
||
body_wc = len(_body_text(chunk.get("text", "")).split())
|
||
if body_wc < _SHORT_CHUNK_THRESHOLD:
|
||
bonus -= _SHORT_CHUNK_PENALTY
|
||
|
||
if grade_in_query:
|
||
grade_in_title = set(_GRADE_RE.findall(full_title))
|
||
if grade_in_title and grade_in_title == grade_in_query:
|
||
bonus += 0.35
|
||
elif grade_in_title and grade_in_title.isdisjoint(grade_in_query):
|
||
bonus -= 0.40
|
||
|
||
chunk_scores[cid] = base + bonus
|
||
|
||
# --- Group by standard_id, keep best chunk score ---
|
||
std_best: dict[str, float] = {}
|
||
std_chunk_repr: dict[str, dict] = {}
|
||
for cid, score in chunk_scores.items():
|
||
chunk = self.idx.chunks[cid]
|
||
sid = chunk["standard_id"]
|
||
if sid not in std_best or score > std_best[sid]:
|
||
std_best[sid] = score
|
||
std_chunk_repr[sid] = chunk
|
||
|
||
# --- Same-family Part disambiguation ---
|
||
base_to_sids: dict[str, list[str]] = {}
|
||
for sid in std_best:
|
||
if "(" not in sid:
|
||
continue
|
||
m = _PART_BASE_RE.match(sid)
|
||
if m:
|
||
base_to_sids.setdefault(m.group(1), []).append(sid)
|
||
for sids in base_to_sids.values():
|
||
if len(sids) < 2:
|
||
continue
|
||
# Only act when all siblings share the same title
|
||
titles = {_norm_std_id(std_chunk_repr[s].get("full_title") or
|
||
std_chunk_repr[s].get("title", ""))
|
||
for s in sids}
|
||
if len(titles) > 1:
|
||
continue
|
||
sib_kws = {s: self.idx.std_keywords.get(_norm_std_id(s), set()) for s in sids}
|
||
disc_weight: dict[str, float] = {s: 0.0 for s in sids}
|
||
for token in query_words:
|
||
owners = [s for s in sids if token in sib_kws[s]]
|
||
if len(owners) == 1:
|
||
# IDF floors at 0 for very common terms; +1 keeps weight positive
|
||
idf_weight = max(self.idx.bm25.idf.get(token, 0.0), 0.0) + 1.0
|
||
disc_weight[owners[0]] += idf_weight
|
||
max_w = max(disc_weight.values())
|
||
if max_w == 0.0:
|
||
continue
|
||
for s, w in disc_weight.items():
|
||
std_best[s] += w * 0.22 - (max_w - w) * 0.14
|
||
|
||
# --- Sort and take top N ---
|
||
ranked = sorted(std_best.items(), key=lambda x: x[1], reverse=True)[:top_n]
|
||
|
||
results = []
|
||
for sid, score in ranked:
|
||
std_rec = self.idx.std_lookup.get(_norm_std_id(sid), {})
|
||
results.append({
|
||
"standard_id": sid,
|
||
"title": std_rec.get("title", std_chunk_repr[sid].get("title", "")),
|
||
"category": std_rec.get("category", std_chunk_repr[sid].get("category", "")),
|
||
"score": round(float(score), 4),
|
||
"matched_section": std_chunk_repr[sid].get("section", ""),
|
||
})
|
||
|
||
latency = time.perf_counter() - t0
|
||
return results, latency
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Index load/build helper
|
||
# ---------------------------------------------------------------------------
|
||
def load_or_build(force_rebuild: bool = False) -> tuple[RetrievalIndex, Retriever]:
|
||
with open(_CHUNKS_PATH, encoding="utf-8") as f:
|
||
chunks = json.load(f)
|
||
with open(_STANDARDS_PATH, encoding="utf-8") as f:
|
||
standards = json.load(f)
|
||
|
||
# Attach full title + keywords from standards.json to each chunk.
|
||
# full_title ensures the BM25 document uses the authoritative (untruncated)
|
||
# title from the structured record, not whatever ended up in the chunk prefix.
|
||
std_map = {s["standard_id"]: s for s in standards}
|
||
for c in chunks:
|
||
rec = std_map.get(c["standard_id"], {})
|
||
c["full_title"] = rec.get("title", c.get("title", ""))
|
||
c["keywords"] = rec.get("keywords", [])
|
||
|
||
print(f"Loaded {len(chunks)} chunks, {len(standards)} standards.")
|
||
model = SentenceTransformer(_MODEL_NAME)
|
||
index = RetrievalIndex(chunks, standards, model)
|
||
index.build(use_cache=not force_rebuild)
|
||
return index, Retriever(index)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# CLI
|
||
# ---------------------------------------------------------------------------
|
||
def _format_result(
|
||
query_id: str,
|
||
query: str,
|
||
results: list[dict],
|
||
latency: float,
|
||
expected_standards: list[str] | None = None,
|
||
) -> dict:
|
||
out: dict[str, Any] = {
|
||
"id": query_id,
|
||
"query": query,
|
||
"retrieved_standards": [r["standard_id"] for r in results],
|
||
"details": results,
|
||
"latency_seconds": round(latency, 4),
|
||
}
|
||
if expected_standards is not None:
|
||
out["expected_standards"] = expected_standards
|
||
return out
|
||
|
||
|
||
def main() -> None:
|
||
parser = argparse.ArgumentParser(description="BIS SP-21 Hybrid Retrieval")
|
||
parser.add_argument("--build", action="store_true", help="Force rebuild of FAISS index")
|
||
parser.add_argument("--query", type=str, help="Single query string")
|
||
parser.add_argument("--input", type=str, help="JSON file with list of {id, query} objects")
|
||
parser.add_argument("--output", type=str, help="Write JSON results to this file")
|
||
args = parser.parse_args()
|
||
|
||
index, retriever = load_or_build(force_rebuild=args.build)
|
||
|
||
if args.query:
|
||
results, latency = retriever.retrieve(args.query)
|
||
out = _format_result("Q0", args.query, results, latency)
|
||
print("\n" + "=" * 60)
|
||
print(f"Query : {args.query}")
|
||
print(f"Latency: {latency:.3f}s")
|
||
print("\nTop results:")
|
||
for i, r in enumerate(results, 1):
|
||
print(f" {i}. {r['standard_id']} — {r['title']}")
|
||
print(f" Category: {r['category']} | Section: {r['matched_section']} | Score: {r['score']}")
|
||
if args.output:
|
||
Path(args.output).write_text(json.dumps([out], indent=2, ensure_ascii=False), encoding="utf-8")
|
||
return
|
||
|
||
if args.input:
|
||
with open(args.input, encoding="utf-8") as f:
|
||
queries = json.load(f)
|
||
|
||
all_results = []
|
||
latencies = []
|
||
for q in queries:
|
||
qid = q.get("id", "?")
|
||
qtext = q.get("query", "")
|
||
results, latency = retriever.retrieve(qtext)
|
||
latencies.append(latency)
|
||
expected = q.get("expected_standards", [])
|
||
out = _format_result(qid, qtext, results, latency, expected_standards=expected or None)
|
||
all_results.append(out)
|
||
hit = any(r["standard_id"] in expected for r in results)
|
||
print(f"[{qid}] latency={latency:.3f}s hit={hit} retrieved={[r['standard_id'] for r in results]}")
|
||
|
||
print(f"\nAvg latency: {sum(latencies)/len(latencies):.3f}s | Max: {max(latencies):.3f}s")
|
||
|
||
# Simple Hit@5 eval
|
||
hits = 0
|
||
for q, out in zip(queries, all_results):
|
||
expected = set(q.get("expected_standards", []))
|
||
if expected & set(out["retrieved_standards"]):
|
||
hits += 1
|
||
print(f"Hit@5: {hits}/{len(queries)} = {hits/len(queries):.1%}")
|
||
|
||
if args.output:
|
||
Path(args.output).write_text(
|
||
json.dumps(all_results, indent=2, ensure_ascii=False), encoding="utf-8"
|
||
)
|
||
print(f"Results written to {args.output}")
|
||
return
|
||
|
||
# Default: demo with one example query
|
||
demo_query = (
|
||
"Which standard specifies chemical and physical requirements "
|
||
"for 33 grade Ordinary Portland Cement?"
|
||
)
|
||
results, latency = retriever.retrieve(demo_query)
|
||
out = _format_result("DEMO", demo_query, results, latency)
|
||
|
||
print("\n" + "=" * 60)
|
||
print(f"Demo query : {demo_query}")
|
||
print(f"Latency : {latency:.3f}s")
|
||
print("\nTop-5 retrieved standards:")
|
||
for i, r in enumerate(results, 1):
|
||
print(f" {i}. {r['standard_id']} — {r['title']}")
|
||
print(f" Category : {r['category']}")
|
||
print(f" Section : {r['matched_section']}")
|
||
print(f" Score : {r['score']}")
|
||
print("=" * 60)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|