File size: 4,130 Bytes
c8dfbc0
 
 
 
 
 
 
 
008dffe
c8dfbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
008dffe
 
 
 
 
 
 
c8dfbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
008dffe
 
 
 
 
 
c8dfbc0
008dffe
 
 
 
 
 
 
 
 
 
c8dfbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Dict, Any, List, Tuple, Optional
from datasets import load_dataset

from .config import RAGBENCH_DATASET, DOMAIN_TO_SUBSETS
from .retriever import ExampleRetriever
from .generator import RAGGenerator
from .judge import RAGJudge
from .metrics import trace_from_attributes, compute_rmse_auc
from .vector_db import SubsetVectorDB


class RagBenchExperiment:
    def __init__(
        self,
        k: int = 3,
        max_examples: Optional[int] = None,
        split: str = "test",
    ):
        self.k = k
        self.max_examples = max_examples
        self.split = split

        self.retriever = ExampleRetriever()
        self.generator = RAGGenerator()
        self.judge = RAGJudge()

    def _load_subset(self, subset: str):
        ds = load_dataset(
            RAGBENCH_DATASET, subset, split=self.split
        )
        return ds

    def _to_docs_sentences(self, row) -> List[List[Tuple[str, str]]]:
        docs: List[List[Tuple[str, str]]] = []
        for doc in row["documents_sentences"]:
            docs.append([(k, s) for k, s in doc])
        return docs

    def run_subset(self, subset: str) -> Dict[str, Any]:
        ds = self._load_subset(subset)

        # Build or load the FAISS-based vector database for this subset.
        # This writes index files under ``vector_store/<subset>/<split>/``
        # the first time it is called and reuses them thereafter.
        vector_db = SubsetVectorDB(subset=subset, split=self.split)
        vector_db.build_or_load(ds)


        y_true_rel: List[float] = []
        y_pred_rel: List[float] = []
        y_true_util: List[float] = []
        y_pred_util: List[float] = []
        y_true_comp: List[float] = []
        y_pred_comp: List[float] = []
        y_true_adh: List[int] = []
        y_pred_adh: List[float] = []

        for i, row in enumerate(ds):
            if self.max_examples is not None and i >= self.max_examples:
                break

            question = row["question"]
            docs_sentences_full = self._to_docs_sentences(row)

            # Try vector DB first: restrict retrieval to documents that
            # belong to this particular example row (same ``i``).
            hits = vector_db.search(
                question,
                k=self.k,
                restrict_row_index=i,
            )

            if hits:
                doc_indices = [doc_idx for _, doc_idx, _ in hits]
            else:
                # Fallback to the original hybrid (BM25 + dense) retriever
                # operating only over this example's documents.
                doc_indices = self.retriever.rank_docs(
                    question, docs_sentences_full, k=self.k
                )

            selected_docs = [docs_sentences_full[j] for j in doc_indices]

            answer = self.generator.generate(question, selected_docs)

            attrs = self.judge.annotate(question, answer, selected_docs)

            pred = trace_from_attributes(attrs, selected_docs)

            y_true_rel.append(float(row["relevance_score"]))
            y_true_util.append(float(row["utilization_score"]))
            y_true_comp.append(float(row["completeness_score"]))
            y_true_adh.append(int(row["adherence_score"]))

            y_pred_rel.append(pred["relevance"])
            y_pred_util.append(pred["utilization"])
            y_pred_comp.append(pred["completeness"])
            y_pred_adh.append(pred["adherence"])

        metrics = compute_rmse_auc(
            y_true_rel,
            y_pred_rel,
            y_true_util,
            y_pred_util,
            y_true_comp,
            y_pred_comp,
            y_true_adh,
            y_pred_adh,
        )

        return {
            "subset": subset,
            "n_examples": len(y_true_rel),
            **metrics,
        }

    def run_domain(self, domain: str) -> Dict[str, Any]:
        subsets = DOMAIN_TO_SUBSETS[domain]
        results = []
        for subset in subsets:
            res = self.run_subset(subset)
            results.append(res)
        return {
            "domain": domain,
            "subsets": results,
        }