Renangi's picture
add vector database changes
008dffe
raw
history blame
6.61 kB
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional
import faiss
import numpy as np
from datasets import Dataset, load_dataset
from sentence_transformers import SentenceTransformer
from .config import RAGBENCH_DATASET, EMBEDDING_MODEL
class SubsetVectorDB:
"""
Simple FAISS-based vector database for a single RAGBench subset + split.
This class is intentionally lightweight and file-based:
- Each (subset, split) pair gets its own folder under ``vector_store/``.
- We build a single FAISS index over all documents' concatenated text.
- We also persist a small ``meta.json`` mapping index -> (row_index, doc_index).
At evaluation time we can:
- Lazily build the index once (or load it if it already exists).
- Retrieve the top-k most similar documents for a given question.
- Optionally restrict results to a particular example row.
"""
def __init__(
self,
subset: str,
split: str = "test",
root_dir: Optional[Path] = None,
) -> None:
self.subset = subset
self.split = split
project_root = Path(__file__).resolve().parents[1]
self.root_dir = (root_dir or project_root / "vector_store").resolve()
self.index_dir = self.root_dir / subset / split
self.index_dir.mkdir(parents=True, exist_ok=True)
self.index_path = self.index_dir / "index.faiss"
self.meta_path = self.index_dir / "meta.json"
# Will be populated by ``build_or_load``
self.embedder: Optional[SentenceTransformer] = None
self.index: Optional[faiss.Index] = None
self.meta: List[Dict[str, Any]] = []
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _load_embedder(self) -> SentenceTransformer:
if self.embedder is None:
self.embedder = SentenceTransformer(EMBEDDING_MODEL)
return self.embedder
def _load_index_files(self) -> bool:
"""
Try to load index + meta files from disk.
Returns True if successful, False if anything is missing.
"""
if not self.index_path.exists() or not self.meta_path.exists():
return False
self.index = faiss.read_index(str(self.index_path))
with self.meta_path.open("r", encoding="utf-8") as f:
self.meta = json.load(f)
return True
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def build_or_load(self, ds: Optional[Dataset] = None) -> None:
"""
Ensure the FAISS index exists for (subset, split).
If the index files are already on disk we simply load them.
Otherwise we:
- iterate over the dataset
- concatenate each document's sentences into a single string
- build a dense embedding using SentenceTransformers
- create a cosine-similarity FAISS index and persist it
"""
if self._load_index_files():
return
if ds is None:
ds = load_dataset(RAGBENCH_DATASET, self.subset, split=self.split)
texts: List[str] = []
meta: List[Dict[str, Any]] = []
for row_idx, row in enumerate(ds):
# ``documents_sentences`` is a list of docs;
# each doc is a list of (sentence_key, sentence_text) pairs.
for doc_idx, doc in enumerate(row["documents_sentences"]):
doc_text = " ".join(sentence_text for _, sentence_text in doc)
texts.append(doc_text)
meta.append({"row_index": int(row_idx), "doc_index": int(doc_idx)})
if not texts:
raise ValueError(
f"No documents found while building vector DB for subset={self.subset}, split={self.split}"
)
embedder = self._load_embedder()
embeddings = embedder.encode(
texts,
batch_size=32,
show_progress_bar=True,
convert_to_numpy=True,
)
# FAISS expects float32
embeddings = np.asarray(embeddings, dtype="float32")
# Use cosine similarity via inner product on L2-normalized vectors
faiss.normalize_L2(embeddings)
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
# Persist to disk so subsequent runs are cheap
faiss.write_index(index, str(self.index_path))
with self.meta_path.open("w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
self.index = index
self.meta = meta
def search(
self,
query: str,
k: int = 10,
restrict_row_index: Optional[int] = None,
) -> List[Tuple[int, int, float]]:
"""
Search the vector DB for the top-k documents relevant to ``query``.
Returns a list of (row_index, doc_index, score) tuples.
If ``restrict_row_index`` is provided, we will over-sample and then
filter to only documents that belong to that example row.
"""
if self.index is None or not self.meta:
if not self._load_index_files():
raise RuntimeError(
"Vector DB has not been built yet. Call build_or_load() first."
)
embedder = self._load_embedder()
q_emb = embedder.encode([query], convert_to_numpy=True)
q_emb = np.asarray(q_emb, dtype="float32")
faiss.normalize_L2(q_emb)
# For restricted searches we over-sample so that filtering still leaves
# enough candidates. For unrestricted we just use k.
search_k = k * 10 if restrict_row_index is not None else k
search_k = max(search_k, k)
scores, indices = self.index.search(q_emb, search_k)
scores = scores[0]
indices = indices[0]
results: List[Tuple[int, int, float]] = []
for idx, score in zip(indices, scores):
if idx < 0 or idx >= len(self.meta):
continue
meta = self.meta[int(idx)]
row_index = meta["row_index"]
doc_index = meta["doc_index"]
if restrict_row_index is not None and row_index != restrict_row_index:
continue
results.append((row_index, doc_index, float(score)))
if len(results) >= k:
break
return results