Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple, Optional | |
| import faiss | |
| import numpy as np | |
| from datasets import Dataset, load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from .config import RAGBENCH_DATASET, EMBEDDING_MODEL | |
| class SubsetVectorDB: | |
| """ | |
| Simple FAISS-based vector database for a single RAGBench subset + split. | |
| This class is intentionally lightweight and file-based: | |
| - Each (subset, split) pair gets its own folder under ``vector_store/``. | |
| - We build a single FAISS index over all documents' concatenated text. | |
| - We also persist a small ``meta.json`` mapping index -> (row_index, doc_index). | |
| At evaluation time we can: | |
| - Lazily build the index once (or load it if it already exists). | |
| - Retrieve the top-k most similar documents for a given question. | |
| - Optionally restrict results to a particular example row. | |
| """ | |
| def __init__( | |
| self, | |
| subset: str, | |
| split: str = "test", | |
| root_dir: Optional[Path] = None, | |
| ) -> None: | |
| self.subset = subset | |
| self.split = split | |
| project_root = Path(__file__).resolve().parents[1] | |
| self.root_dir = (root_dir or project_root / "vector_store").resolve() | |
| self.index_dir = self.root_dir / subset / split | |
| self.index_dir.mkdir(parents=True, exist_ok=True) | |
| self.index_path = self.index_dir / "index.faiss" | |
| self.meta_path = self.index_dir / "meta.json" | |
| # Will be populated by ``build_or_load`` | |
| self.embedder: Optional[SentenceTransformer] = None | |
| self.index: Optional[faiss.Index] = None | |
| self.meta: List[Dict[str, Any]] = [] | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _load_embedder(self) -> SentenceTransformer: | |
| if self.embedder is None: | |
| self.embedder = SentenceTransformer(EMBEDDING_MODEL) | |
| return self.embedder | |
| def _load_index_files(self) -> bool: | |
| """ | |
| Try to load index + meta files from disk. | |
| Returns True if successful, False if anything is missing. | |
| """ | |
| if not self.index_path.exists() or not self.meta_path.exists(): | |
| return False | |
| self.index = faiss.read_index(str(self.index_path)) | |
| with self.meta_path.open("r", encoding="utf-8") as f: | |
| self.meta = json.load(f) | |
| return True | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def build_or_load(self, ds: Optional[Dataset] = None) -> None: | |
| """ | |
| Ensure the FAISS index exists for (subset, split). | |
| If the index files are already on disk we simply load them. | |
| Otherwise we: | |
| - iterate over the dataset | |
| - concatenate each document's sentences into a single string | |
| - build a dense embedding using SentenceTransformers | |
| - create a cosine-similarity FAISS index and persist it | |
| """ | |
| if self._load_index_files(): | |
| return | |
| if ds is None: | |
| ds = load_dataset(RAGBENCH_DATASET, self.subset, split=self.split) | |
| texts: List[str] = [] | |
| meta: List[Dict[str, Any]] = [] | |
| for row_idx, row in enumerate(ds): | |
| # ``documents_sentences`` is a list of docs; | |
| # each doc is a list of (sentence_key, sentence_text) pairs. | |
| for doc_idx, doc in enumerate(row["documents_sentences"]): | |
| doc_text = " ".join(sentence_text for _, sentence_text in doc) | |
| texts.append(doc_text) | |
| meta.append({"row_index": int(row_idx), "doc_index": int(doc_idx)}) | |
| if not texts: | |
| raise ValueError( | |
| f"No documents found while building vector DB for subset={self.subset}, split={self.split}" | |
| ) | |
| embedder = self._load_embedder() | |
| embeddings = embedder.encode( | |
| texts, | |
| batch_size=32, | |
| show_progress_bar=True, | |
| convert_to_numpy=True, | |
| ) | |
| # FAISS expects float32 | |
| embeddings = np.asarray(embeddings, dtype="float32") | |
| # Use cosine similarity via inner product on L2-normalized vectors | |
| faiss.normalize_L2(embeddings) | |
| dim = embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| index.add(embeddings) | |
| # Persist to disk so subsequent runs are cheap | |
| faiss.write_index(index, str(self.index_path)) | |
| with self.meta_path.open("w", encoding="utf-8") as f: | |
| json.dump(meta, f, indent=2) | |
| self.index = index | |
| self.meta = meta | |
| def search( | |
| self, | |
| query: str, | |
| k: int = 10, | |
| restrict_row_index: Optional[int] = None, | |
| ) -> List[Tuple[int, int, float]]: | |
| """ | |
| Search the vector DB for the top-k documents relevant to ``query``. | |
| Returns a list of (row_index, doc_index, score) tuples. | |
| If ``restrict_row_index`` is provided, we will over-sample and then | |
| filter to only documents that belong to that example row. | |
| """ | |
| if self.index is None or not self.meta: | |
| if not self._load_index_files(): | |
| raise RuntimeError( | |
| "Vector DB has not been built yet. Call build_or_load() first." | |
| ) | |
| embedder = self._load_embedder() | |
| q_emb = embedder.encode([query], convert_to_numpy=True) | |
| q_emb = np.asarray(q_emb, dtype="float32") | |
| faiss.normalize_L2(q_emb) | |
| # For restricted searches we over-sample so that filtering still leaves | |
| # enough candidates. For unrestricted we just use k. | |
| search_k = k * 10 if restrict_row_index is not None else k | |
| search_k = max(search_k, k) | |
| scores, indices = self.index.search(q_emb, search_k) | |
| scores = scores[0] | |
| indices = indices[0] | |
| results: List[Tuple[int, int, float]] = [] | |
| for idx, score in zip(indices, scores): | |
| if idx < 0 or idx >= len(self.meta): | |
| continue | |
| meta = self.meta[int(idx)] | |
| row_index = meta["row_index"] | |
| doc_index = meta["doc_index"] | |
| if restrict_row_index is not None and row_index != restrict_row_index: | |
| continue | |
| results.append((row_index, doc_index, float(score))) | |
| if len(results) >= k: | |
| break | |
| return results |