Spaces:
Running
Running
| from typing import Dict, Any, List, Tuple, Optional | |
| from datasets import load_dataset | |
| from .config import RAGBENCH_DATASET, DOMAIN_TO_SUBSETS | |
| from .retriever import ExampleRetriever | |
| from .generator import RAGGenerator | |
| from .judge import RAGJudge | |
| from .metrics import trace_from_attributes, compute_rmse_auc | |
| from .vector_db import SubsetVectorDB | |
| class RagBenchExperiment: | |
| def __init__( | |
| self, | |
| k: int = 3, | |
| max_examples: Optional[int] = None, | |
| split: str = "test", | |
| ): | |
| self.k = k | |
| self.max_examples = max_examples | |
| self.split = split | |
| self.retriever = ExampleRetriever() | |
| self.generator = RAGGenerator() | |
| self.judge = RAGJudge() | |
| def _load_subset(self, subset: str): | |
| ds = load_dataset( | |
| RAGBENCH_DATASET, subset, split=self.split | |
| ) | |
| return ds | |
| def _to_docs_sentences(self, row) -> List[List[Tuple[str, str]]]: | |
| docs: List[List[Tuple[str, str]]] = [] | |
| for doc in row["documents_sentences"]: | |
| docs.append([(k, s) for k, s in doc]) | |
| return docs | |
| def run_subset(self, subset: str) -> Dict[str, Any]: | |
| ds = self._load_subset(subset) | |
| # Build or load the FAISS-based vector database for this subset. | |
| # This writes index files under ``vector_store/<subset>/<split>/`` | |
| # the first time it is called and reuses them thereafter. | |
| vector_db = SubsetVectorDB(subset=subset, split=self.split) | |
| vector_db.build_or_load(ds) | |
| y_true_rel: List[float] = [] | |
| y_pred_rel: List[float] = [] | |
| y_true_util: List[float] = [] | |
| y_pred_util: List[float] = [] | |
| y_true_comp: List[float] = [] | |
| y_pred_comp: List[float] = [] | |
| y_true_adh: List[int] = [] | |
| y_pred_adh: List[float] = [] | |
| for i, row in enumerate(ds): | |
| if self.max_examples is not None and i >= self.max_examples: | |
| break | |
| question = row["question"] | |
| docs_sentences_full = self._to_docs_sentences(row) | |
| # Try vector DB first: restrict retrieval to documents that | |
| # belong to this particular example row (same ``i``). | |
| hits = vector_db.search( | |
| question, | |
| k=self.k, | |
| restrict_row_index=i, | |
| ) | |
| if hits: | |
| doc_indices = [doc_idx for _, doc_idx, _ in hits] | |
| else: | |
| # Fallback to the original hybrid (BM25 + dense) retriever | |
| # operating only over this example's documents. | |
| doc_indices = self.retriever.rank_docs( | |
| question, docs_sentences_full, k=self.k | |
| ) | |
| selected_docs = [docs_sentences_full[j] for j in doc_indices] | |
| answer = self.generator.generate(question, selected_docs) | |
| attrs = self.judge.annotate(question, answer, selected_docs) | |
| pred = trace_from_attributes(attrs, selected_docs) | |
| y_true_rel.append(float(row["relevance_score"])) | |
| y_true_util.append(float(row["utilization_score"])) | |
| y_true_comp.append(float(row["completeness_score"])) | |
| y_true_adh.append(int(row["adherence_score"])) | |
| y_pred_rel.append(pred["relevance"]) | |
| y_pred_util.append(pred["utilization"]) | |
| y_pred_comp.append(pred["completeness"]) | |
| y_pred_adh.append(pred["adherence"]) | |
| metrics = compute_rmse_auc( | |
| y_true_rel, | |
| y_pred_rel, | |
| y_true_util, | |
| y_pred_util, | |
| y_true_comp, | |
| y_pred_comp, | |
| y_true_adh, | |
| y_pred_adh, | |
| ) | |
| return { | |
| "subset": subset, | |
| "n_examples": len(y_true_rel), | |
| **metrics, | |
| } | |
| def run_domain(self, domain: str) -> Dict[str, Any]: | |
| subsets = DOMAIN_TO_SUBSETS[domain] | |
| results = [] | |
| for subset in subsets: | |
| res = self.run_subset(subset) | |
| results.append(res) | |
| return { | |
| "domain": domain, | |
| "subsets": results, | |
| } | |