Renangi's picture
Initial commit without secrets
c8dfbc0
raw
history blame
1.01 kB
from typing import List, Tuple
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from .config import EMBEDDING_MODEL
class ExampleRetriever:
"""Ranks the per-example documents in RAGBench by similarity to the question.""" # noqa: E501
def __init__(self):
self.embedder = SentenceTransformer(EMBEDDING_MODEL)
def _encode(self, texts: List[str]) -> np.ndarray:
return self.embedder.encode(texts, show_progress_bar=False)
def rank_docs(
self,
question: str,
documents_sentences: List[List[Tuple[str, str]]],
k: int = 4,
) -> List[int]:
doc_texts = [
" ".join(sent for _, sent in doc) for doc in documents_sentences
]
q_emb = self._encode([question])
d_emb = self._encode(doc_texts)
sims = cosine_similarity(q_emb, d_emb)[0]
topk_idx = np.argsort(sims)[::-1][:k]
return topk_idx.tolist()