Renangi commited on
Commit
c8dfbc0
·
0 Parent(s):

Initial commit without secrets

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .env
2
+ .venv/
3
+ __pycache__/
4
+ *.pyc
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential git && \
7
+ rm -rf /var/lib/apt/lists/*
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ COPY ragbench_eval ./ragbench_eval
13
+ COPY app ./app
14
+ COPY scripts ./scripts
15
+ COPY prompts ./prompts
16
+
17
+ ENV PYTHONUNBUFFERED=1
18
+
19
+ # Hugging Face Spaces expect 7860
20
+ EXPOSE 7860
21
+
22
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ragbench-rag-eval
3
+ emoji: "📊"
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ # RAGBench RAG Evaluation Project
11
+
12
+ This project evaluates a RAG system on the RAGBench dataset across 5 domains:
13
+ Biomedical, General Knowledge, Legal, Customer Support, and Finance.
14
+
15
+
16
+ # RAGBench RAG Evaluation Project
17
+
18
+ This project evaluates a RAG system on the RAGBench dataset across 5 domains:
19
+ Biomedical, General Knowledge, Legal, Customer Support, and Finance.
20
+
21
+ ## 1. Setup (local, no Docker)
22
+
23
+ ```bash
24
+ python -m venv .venv
25
+ source .venv/bin/activate # Windows: .venv\\Scripts\\activate
26
+ pip install --upgrade pip
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ Copy `.env.example` to `.env` and fill in:
31
+
32
+ - HF_TOKEN (if using Hugging Face models)
33
+ - GROQ_API_KEY (if using Groq)
34
+ - RAGBENCH_LLM_PROVIDER = groq or hf
35
+ - RAGBENCH_GEN_MODEL
36
+ - RAGBENCH_JUDGE_MODEL
37
+
38
+ Also open `prompts/ragbench_judge_prompt.txt` and paste the official JSON
39
+ annotation prompt from the RAGBench paper (Appendix 9.4), with placeholders:
40
+ `{documents}`, `{question}`, `{answer}`.
41
+
42
+ ### Run an experiment from CLI
43
+
44
+ ```bash
45
+ python -m scripts.run_experiment --domain biomedical --k 3 --max_examples 10
46
+ ```
47
+
48
+ ## 2. Run FastAPI locally (no Docker)
49
+
50
+ ```bash
51
+ uvicorn app.main:app --host 0.0.0.0 --port 7860
52
+ ```
53
+
54
+ Then open:
55
+
56
+ - `http://localhost:7860/health`
57
+ - `http://localhost:7860/docs` (Swagger UI)
58
+ - POST `/run_domain` with JSON:
59
+
60
+ ```json
61
+ {
62
+ "domain": "biomedical",
63
+ "k": 3,
64
+ "max_examples": 10,
65
+ "split": "test"
66
+ }
67
+ ```
68
+
69
+ ## 3. Run with Docker (local laptop)
70
+
71
+ Build and run:
72
+
73
+ ```bash
74
+ docker compose build
75
+ docker compose up
76
+ ```
77
+
78
+ The API will be available at `http://localhost:8000`.
79
+
80
+ ## 4. Deploy to Hugging Face Space (Docker)
81
+
82
+ 1. Create a new Space with SDK = Docker.
83
+ 2. Push this repo to the Space Git URL.
84
+ 3. On the Space settings, add variables/secrets:
85
+
86
+ - HF_TOKEN
87
+ - GROQ_API_KEY
88
+ - RAGBENCH_LLM_PROVIDER
89
+ - RAGBENCH_GEN_MODEL
90
+ - RAGBENCH_JUDGE_MODEL
91
+
92
+ 4. Once the Space builds successfully, open `/docs` on the Space URL to run
93
+ `/run_domain` for each domain via Swagger UI.
app/1111-main - Copy.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+
4
+ from ragbench_eval.pipeline import RagBenchExperiment
5
+
6
+ app = FastAPI(title="RAGBench RAG Evaluation API")
7
+
8
+
9
+ class RunRequest(BaseModel):
10
+ domain: str
11
+ k: int = 3
12
+ max_examples: int = 20
13
+ split: str = "test"
14
+
15
+
16
+ @app.post("/run_domain")
17
+ def run_domain(req: RunRequest):
18
+ exp = RagBenchExperiment(
19
+ k=req.k,
20
+ max_examples=req.max_examples,
21
+ split=req.split,
22
+ )
23
+ result = exp.run_domain(req.domain)
24
+ return result
25
+
26
+
27
+ @app.get("/health")
28
+ def health():
29
+ return {"status": "ok"}
30
+
31
+
32
+ @app.get("/")
33
+ def root():
34
+ return {
35
+ "message": "RAGBench RAG Evaluation API is running.",
36
+ "endpoints": {
37
+ "health": "/health",
38
+ "docs": "/docs",
39
+ "run_domain": "/run_domain (POST)",
40
+ },
41
+ }
app/222222-main - Copy.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import HTMLResponse
3
+ from pydantic import BaseModel
4
+
5
+ from ragbench_eval.pipeline import RagBenchExperiment
6
+
7
+ app = FastAPI(title="RAGBench RAG Evaluation API")
8
+
9
+
10
+ class RunRequest(BaseModel):
11
+ domain: str
12
+ k: int = 3
13
+ max_examples: int = 20
14
+ split: str = "test"
15
+
16
+
17
+ @app.post("/run_domain")
18
+ def run_domain(req: RunRequest):
19
+ exp = RagBenchExperiment(
20
+ k=req.k,
21
+ max_examples=req.max_examples,
22
+ split=req.split,
23
+ )
24
+ result = exp.run_domain(req.domain)
25
+ return result
26
+
27
+
28
+ @app.get("/health")
29
+ def health():
30
+ return {"status": "ok"}
31
+
32
+
33
+ @app.get("/")
34
+ def root():
35
+ return {
36
+ "message": "RAGBench RAG Evaluation API is running.",
37
+ "endpoints": {
38
+ "health": "/health",
39
+ "docs": "/docs",
40
+ "ui": "/ui",
41
+ "run_domain": "/run_domain (POST)",
42
+ },
43
+ }
44
+
45
+
46
+ # ------------- NEW: simple frontend at /ui -----------------
47
+
48
+ @app.get("/ui", response_class=HTMLResponse)
49
+ def ui():
50
+ html = """
51
+ <!DOCTYPE html>
52
+ <html lang="en">
53
+ <head>
54
+ <meta charset="UTF-8" />
55
+ <title>RAGBench RAG Evaluation UI</title>
56
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
57
+ <style>
58
+ body {
59
+ font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
60
+ margin: 0;
61
+ padding: 0;
62
+ background: #f5f7fa;
63
+ color: #111827;
64
+ }
65
+ .wrapper {
66
+ max-width: 960px;
67
+ margin: 2rem auto;
68
+ padding: 1.5rem;
69
+ background: #ffffff;
70
+ border-radius: 0.75rem;
71
+ box-shadow: 0 10px 25px rgba(0, 0, 0, 0.06);
72
+ }
73
+ h1 {
74
+ margin-top: 0;
75
+ font-size: 1.6rem;
76
+ }
77
+ .row {
78
+ display: flex;
79
+ flex-wrap: wrap;
80
+ gap: 1rem;
81
+ margin-bottom: 1rem;
82
+ }
83
+ .field {
84
+ flex: 1 1 180px;
85
+ min-width: 160px;
86
+ }
87
+ label {
88
+ display: block;
89
+ font-size: 0.85rem;
90
+ font-weight: 600;
91
+ margin-bottom: 0.25rem;
92
+ }
93
+ select, input {
94
+ width: 100%;
95
+ padding: 0.45rem 0.55rem;
96
+ border-radius: 0.375rem;
97
+ border: 1px solid #d1d5db;
98
+ font-size: 0.9rem;
99
+ box-sizing: border-box;
100
+ }
101
+ button {
102
+ padding: 0.55rem 1.2rem;
103
+ border-radius: 999px;
104
+ border: none;
105
+ background: #2563eb;
106
+ color: #ffffff;
107
+ font-weight: 600;
108
+ font-size: 0.95rem;
109
+ cursor: pointer;
110
+ }
111
+ button:disabled {
112
+ opacity: 0.6;
113
+ cursor: default;
114
+ }
115
+ .actions {
116
+ margin-top: 0.5rem;
117
+ margin-bottom: 1rem;
118
+ }
119
+ .status {
120
+ font-size: 0.85rem;
121
+ margin-bottom: 0.5rem;
122
+ color: #4b5563;
123
+ }
124
+ pre {
125
+ background: #0b1020;
126
+ color: #e5e7eb;
127
+ padding: 1rem;
128
+ border-radius: 0.75rem;
129
+ overflow: auto;
130
+ max-height: 480px;
131
+ font-size: 0.8rem;
132
+ }
133
+ code {
134
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
135
+ }
136
+ @media (max-width: 640px) {
137
+ .wrapper {
138
+ margin: 0.5rem;
139
+ border-radius: 0.5rem;
140
+ }
141
+ }
142
+ </style>
143
+ </head>
144
+ <body>
145
+ <div class="wrapper">
146
+ <h1>RAGBench RAG Evaluation</h1>
147
+ <p style="font-size:0.9rem; color:#4b5563;">
148
+ Use this UI to call <code>POST /run_domain</code> and inspect the metrics
149
+ for a given domain. The backend uses the RAGBench dataset and your configured LLMs.
150
+ </p>
151
+
152
+ <div class="row">
153
+ <div class="field">
154
+ <label for="domain">Domain</label>
155
+ <select id="domain">
156
+ <option value="biomedical">Biomedical</option>
157
+ <option value="general_knowledge">General Knowledge</option>
158
+ <option value="legal">Legal</option>
159
+ <option value="customer_support">Customer Support</option>
160
+ <option value="finance">Finance</option>
161
+ </select>
162
+ </div>
163
+
164
+ <div class="field">
165
+ <label for="k">Top-k documents</label>
166
+ <input id="k" type="number" value="3" min="1" />
167
+ </div>
168
+
169
+ <div class="field">
170
+ <label for="max_examples">Max examples</label>
171
+ <input id="max_examples" type="number" value="5" min="1" />
172
+ </div>
173
+
174
+ <div class="field">
175
+ <label for="split">Dataset split</label>
176
+ <input id="split" type="text" value="test" />
177
+ </div>
178
+ </div>
179
+
180
+ <div class="actions">
181
+ <button id="runBtn" onclick="runDomain()">Run Domain Evaluation</button>
182
+ </div>
183
+
184
+ <div class="status" id="status"></div>
185
+
186
+ <pre><code id="output">{}</code></pre>
187
+ </div>
188
+
189
+ <script>
190
+ async function runDomain() {
191
+ const domainEl = document.getElementById("domain");
192
+ const kEl = document.getElementById("k");
193
+ const maxExamplesEl = document.getElementById("max_examples");
194
+ const splitEl = document.getElementById("split");
195
+ const statusEl = document.getElementById("status");
196
+ const outputEl = document.getElementById("output");
197
+ const btn = document.getElementById("runBtn");
198
+
199
+ const domain = domainEl.value;
200
+ const k = parseInt(kEl.value || "3", 10);
201
+ const maxExamples = parseInt(maxExamplesEl.value || "5", 10);
202
+ const split = splitEl.value || "test";
203
+
204
+ const payload = {
205
+ domain: domain,
206
+ k: k,
207
+ max_examples: maxExamples,
208
+ split: split
209
+ };
210
+
211
+ statusEl.textContent = "Running evaluation...";
212
+ btn.disabled = true;
213
+ outputEl.textContent = "{}";
214
+
215
+ try {
216
+ const res = await fetch("/run_domain", {
217
+ method: "POST",
218
+ headers: {
219
+ "Content-Type": "application/json"
220
+ },
221
+ body: JSON.stringify(payload)
222
+ });
223
+
224
+ const data = await res.json();
225
+
226
+ if (!res.ok) {
227
+ statusEl.textContent = "Error " + res.status;
228
+ } else {
229
+ statusEl.textContent = "Done.";
230
+ }
231
+
232
+ outputEl.textContent = JSON.stringify(data, null, 2);
233
+ } catch (err) {
234
+ statusEl.textContent = "Request failed: " + err;
235
+ outputEl.textContent = "{}";
236
+ } finally {
237
+ btn.disabled = false;
238
+ }
239
+ }
240
+ </script>
241
+ </body>
242
+ </html>
243
+ """
244
+ return HTMLResponse(content=html)
app/__init__.py ADDED
File without changes
app/main.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ from fastapi import FastAPI
4
+ from fastapi.responses import HTMLResponse
5
+ from pydantic import BaseModel
6
+ from datasets import load_dataset
7
+
8
+ from ragbench_eval.pipeline import RagBenchExperiment
9
+ from ragbench_eval.retriever import ExampleRetriever
10
+ from ragbench_eval.generator import RAGGenerator
11
+ from ragbench_eval.judge import RAGJudge
12
+ from ragbench_eval.metrics import trace_from_attributes
13
+ from ragbench_eval.config import RAGBENCH_DATASET
14
+
15
+ app = FastAPI(title="RAGBench RAG Evaluation API")
16
+
17
+
18
+ class RunRequest(BaseModel):
19
+ domain: str
20
+ k: int = 3
21
+ max_examples: int = 20
22
+ split: str = "test"
23
+
24
+
25
+ class QAExampleRequest(BaseModel):
26
+ subset: str # e.g. "covidqa", "pubmedqa", "finqa"
27
+ index: int = 0 # which example in that subset
28
+ k: int = 3 # top-k docs
29
+ split: str = "test" # usually "test"
30
+
31
+
32
+ @app.post("/run_domain")
33
+ def run_domain(req: RunRequest):
34
+ exp = RagBenchExperiment(
35
+ k=req.k,
36
+ max_examples=req.max_examples,
37
+ split=req.split,
38
+ )
39
+ result = exp.run_domain(req.domain)
40
+ return result
41
+
42
+
43
+ @app.post("/qa_example")
44
+ def qa_example(req: QAExampleRequest):
45
+ """
46
+ Run RAG on a single RAGBench example and return:
47
+ - question
48
+ - generated answer
49
+ - retrieved docs with sentence keys
50
+ - judge attributes
51
+ - predicted TRACe metrics
52
+ - ground-truth scores from dataset
53
+ """
54
+ ds = load_dataset(RAGBENCH_DATASET, req.subset, split=req.split)
55
+
56
+ if req.index < 0 or req.index >= len(ds):
57
+ return {"error": f"index {req.index} out of range (0..{len(ds)-1})"}
58
+
59
+ row = ds[req.index]
60
+
61
+ docs_sentences_full: List[List[Tuple[str, str]]] = []
62
+ for doc in row["documents_sentences"]:
63
+ docs_sentences_full.append([(k, s) for k, s in doc])
64
+
65
+ question = row["question"]
66
+
67
+ retriever = ExampleRetriever()
68
+ doc_indices = retriever.rank_docs(question, docs_sentences_full, k=req.k)
69
+ selected_docs = [docs_sentences_full[j] for j in doc_indices]
70
+
71
+ generator = RAGGenerator()
72
+ answer = generator.generate(question, selected_docs)
73
+
74
+ judge = RAGJudge()
75
+ attrs = judge.annotate(question, answer, selected_docs)
76
+
77
+ pred_metrics = trace_from_attributes(attrs, selected_docs)
78
+
79
+ docs_view = []
80
+ for doc_i, doc in enumerate(selected_docs):
81
+ docs_view.append({
82
+ "doc_index": doc_indices[doc_i],
83
+ "sentences": [{"key": k, "text": s} for k, s in doc],
84
+ })
85
+
86
+ return {
87
+ "subset": req.subset,
88
+ "index": req.index,
89
+ "question": question,
90
+ "answer": answer,
91
+ "retrieved_docs": docs_view,
92
+ "judge_attributes": attrs,
93
+ "predicted_trace_metrics": pred_metrics,
94
+ "ground_truth": {
95
+ "relevance_score": row.get("relevance_score"),
96
+ "utilization_score": row.get("utilization_score"),
97
+ "completeness_score": row.get("completeness_score"),
98
+ "adherence_score": row.get("adherence_score"),
99
+ },
100
+ }
101
+
102
+
103
+ @app.get("/health")
104
+ def health():
105
+ return {"status": "ok"}
106
+
107
+
108
+ @app.get("/")
109
+ def root():
110
+ return {
111
+ "message": "RAGBench RAG Evaluation API is running.",
112
+ "endpoints": {
113
+ "health": "/health",
114
+ "docs": "/docs",
115
+ "ui": "/ui",
116
+ "run_domain": "/run_domain (POST)",
117
+ "qa_example": "/qa_example (POST)",
118
+ },
119
+ }
120
+
121
+
122
+ @app.get("/ui", response_class=HTMLResponse)
123
+ def ui():
124
+ html = """
125
+ <!DOCTYPE html>
126
+ <html lang="en">
127
+ <head>
128
+ <meta charset="UTF-8" />
129
+ <title>RAGBench RAG Evaluation UI</title>
130
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
131
+ <style>
132
+ body {
133
+ font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
134
+ margin: 0;
135
+ padding: 0;
136
+ background: #f3f4f6;
137
+ color: #111827;
138
+ }
139
+ .wrapper {
140
+ max-width: 1080px;
141
+ margin: 2rem auto;
142
+ padding: 1.5rem;
143
+ }
144
+ .card {
145
+ background: #ffffff;
146
+ border-radius: 0.75rem;
147
+ box-shadow: 0 10px 25px rgba(0, 0, 0, 0.06);
148
+ padding: 1.25rem 1.5rem;
149
+ margin-bottom: 1.5rem;
150
+ }
151
+ h1 {
152
+ margin-top: 0;
153
+ font-size: 1.6rem;
154
+ }
155
+ h2 {
156
+ margin-top: 0;
157
+ font-size: 1.2rem;
158
+ }
159
+ p {
160
+ font-size: 0.9rem;
161
+ color: #4b5563;
162
+ }
163
+ .row {
164
+ display: flex;
165
+ flex-wrap: wrap;
166
+ gap: 1rem;
167
+ margin-bottom: 1rem;
168
+ }
169
+ .field {
170
+ flex: 1 1 180px;
171
+ min-width: 160px;
172
+ }
173
+ label {
174
+ display: block;
175
+ font-size: 0.85rem;
176
+ font-weight: 600;
177
+ margin-bottom: 0.25rem;
178
+ }
179
+ select, input {
180
+ width: 100%;
181
+ padding: 0.45rem 0.55rem;
182
+ border-radius: 0.375rem;
183
+ border: 1px solid #d1d5db;
184
+ font-size: 0.9rem;
185
+ box-sizing: border-box;
186
+ }
187
+ button {
188
+ padding: 0.55rem 1.2rem;
189
+ border-radius: 999px;
190
+ border: none;
191
+ background: #2563eb;
192
+ color: #ffffff;
193
+ font-weight: 600;
194
+ font-size: 0.95rem;
195
+ cursor: pointer;
196
+ }
197
+ button:disabled {
198
+ opacity: 0.6;
199
+ cursor: default;
200
+ }
201
+ .actions {
202
+ margin-top: 0.5rem;
203
+ margin-bottom: 0.75rem;
204
+ }
205
+ .status {
206
+ font-size: 0.85rem;
207
+ margin-bottom: 0.5rem;
208
+ color: #4b5563;
209
+ }
210
+ pre {
211
+ background: #0b1020;
212
+ color: #e5e7eb;
213
+ padding: 1rem;
214
+ border-radius: 0.75rem;
215
+ overflow: auto;
216
+ max-height: 420px;
217
+ font-size: 0.8rem;
218
+ }
219
+ code {
220
+ font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
221
+ }
222
+ @media (max-width: 640px) {
223
+ .wrapper {
224
+ margin: 0.5rem;
225
+ padding: 0.75rem;
226
+ }
227
+ .card {
228
+ padding: 0.9rem 1rem;
229
+ }
230
+ }
231
+ </style>
232
+ </head>
233
+ <body>
234
+ <div class="wrapper">
235
+ <div class="card">
236
+ <h1>RAGBench RAG Evaluation</h1>
237
+ <p>
238
+ This UI lets you:
239
+ (1) run domain-level evaluation on RAGBench, and
240
+ (2) inspect a single example (question, retrieved docs, answer, and metrics).
241
+ </p>
242
+ </div>
243
+
244
+ <!-- Domain evaluation card -->
245
+ <div class="card">
246
+ <h2>1. Domain Evaluation (POST /run_domain)</h2>
247
+ <p>
248
+ Evaluate all subsets in a domain using the configured LLM and retriever.
249
+ </p>
250
+
251
+ <div class="row">
252
+ <div class="field">
253
+ <label for="domain">Domain</label>
254
+ <select id="domain">
255
+ <option value="biomedical">Biomedical</option>
256
+ <option value="general_knowledge">General Knowledge</option>
257
+ <option value="legal">Legal</option>
258
+ <option value="customer_support">Customer Support</option>
259
+ <option value="finance">Finance</option>
260
+ </select>
261
+ </div>
262
+
263
+ <div class="field">
264
+ <label for="k">Top-k documents</label>
265
+ <input id="k" type="number" value="3" min="1" />
266
+ </div>
267
+
268
+ <div class="field">
269
+ <label for="max_examples">Max examples</label>
270
+ <input id="max_examples" type="number" value="5" min="1" />
271
+ </div>
272
+
273
+ <div class="field">
274
+ <label for="split">Dataset split</label>
275
+ <input id="split" type="text" value="test" />
276
+ </div>
277
+ </div>
278
+
279
+ <div class="actions">
280
+ <button id="runBtn" onclick="runDomain()">Run Domain Evaluation</button>
281
+ </div>
282
+
283
+ <div class="status" id="status"></div>
284
+ <pre><code id="output">{}</code></pre>
285
+ </div>
286
+
287
+ <!-- Single example viewer card -->
288
+ <div class="card">
289
+ <h2>2. Single Example Viewer (POST /qa_example)</h2>
290
+ <p>
291
+ Inspect one RAGBench example: question, retrieved documents, answer,
292
+ judge attributes, and TRACe metrics.
293
+ </p>
294
+
295
+ <div class="row">
296
+ <div class="field">
297
+ <label for="subset">Subset</label>
298
+ <input list="subset-list" id="subset" value="covidqa" />
299
+ <datalist id="subset-list">
300
+ <option value="pubmedqa">
301
+ <option value="covidqa">
302
+ <option value="hotpotqa">
303
+ <option value="msmarco">
304
+ <option value="hagrid">
305
+ <option value="expertqa">
306
+ <option value="cuad">
307
+ <option value="delucionqa">
308
+ <option value="emanual">
309
+ <option value="techqa">
310
+ <option value="finqa">
311
+ <option value="tatqa">
312
+ </datalist>
313
+ </div>
314
+
315
+ <div class="field">
316
+ <label for="example_index">Example index</label>
317
+ <input id="example_index" type="number" value="0" min="0" />
318
+ </div>
319
+
320
+ <div class="field">
321
+ <label for="k_example">Top-k documents</label>
322
+ <input id="k_example" type="number" value="3" min="1" />
323
+ </div>
324
+
325
+ <div class="field">
326
+ <label for="split_example">Dataset split</label>
327
+ <input id="split_example" type="text" value="test" />
328
+ </div>
329
+ </div>
330
+
331
+ <div class="actions">
332
+ <button id="qaBtn" onclick="runExample()">Run Single Example</button>
333
+ </div>
334
+
335
+ <div class="status" id="qa_status"></div>
336
+ <pre><code id="qa_output">{}</code></pre>
337
+ </div>
338
+ </div>
339
+
340
+ <script>
341
+ async function runDomain() {
342
+ const domainEl = document.getElementById("domain");
343
+ const kEl = document.getElementById("k");
344
+ const maxExamplesEl = document.getElementById("max_examples");
345
+ const splitEl = document.getElementById("split");
346
+ const statusEl = document.getElementById("status");
347
+ const outputEl = document.getElementById("output");
348
+ const btn = document.getElementById("runBtn");
349
+
350
+ const payload = {
351
+ domain: domainEl.value,
352
+ k: parseInt(kEl.value || "3", 10),
353
+ max_examples: parseInt(maxExamplesEl.value || "5", 10),
354
+ split: splitEl.value || "test"
355
+ };
356
+
357
+ statusEl.textContent = "Running domain evaluation...";
358
+ btn.disabled = true;
359
+ outputEl.textContent = "{}";
360
+
361
+ try {
362
+ const res = await fetch("/run_domain", {
363
+ method: "POST",
364
+ headers: { "Content-Type": "application/json" },
365
+ body: JSON.stringify(payload)
366
+ });
367
+ const data = await res.json();
368
+ if (!res.ok) {
369
+ statusEl.textContent = "Error " + res.status;
370
+ } else {
371
+ statusEl.textContent = "Done.";
372
+ }
373
+ outputEl.textContent = JSON.stringify(data, null, 2);
374
+ } catch (err) {
375
+ statusEl.textContent = "Request failed: " + err;
376
+ outputEl.textContent = "{}";
377
+ } finally {
378
+ btn.disabled = false;
379
+ }
380
+ }
381
+
382
+ async function runExample() {
383
+ const subsetEl = document.getElementById("subset");
384
+ const indexEl = document.getElementById("example_index");
385
+ const kEl = document.getElementById("k_example");
386
+ const splitEl = document.getElementById("split_example");
387
+ const statusEl = document.getElementById("qa_status");
388
+ const outputEl = document.getElementById("qa_output");
389
+ const btn = document.getElementById("qaBtn");
390
+
391
+ const payload = {
392
+ subset: subsetEl.value,
393
+ index: parseInt(indexEl.value || "0", 10),
394
+ k: parseInt(kEl.value || "3", 10),
395
+ split: splitEl.value || "test"
396
+ };
397
+
398
+ statusEl.textContent = "Running single example...";
399
+ btn.disabled = true;
400
+ outputEl.textContent = "{}";
401
+
402
+ try {
403
+ const res = await fetch("/qa_example", {
404
+ method: "POST",
405
+ headers: { "Content-Type": "application/json" },
406
+ body: JSON.stringify(payload)
407
+ });
408
+ const data = await res.json();
409
+ if (!res.ok) {
410
+ statusEl.textContent = "Error " + res.status;
411
+ } else if (data.error) {
412
+ statusEl.textContent = "Backend error: " + data.error;
413
+ } else {
414
+ statusEl.textContent = "Done.";
415
+ }
416
+ outputEl.textContent = JSON.stringify(data, null, 2);
417
+ } catch (err) {
418
+ statusEl.textContent = "Request failed: " + err;
419
+ outputEl.textContent = "{}";
420
+ } finally {
421
+ btn.disabled = false;
422
+ }
423
+ }
424
+ </script>
425
+ </body>
426
+ </html>
427
+ """
428
+ return HTMLResponse(content=html)
docker-compose.yml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.9"
2
+ services:
3
+ ragbench-api:
4
+ build: .
5
+ ports:
6
+ - "8000:8000"
7
+ environment:
8
+ HF_TOKEN: "${HF_TOKEN}"
9
+ GROQ_API_KEY: "${GROQ_API_KEY}"
10
+ RAGBENCH_LLM_PROVIDER: "${RAGBENCH_LLM_PROVIDER:-groq}"
11
+ RAGBENCH_GEN_MODEL: "${RAGBENCH_GEN_MODEL:-llama3-8b-8192}"
12
+ RAGBENCH_JUDGE_MODEL: "${RAGBENCH_JUDGE_MODEL:-llama3-70b-8192}"
prompts/ragbench_judge_prompt.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ IMPORTANT: Replace this file content with the official JSON-format judge prompt
2
+ from the RAGBench paper (Appendix 9.4). Keep the placeholders:
3
+ {documents}
4
+ {question}
5
+ {answer}
6
+ exactly as they are used in their template.
ragbench_eval/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __all__ = []
ragbench_eval/config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ load_dotenv()
5
+
6
+ HF_TOKEN = os.getenv("HF_TOKEN")
7
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
8
+
9
+ LLM_PROVIDER = os.getenv("RAGBENCH_LLM_PROVIDER", "groq") # "groq" or "hf"
10
+ GEN_MODEL = os.getenv("RAGBENCH_GEN_MODEL", "llama3-8b-8192")
11
+ JUDGE_MODEL = os.getenv("RAGBENCH_JUDGE_MODEL", "llama3-70b-8192")
12
+
13
+ EMBEDDING_MODEL = os.getenv(
14
+ "RAGBENCH_EMBEDDING_MODEL",
15
+ "sentence-transformers/all-MiniLM-L6-v2",
16
+ )
17
+
18
+ RAGBENCH_DATASET = os.getenv("RAGBENCH_DATASET", "galileo-ai/ragbench")
19
+
20
+ DOMAIN_TO_SUBSETS = {
21
+ "biomedical": ["pubmedqa", "covidqa"],
22
+ "general_knowledge": ["hotpotqa", "msmarco", "hagrid", "expertqa"],
23
+ "legal": ["cuad"],
24
+ "customer_support": ["delucionqa", "emanual", "techqa"],
25
+ "finance": ["finqa", "tatqa"],
26
+ }
ragbench_eval/generator.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from .llm import LLMClient
3
+ from .config import GEN_MODEL
4
+
5
+
6
+ def build_context_from_docs(
7
+ docs_sentences: List[List[Tuple[str, str]]]
8
+ ) -> str:
9
+ chunks = []
10
+ for doc in docs_sentences:
11
+ text = " ".join(sent for _, sent in doc)
12
+ chunks.append(text)
13
+ return "\n\n".join(chunks)
14
+
15
+
16
+ class RAGGenerator:
17
+ def __init__(self):
18
+ self.client = LLMClient(GEN_MODEL)
19
+
20
+ def generate(self, question: str, docs_sentences: List[List[Tuple[str, str]]]) -> str: # noqa: E501
21
+ context = build_context_from_docs(docs_sentences)
22
+ prompt = (
23
+ "Use the following pieces of context to answer the question.\n\n"
24
+ f"{context}\n\n"
25
+ f"Question: {question}\n\n"
26
+ "Answer:"
27
+ )
28
+ messages = [
29
+ {"role": "system", "content": "You are a precise, grounded QA assistant."}, # noqa: E501
30
+ {"role": "user", "content": prompt},
31
+ ]
32
+ return self.client.chat(messages)
ragbench_eval/judge.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, Dict, List, Tuple
4
+
5
+ from .llm import LLMClient
6
+ from .config import JUDGE_MODEL
7
+
8
+
9
+ def format_docs_with_keys(
10
+ documents_sentences: List[List[Tuple[str, str]]]
11
+ ) -> str:
12
+ blocks = []
13
+ for doc in documents_sentences:
14
+ for key, sent in doc:
15
+ blocks.append(f"{key}: {sent}")
16
+ blocks.append("") # blank line
17
+ return "\n".join(blocks).strip()
18
+
19
+
20
+ class RAGJudge:
21
+ def __init__(self, prompt_path: str = "prompts/ragbench_judge_prompt.txt"):
22
+ self.client = LLMClient(JUDGE_MODEL)
23
+ self.prompt_template = Path(prompt_path).read_text(encoding="utf-8")
24
+
25
+ def annotate(
26
+ self,
27
+ question: str,
28
+ answer: str,
29
+ docs_sentences: List[List[Tuple[str, str]]],
30
+ ) -> Dict[str, Any]:
31
+ docs_block = format_docs_with_keys(docs_sentences)
32
+ prompt = self.prompt_template.format(
33
+ documents=docs_block,
34
+ question=question,
35
+ answer=answer,
36
+ )
37
+ messages = [
38
+ {
39
+ "role": "system",
40
+ "content": "You are an evaluator that outputs STRICT JSON only.",
41
+ },
42
+ {"role": "user", "content": prompt},
43
+ ]
44
+ raw = self.client.chat(messages, max_tokens=2048)
45
+
46
+ try:
47
+ data = json.loads(raw)
48
+ except json.JSONDecodeError as e:
49
+ raise ValueError(f"Judge JSON parse error: {e}\nRaw: {raw[:500]}")
50
+ for key in [
51
+ "relevance_explanation",
52
+ "all_relevant_sentence_keys",
53
+ "overall_supported_explanation",
54
+ "overall_supported",
55
+ "sentence_support_information",
56
+ "all_utilized_sentence_keys",
57
+ ]:
58
+ if key not in data:
59
+ raise ValueError(f"Missing key in judge output: {key}")
60
+ return data
ragbench_eval/llm.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ from .config import LLM_PROVIDER, HF_TOKEN, GROQ_API_KEY
3
+
4
+ from huggingface_hub import InferenceClient
5
+ from groq import Groq
6
+
7
+
8
+ class LLMClient:
9
+ def __init__(self, model: str, is_chat: bool = True):
10
+ self.provider = LLM_PROVIDER
11
+ self.model = model
12
+ self.is_chat = is_chat
13
+
14
+ if self.provider == "hf":
15
+ if not HF_TOKEN:
16
+ raise RuntimeError("HF_TOKEN is required for HF provider")
17
+ self.client = InferenceClient(token=HF_TOKEN)
18
+ elif self.provider == "groq":
19
+ if not GROQ_API_KEY:
20
+ raise RuntimeError("GROQ_API_KEY is required for Groq provider")
21
+ self.client = Groq(api_key=GROQ_API_KEY)
22
+ else:
23
+ raise ValueError(f"Unsupported provider {self.provider}")
24
+
25
+ def chat(self, messages: List[Dict[str, str]], max_tokens: int = 1024) -> str:
26
+ if self.provider == "hf":
27
+ prompt = ""
28
+ for m in messages:
29
+ role = m.get("role", "user")
30
+ content = m.get("content", "")
31
+ prompt += f"[{role.upper()}]\n{content}\n"
32
+ out = self.client.text_generation(
33
+ prompt,
34
+ model=self.model,
35
+ max_new_tokens=max_tokens,
36
+ temperature=0.2,
37
+ do_sample=False,
38
+ )
39
+ return out
40
+ else:
41
+ resp = self.client.chat.completions.create(
42
+ model=self.model,
43
+ messages=messages,
44
+ max_tokens=max_tokens,
45
+ temperature=0.2,
46
+ )
47
+ return resp.choices[0].message.content
ragbench_eval/metrics.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+ import numpy as np
3
+ from sklearn.metrics import mean_squared_error, roc_auc_score
4
+
5
+
6
+ def _all_sentence_keys(
7
+ docs_sentences: List[List[Tuple[str, str]]]
8
+ ) -> List[str]:
9
+ keys: List[str] = []
10
+ for doc in docs_sentences:
11
+ for key, _ in doc:
12
+ keys.append(key)
13
+ return keys
14
+
15
+
16
+ def trace_from_attributes(
17
+ attrs: Dict[str, Any],
18
+ docs_sentences: List[List[Tuple[str, str]]],
19
+ ) -> Dict[str, float]:
20
+ all_keys = _all_sentence_keys(docs_sentences)
21
+ total = len(all_keys)
22
+ if total == 0:
23
+ return {
24
+ "relevance": 0.0,
25
+ "utilization": 0.0,
26
+ "completeness": 0.0,
27
+ "adherence": 0.0,
28
+ }
29
+
30
+ relevant = set(attrs.get("all_relevant_sentence_keys", [])) & set(all_keys)
31
+ utilized = set(attrs.get("all_utilized_sentence_keys", [])) & set(all_keys)
32
+
33
+ relevance = len(relevant) / total if total > 0 else 0.0
34
+ utilization = len(utilized) / total if total > 0 else 0.0
35
+ completeness = (
36
+ len(relevant & utilized) / len(relevant) if relevant else 0.0
37
+ )
38
+ adherence = 1.0 if attrs.get("overall_supported", False) else 0.0
39
+
40
+ return {
41
+ "relevance": float(relevance),
42
+ "utilization": float(utilization),
43
+ "completeness": float(completeness),
44
+ "adherence": float(adherence),
45
+ }
46
+
47
+
48
+ def compute_rmse_auc(
49
+ y_true_rel: List[float],
50
+ y_pred_rel: List[float],
51
+ y_true_util: List[float],
52
+ y_pred_util: List[float],
53
+ y_true_comp: List[float],
54
+ y_pred_comp: List[float],
55
+ y_true_adh: List[int],
56
+ y_pred_adh: List[float],
57
+ ) -> Dict[str, float]:
58
+ metrics = {
59
+ "rmse_relevance": float(
60
+ mean_squared_error(y_true_rel, y_pred_rel, squared=False)
61
+ ),
62
+ "rmse_utilization": float(
63
+ mean_squared_error(y_true_util, y_pred_util, squared=False)
64
+ ),
65
+ "rmse_completeness": float(
66
+ mean_squared_error(y_true_comp, y_pred_comp, squared=False)
67
+ ),
68
+ }
69
+
70
+ if len(set(y_true_adh)) > 1:
71
+ metrics["auroc_adherence"] = float(
72
+ roc_auc_score(y_true_adh, y_pred_adh)
73
+ )
74
+ else:
75
+ metrics["auroc_adherence"] = float("nan")
76
+
77
+ return metrics
ragbench_eval/pipeline.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List, Tuple, Optional
2
+ from datasets import load_dataset
3
+
4
+ from .config import RAGBENCH_DATASET, DOMAIN_TO_SUBSETS
5
+ from .retriever import ExampleRetriever
6
+ from .generator import RAGGenerator
7
+ from .judge import RAGJudge
8
+ from .metrics import trace_from_attributes, compute_rmse_auc
9
+
10
+
11
+ class RagBenchExperiment:
12
+ def __init__(
13
+ self,
14
+ k: int = 3,
15
+ max_examples: Optional[int] = None,
16
+ split: str = "test",
17
+ ):
18
+ self.k = k
19
+ self.max_examples = max_examples
20
+ self.split = split
21
+
22
+ self.retriever = ExampleRetriever()
23
+ self.generator = RAGGenerator()
24
+ self.judge = RAGJudge()
25
+
26
+ def _load_subset(self, subset: str):
27
+ ds = load_dataset(
28
+ RAGBENCH_DATASET, subset, split=self.split
29
+ )
30
+ return ds
31
+
32
+ def _to_docs_sentences(self, row) -> List[List[Tuple[str, str]]]:
33
+ docs: List[List[Tuple[str, str]]] = []
34
+ for doc in row["documents_sentences"]:
35
+ docs.append([(k, s) for k, s in doc])
36
+ return docs
37
+
38
+ def run_subset(self, subset: str) -> Dict[str, Any]:
39
+ ds = self._load_subset(subset)
40
+
41
+ y_true_rel: List[float] = []
42
+ y_pred_rel: List[float] = []
43
+ y_true_util: List[float] = []
44
+ y_pred_util: List[float] = []
45
+ y_true_comp: List[float] = []
46
+ y_pred_comp: List[float] = []
47
+ y_true_adh: List[int] = []
48
+ y_pred_adh: List[float] = []
49
+
50
+ for i, row in enumerate(ds):
51
+ if self.max_examples is not None and i >= self.max_examples:
52
+ break
53
+
54
+ question = row["question"]
55
+ docs_sentences_full = self._to_docs_sentences(row)
56
+
57
+ doc_indices = self.retriever.rank_docs(
58
+ question, docs_sentences_full, k=self.k
59
+ )
60
+ selected_docs = [docs_sentences_full[j] for j in doc_indices]
61
+
62
+ answer = self.generator.generate(question, selected_docs)
63
+
64
+ attrs = self.judge.annotate(question, answer, selected_docs)
65
+
66
+ pred = trace_from_attributes(attrs, selected_docs)
67
+
68
+ y_true_rel.append(float(row["relevance_score"]))
69
+ y_true_util.append(float(row["utilization_score"]))
70
+ y_true_comp.append(float(row["completeness_score"]))
71
+ y_true_adh.append(int(row["adherence_score"]))
72
+
73
+ y_pred_rel.append(pred["relevance"])
74
+ y_pred_util.append(pred["utilization"])
75
+ y_pred_comp.append(pred["completeness"])
76
+ y_pred_adh.append(pred["adherence"])
77
+
78
+ metrics = compute_rmse_auc(
79
+ y_true_rel,
80
+ y_pred_rel,
81
+ y_true_util,
82
+ y_pred_util,
83
+ y_true_comp,
84
+ y_pred_comp,
85
+ y_true_adh,
86
+ y_pred_adh,
87
+ )
88
+
89
+ return {
90
+ "subset": subset,
91
+ "n_examples": len(y_true_rel),
92
+ **metrics,
93
+ }
94
+
95
+ def run_domain(self, domain: str) -> Dict[str, Any]:
96
+ subsets = DOMAIN_TO_SUBSETS[domain]
97
+ results = []
98
+ for subset in subsets:
99
+ res = self.run_subset(subset)
100
+ results.append(res)
101
+ return {
102
+ "domain": domain,
103
+ "subsets": results,
104
+ }
ragbench_eval/retriever.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import numpy as np
3
+ from sentence_transformers import SentenceTransformer
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
+
6
+ from .config import EMBEDDING_MODEL
7
+
8
+
9
+ class ExampleRetriever:
10
+ """Ranks the per-example documents in RAGBench by similarity to the question.""" # noqa: E501
11
+
12
+ def __init__(self):
13
+ self.embedder = SentenceTransformer(EMBEDDING_MODEL)
14
+
15
+ def _encode(self, texts: List[str]) -> np.ndarray:
16
+ return self.embedder.encode(texts, show_progress_bar=False)
17
+
18
+ def rank_docs(
19
+ self,
20
+ question: str,
21
+ documents_sentences: List[List[Tuple[str, str]]],
22
+ k: int = 4,
23
+ ) -> List[int]:
24
+ doc_texts = [
25
+ " ".join(sent for _, sent in doc) for doc in documents_sentences
26
+ ]
27
+ q_emb = self._encode([question])
28
+ d_emb = self._encode(doc_texts)
29
+
30
+ sims = cosine_similarity(q_emb, d_emb)[0]
31
+ topk_idx = np.argsort(sims)[::-1][:k]
32
+ return topk_idx.tolist()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets==2.21.0
2
+ sentence-transformers==3.0.1
3
+ scikit-learn==1.5.2
4
+ numpy==1.26.4
5
+ pydantic==2.9.2
6
+ fastapi==0.115.5
7
+ uvicorn[standard]==0.32.0
8
+ python-dotenv==1.0.1
9
+ huggingface_hub[inference]==0.26.2
10
+ groq==0.9.0
scripts/run_experiment.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from ragbench_eval.pipeline import RagBenchExperiment
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument(
9
+ "--domain",
10
+ type=str,
11
+ required=True,
12
+ choices=[
13
+ "biomedical",
14
+ "general_knowledge",
15
+ "legal",
16
+ "customer_support",
17
+ "finance",
18
+ ],
19
+ )
20
+ parser.add_argument("--k", type=int, default=3)
21
+ parser.add_argument("--max_examples", type=int, default=50)
22
+ parser.add_argument("--split", type=str, default="test")
23
+ args = parser.parse_args()
24
+
25
+ exp = RagBenchExperiment(
26
+ k=args.k,
27
+ max_examples=args.max_examples,
28
+ split=args.split,
29
+ )
30
+ results = exp.run_domain(args.domain)
31
+ print(json.dumps(results, indent=2))
32
+
33
+
34
+ if __name__ == "__main__":
35
+ main()