from typing import List, Tuple import numpy as np from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from .config import EMBEDDING_MODEL class ExampleRetriever: """Ranks the per-example documents in RAGBench by similarity to the question.""" # noqa: E501 def __init__(self): self.embedder = SentenceTransformer(EMBEDDING_MODEL) def _encode(self, texts: List[str]) -> np.ndarray: return self.embedder.encode(texts, show_progress_bar=False) def rank_docs( self, question: str, documents_sentences: List[List[Tuple[str, str]]], k: int = 4, ) -> List[int]: doc_texts = [ " ".join(sent for _, sent in doc) for doc in documents_sentences ] q_emb = self._encode([question]) d_emb = self._encode(doc_texts) sims = cosine_similarity(q_emb, d_emb)[0] topk_idx = np.argsort(sims)[::-1][:k] return topk_idx.tolist()