Renangi's picture
add vector database changes
008dffe
raw
history blame
4.13 kB
from typing import Dict, Any, List, Tuple, Optional
from datasets import load_dataset
from .config import RAGBENCH_DATASET, DOMAIN_TO_SUBSETS
from .retriever import ExampleRetriever
from .generator import RAGGenerator
from .judge import RAGJudge
from .metrics import trace_from_attributes, compute_rmse_auc
from .vector_db import SubsetVectorDB
class RagBenchExperiment:
def __init__(
self,
k: int = 3,
max_examples: Optional[int] = None,
split: str = "test",
):
self.k = k
self.max_examples = max_examples
self.split = split
self.retriever = ExampleRetriever()
self.generator = RAGGenerator()
self.judge = RAGJudge()
def _load_subset(self, subset: str):
ds = load_dataset(
RAGBENCH_DATASET, subset, split=self.split
)
return ds
def _to_docs_sentences(self, row) -> List[List[Tuple[str, str]]]:
docs: List[List[Tuple[str, str]]] = []
for doc in row["documents_sentences"]:
docs.append([(k, s) for k, s in doc])
return docs
def run_subset(self, subset: str) -> Dict[str, Any]:
ds = self._load_subset(subset)
# Build or load the FAISS-based vector database for this subset.
# This writes index files under ``vector_store/<subset>/<split>/``
# the first time it is called and reuses them thereafter.
vector_db = SubsetVectorDB(subset=subset, split=self.split)
vector_db.build_or_load(ds)
y_true_rel: List[float] = []
y_pred_rel: List[float] = []
y_true_util: List[float] = []
y_pred_util: List[float] = []
y_true_comp: List[float] = []
y_pred_comp: List[float] = []
y_true_adh: List[int] = []
y_pred_adh: List[float] = []
for i, row in enumerate(ds):
if self.max_examples is not None and i >= self.max_examples:
break
question = row["question"]
docs_sentences_full = self._to_docs_sentences(row)
# Try vector DB first: restrict retrieval to documents that
# belong to this particular example row (same ``i``).
hits = vector_db.search(
question,
k=self.k,
restrict_row_index=i,
)
if hits:
doc_indices = [doc_idx for _, doc_idx, _ in hits]
else:
# Fallback to the original hybrid (BM25 + dense) retriever
# operating only over this example's documents.
doc_indices = self.retriever.rank_docs(
question, docs_sentences_full, k=self.k
)
selected_docs = [docs_sentences_full[j] for j in doc_indices]
answer = self.generator.generate(question, selected_docs)
attrs = self.judge.annotate(question, answer, selected_docs)
pred = trace_from_attributes(attrs, selected_docs)
y_true_rel.append(float(row["relevance_score"]))
y_true_util.append(float(row["utilization_score"]))
y_true_comp.append(float(row["completeness_score"]))
y_true_adh.append(int(row["adherence_score"]))
y_pred_rel.append(pred["relevance"])
y_pred_util.append(pred["utilization"])
y_pred_comp.append(pred["completeness"])
y_pred_adh.append(pred["adherence"])
metrics = compute_rmse_auc(
y_true_rel,
y_pred_rel,
y_true_util,
y_pred_util,
y_true_comp,
y_pred_comp,
y_true_adh,
y_pred_adh,
)
return {
"subset": subset,
"n_examples": len(y_true_rel),
**metrics,
}
def run_domain(self, domain: str) -> Dict[str, Any]:
subsets = DOMAIN_TO_SUBSETS[domain]
results = []
for subset in subsets:
res = self.run_subset(subset)
results.append(res)
return {
"domain": domain,
"subsets": results,
}