Spaces:
Running
Running
| # app/main.py | |
| import os | |
| from typing import List, Tuple, Optional | |
| import requests | |
| from datasets import load_dataset | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel | |
| from ragbench_eval.pipeline import RagBenchExperiment | |
| from ragbench_eval.retriever import ExampleRetriever | |
| from ragbench_eval.generator import RAGGenerator | |
| from ragbench_eval.judge import RAGJudge | |
| from ragbench_eval.metrics import trace_from_attributes | |
| from ragbench_eval.config import RAGBENCH_DATASET, DOMAIN_TO_SUBSETS | |
| # --------------------------------------------------------------------- | |
| # FastAPI app | |
| # --------------------------------------------------------------------- | |
| app = FastAPI(title="RAGBench Chat + RAG Evaluation API") | |
| # --------------------------------------------------------------------- | |
| # Config for Hugging Face router (LLM chat) | |
| # --------------------------------------------------------------------- | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| HF_CHAT_BASE_URL = os.getenv("HF_CHAT_BASE_URL", "https://router.huggingface.co/v1") | |
| HF_CHAT_MODEL = os.getenv( | |
| "HF_CHAT_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct" | |
| ) | |
| if HF_TOKEN is None or HF_TOKEN.strip() == "": | |
| # We don't crash the app, but /chat will raise a clear error | |
| print("[RAGBench Chat] WARNING: HF_TOKEN is not set. /chat will fail.") | |
| # --------------------------------------------------------------------- | |
| # Request models | |
| # --------------------------------------------------------------------- | |
| class RunRequest(BaseModel): | |
| domain: str # "biomedical", "general_knowledge", ... | |
| k: int = 3 | |
| max_examples: Optional[int] = 20 | |
| split: str = "test" # "test" or "validation" | |
| class QAExampleRequest(BaseModel): | |
| subset: str # e.g. "covidqa", "pubmedqa" | |
| index: int = 0 # which row from that subset | |
| k: int = 3 | |
| split: str = "test" | |
| class ChatRequest(BaseModel): | |
| domain: str # must be one of DOMAIN_TO_SUBSETS keys | |
| question: str | |
| # --------------------------------------------------------------------- | |
| # LLM Chat endpoint (using HF router OpenAI-compatible API) | |
| # --------------------------------------------------------------------- | |
| def chat(req: ChatRequest): | |
| """ | |
| Simple domain-aware chat endpoint. | |
| Uses Hugging Face router OpenAI-compatible Chat Completions API: | |
| POST {HF_CHAT_BASE_URL}/chat/completions | |
| """ | |
| if not HF_TOKEN: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="HF_TOKEN environment variable is not set in the backend.", | |
| ) | |
| if req.domain not in DOMAIN_TO_SUBSETS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unknown domain '{req.domain}'. " | |
| f"Valid domains: {', '.join(DOMAIN_TO_SUBSETS.keys())}", | |
| ) | |
| system_prompt = ( | |
| "You are an assistant answering questions in the domain: " | |
| f"{req.domain}. " | |
| "Answer using correct, verifiable information. " | |
| "If you are not sure, clearly say that you are not sure instead of " | |
| "guessing. Be concise and avoid fabricating facts." | |
| ) | |
| payload = { | |
| "model": HF_CHAT_MODEL, | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": req.question}, | |
| ], | |
| "temperature": 0.2, | |
| "max_tokens": 512, | |
| } | |
| try: | |
| resp = requests.post( | |
| f"{HF_CHAT_BASE_URL}/chat/completions", | |
| headers={ | |
| "Authorization": f"Bearer {HF_TOKEN}", | |
| "Content-Type": "application/json", | |
| }, | |
| json=payload, | |
| timeout=60, | |
| ) | |
| resp.raise_for_status() | |
| except requests.exceptions.RequestException as e: | |
| # Surface clear error to frontend (will show "Error: HTTP 500") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Hugging Face router request failed: {e}", | |
| ) | |
| data = resp.json() | |
| try: | |
| answer = data["choices"][0]["message"]["content"] | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Unexpected response format from HF router: {e}", | |
| ) | |
| return {"answer": answer} | |
| # --------------------------------------------------------------------- | |
| # RAGBench evaluation endpoints | |
| # --------------------------------------------------------------------- | |
| def run_domain(req: RunRequest): | |
| """ | |
| Run a full RAGBench evaluation over a domain (biomedical, finance, etc.) | |
| """ | |
| exp = RagBenchExperiment( | |
| k=req.k, | |
| max_examples=req.max_examples, | |
| split=req.split, | |
| ) | |
| result = exp.run_domain(req.domain) | |
| return result | |
| def qa_example(req: QAExampleRequest): | |
| """ | |
| Run RAG (retriever + generator + judge) on a single RAGBench example. | |
| """ | |
| ds = load_dataset(RAGBENCH_DATASET, req.subset, split=req.split) | |
| if req.index < 0 or req.index >= len(ds): | |
| return {"error": f"index {req.index} out of range (0..{len(ds) - 1})"} | |
| row = ds[req.index] | |
| # Build full per-document sentence lists | |
| docs_sentences_full: List[List[Tuple[str, str]]] = [] | |
| for doc in row["documents_sentences"]: | |
| docs_sentences_full.append([(k, s) for k, s in doc]) | |
| question = row["question"] | |
| # 1) Retrieve top-k docs | |
| retriever = ExampleRetriever() | |
| doc_indices = retriever.rank_docs(question, docs_sentences_full, k=req.k) | |
| selected_docs = [docs_sentences_full[j] for j in doc_indices] | |
| # 2) Generate answer from retrieved docs | |
| generator = RAGGenerator() | |
| answer = generator.generate(question, selected_docs) | |
| # 3) Judge + metrics | |
| judge = RAGJudge() | |
| attrs = judge.annotate(question, answer, selected_docs) | |
| pred_metrics = trace_from_attributes(attrs, selected_docs) | |
| docs_view = [] | |
| for i, doc in enumerate(selected_docs): | |
| docs_view.append( | |
| { | |
| "doc_index": doc_indices[i], | |
| "sentences": [{"key": k, "text": s} for k, s in doc], | |
| } | |
| ) | |
| return { | |
| "subset": req.subset, | |
| "index": req.index, | |
| "question": question, | |
| "answer": answer, | |
| "retrieved_docs": docs_view, | |
| "judge_attributes": attrs, | |
| "predicted_trace_metrics": pred_metrics, | |
| "ground_truth": { | |
| "relevance_score": row.get("relevance_score"), | |
| "utilization_score": row.get("utilization_score"), | |
| "completeness_score": row.get("completeness_score"), | |
| "adherence_score": row.get("adherence_score"), | |
| }, | |
| } | |
| def health(): | |
| return {"status": "ok"} | |
| # --------------------------------------------------------------------- | |
| # HTML Chat UI at root "/" | |
| # --------------------------------------------------------------------- | |
| def chat_ui(): | |
| html = """ | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8" /> | |
| <title>RAGBench Chat</title> | |
| <meta name="viewport" content="width=device-width, initial-scale=1" /> | |
| <style> | |
| * { box-sizing: border-box; } | |
| body { | |
| margin: 0; | |
| font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", | |
| sans-serif; | |
| background: #f3f4f6; | |
| color: #111827; | |
| } | |
| .app { | |
| max-width: 960px; | |
| margin: 0 auto; | |
| height: 100vh; | |
| display: flex; | |
| flex-direction: column; | |
| padding: 0.75rem; | |
| } | |
| header { | |
| display: flex; | |
| align-items: center; | |
| justify-content: space-between; | |
| margin-bottom: 0.75rem; | |
| } | |
| header h1 { | |
| font-size: 1.1rem; | |
| margin: 0; | |
| } | |
| header small { | |
| color: #6b7280; | |
| font-size: 0.75rem; | |
| } | |
| select { | |
| padding: 0.3rem 0.5rem; | |
| border-radius: 999px; | |
| border: 1px solid #d1d5db; | |
| background: white; | |
| font-size: 0.85rem; | |
| } | |
| .chat-window { | |
| flex: 1; | |
| background: white; | |
| border-radius: 0.75rem; | |
| padding: 0.75rem; | |
| overflow-y: auto; | |
| border: 1px solid #e5e7eb; | |
| } | |
| .message-row { | |
| display: flex; | |
| margin-bottom: 0.5rem; | |
| } | |
| .message-row.user { | |
| justify-content: flex-end; | |
| } | |
| .bubble { | |
| max-width: 70%; | |
| padding: 0.5rem 0.75rem; | |
| border-radius: 0.75rem; | |
| font-size: 0.9rem; | |
| line-height: 1.35; | |
| white-space: pre-wrap; | |
| word-wrap: break-word; | |
| } | |
| .bubble.user { | |
| background: #2563eb; | |
| color: white; | |
| border-bottom-right-radius: 0.15rem; | |
| } | |
| .bubble.assistant { | |
| background: #e5e7eb; | |
| color: #111827; | |
| border-bottom-left-radius: 0.15rem; | |
| } | |
| .status { | |
| margin-top: 0.25rem; | |
| font-size: 0.8rem; | |
| color: #6b7280; | |
| min-height: 1.1rem; | |
| } | |
| form { | |
| margin-top: 0.5rem; | |
| display: flex; | |
| gap: 0.5rem; | |
| } | |
| textarea { | |
| flex: 1; | |
| resize: none; | |
| min-height: 48px; | |
| max-height: 96px; | |
| padding: 0.5rem 0.6rem; | |
| border-radius: 999px; | |
| border: 1px solid #d1d5db; | |
| font-size: 0.9rem; | |
| } | |
| button { | |
| border: none; | |
| border-radius: 999px; | |
| padding: 0 1.2rem; | |
| background: #2563eb; | |
| color: white; | |
| font-size: 0.9rem; | |
| font-weight: 500; | |
| cursor: pointer; | |
| } | |
| button:disabled { | |
| opacity: 0.6; | |
| cursor: default; | |
| } | |
| @media (max-width: 640px) { | |
| .app { padding: 0.5rem; } | |
| .chat-window { padding: 0.5rem; } | |
| .bubble { max-width: 82%; } | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="app"> | |
| <header> | |
| <div> | |
| <h1>RAGBench Chat</h1> | |
| <small>Select a domain, then start chatting.</small> | |
| </div> | |
| <div> | |
| <select id="domain"> | |
| <option value="biomedical">Biomedical</option> | |
| <option value="general_knowledge">General knowledge</option> | |
| <option value="customer_support">Customer support</option> | |
| <option value="finance">Finance</option> | |
| <option value="legal">Legal</option> | |
| </select> | |
| </div> | |
| </header> | |
| <div id="chat" class="chat-window"></div> | |
| <div id="status" class="status"></div> | |
| <form id="chat-form"> | |
| <textarea id="question" placeholder="Ask a question..."></textarea> | |
| <button type="submit">Send</button> | |
| </form> | |
| </div> | |
| <script> | |
| const form = document.getElementById("chat-form"); | |
| const questionEl = document.getElementById("question"); | |
| const domainEl = document.getElementById("domain"); | |
| const chatEl = document.getElementById("chat"); | |
| const statusEl = document.getElementById("status"); | |
| function addMessage(role, text) { | |
| const row = document.createElement("div"); | |
| row.className = "message-row " + role; | |
| const bubble = document.createElement("div"); | |
| bubble.className = "bubble " + role; | |
| bubble.textContent = text; | |
| row.appendChild(bubble); | |
| chatEl.appendChild(row); | |
| chatEl.scrollTop = chatEl.scrollHeight; | |
| } | |
| form.addEventListener("submit", async (e) => { | |
| e.preventDefault(); | |
| const question = questionEl.value.trim(); | |
| if (!question) return; | |
| const domain = domainEl.value; | |
| addMessage("user", question); | |
| questionEl.value = ""; | |
| statusEl.textContent = "Thinking..."; | |
| form.querySelector("button").disabled = true; | |
| try { | |
| const resp = await fetch("/chat", { | |
| method: "POST", | |
| headers: { "Content-Type": "application/json" }, | |
| body: JSON.stringify({ domain, question }), | |
| }); | |
| if (!resp.ok) { | |
| throw new Error("HTTP " + resp.status); | |
| } | |
| const data = await resp.json(); | |
| addMessage("assistant", data.answer || "[No answer returned]"); | |
| statusEl.textContent = ""; | |
| } catch (err) { | |
| statusEl.textContent = "Error: " + err; | |
| } finally { | |
| form.querySelector("button").disabled = false; | |
| } | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html) | |
| # --------------------------------------------------------------------- | |
| # HTML RAGBench Evaluation UI at "/eval" | |
| # --------------------------------------------------------------------- | |
| def eval_ui(): | |
| """Simple page to run /run_domain evaluations from the browser.""" | |
| html = """ | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8" /> | |
| <title>RAGBench RAG Evaluation</title> | |
| <meta name="viewport" content="width=device-width, initial-scale=1" /> | |
| <style> | |
| * { box-sizing: border-box; } | |
| body { | |
| margin: 0; | |
| font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; | |
| background: #f3f4f6; | |
| color: #111827; | |
| } | |
| .app { | |
| max-width: 960px; | |
| margin: 0 auto; | |
| padding: 24px 16px 40px; | |
| } | |
| header { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: baseline; | |
| margin-bottom: 20px; | |
| } | |
| header h1 { | |
| font-size: 1.6rem; | |
| margin: 0; | |
| } | |
| header nav a { | |
| font-size: 0.9rem; | |
| color: #2563eb; | |
| text-decoration: none; | |
| margin-left: 12px; | |
| } | |
| header nav a:hover { | |
| text-decoration: underline; | |
| } | |
| .card { | |
| background: #ffffff; | |
| border-radius: 12px; | |
| box-shadow: 0 10px 25px rgba(15, 23, 42, 0.08); | |
| padding: 20px 20px 24px; | |
| margin-bottom: 24px; | |
| } | |
| .card-title { | |
| font-size: 1.1rem; | |
| font-weight: 600; | |
| margin-bottom: 12px; | |
| } | |
| .form-row { | |
| display: grid; | |
| grid-template-columns: repeat(4, minmax(0, 1fr)); | |
| gap: 12px; | |
| margin-bottom: 12px; | |
| } | |
| label { | |
| display: block; | |
| font-size: 0.8rem; | |
| font-weight: 500; | |
| color: #4b5563; | |
| margin-bottom: 4px; | |
| } | |
| select, input { | |
| width: 100%; | |
| padding: 6px 8px; | |
| border-radius: 8px; | |
| border: 1px solid #d1d5db; | |
| font-size: 0.9rem; | |
| outline: none; | |
| } | |
| select:focus, input:focus { | |
| border-color: #2563eb; | |
| box-shadow: 0 0 0 1px #2563eb33; | |
| } | |
| button { | |
| padding: 8px 16px; | |
| border-radius: 999px; | |
| border: none; | |
| background: #2563eb; | |
| color: #ffffff; | |
| font-weight: 500; | |
| font-size: 0.95rem; | |
| cursor: pointer; | |
| } | |
| button:disabled { | |
| opacity: 0.6; | |
| cursor: default; | |
| } | |
| #status { | |
| font-size: 0.85rem; | |
| color: #6b7280; | |
| margin-left: 12px; | |
| } | |
| pre { | |
| background: #111827; | |
| color: #e5e7eb; | |
| padding: 16px; | |
| border-radius: 10px; | |
| font-size: 0.8rem; | |
| overflow: auto; | |
| max-height: 60vh; | |
| } | |
| @media (max-width: 768px) { | |
| .form-row { | |
| grid-template-columns: repeat(2, minmax(0, 1fr)); | |
| } | |
| } | |
| @media (max-width: 480px) { | |
| .form-row { | |
| grid-template-columns: 1fr; | |
| } | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="app"> | |
| <header> | |
| <h1>RAGBench RAG Evaluation</h1> | |
| <nav> | |
| <a href="/">Chat</a> | |
| <a href="/eval"><strong>Evaluation</strong></a> | |
| </nav> | |
| </header> | |
| <div class="card"> | |
| <div class="card-title">Run Domain Evaluation</div> | |
| <form id="eval-form"> | |
| <div class="form-row"> | |
| <div> | |
| <label for="domain">Domain</label> | |
| <select id="domain" name="domain"> | |
| <option value="biomedical">Biomedical</option> | |
| <option value="general_knowledge">General Knowledge</option> | |
| <option value="legal">Legal</option> | |
| <option value="customer_support">Customer Support</option> | |
| <option value="finance">Finance</option> | |
| </select> | |
| </div> | |
| <div> | |
| <label for="k">Top-k documents</label> | |
| <input id="k" name="k" type="number" min="1" value="3" /> | |
| </div> | |
| <div> | |
| <label for="max_examples">Max examples</label> | |
| <input id="max_examples" name="max_examples" type="number" min="1" value="5" /> | |
| </div> | |
| <div> | |
| <label for="split">Dataset split</label> | |
| <select id="split" style="padding:10px;border-radius:8px;border:1px solid #d0d7de;"> | |
| <option value="test">test</option> | |
| <option value="train">train</option> | |
| <option value="validation">validation</option> | |
| </select> | |
| </div> | |
| </div> | |
| <button type="submit">Run Domain Evaluation</button> | |
| <span id="status"></span> | |
| </form> | |
| </div> | |
| <div class="card"> | |
| <div class="card-title">Results</div> | |
| <pre id="output">{}</pre> | |
| </div> | |
| </div> | |
| <script> | |
| const form = document.getElementById("eval-form"); | |
| const statusEl = document.getElementById("status"); | |
| const outputEl = document.getElementById("output"); | |
| form.addEventListener("submit", async (ev) => { | |
| ev.preventDefault(); | |
| statusEl.textContent = "Running evaluation..."; | |
| outputEl.textContent = "{}"; | |
| form.querySelector("button").disabled = true; | |
| const domain = document.getElementById("domain").value; | |
| const k = parseInt(document.getElementById("k").value || "3", 10); | |
| const maxExamplesRaw = document.getElementById("max_examples").value; | |
| const split = document.getElementById("split").value; | |
| const payload = { | |
| domain: domain, | |
| k: k, | |
| split: split, | |
| }; | |
| if (maxExamplesRaw !== "") { | |
| payload.max_examples = parseInt(maxExamplesRaw, 10); | |
| } | |
| try { | |
| const resp = await fetch("/run_domain", { | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| }, | |
| body: JSON.stringify(payload), | |
| }); | |
| if (!resp.ok) { | |
| let msg = "HTTP " + resp.status; | |
| try { | |
| const errData = await resp.json(); | |
| if (errData.detail) { | |
| msg += " - " + errData.detail; | |
| } | |
| } catch (_) {} | |
| throw new Error(msg); | |
| } | |
| const data = await resp.json(); | |
| statusEl.textContent = "Done."; | |
| outputEl.textContent = JSON.stringify(data, null, 2); | |
| } catch (err) { | |
| statusEl.textContent = "Error: " + err; | |
| outputEl.textContent = "{}"; | |
| } finally { | |
| form.querySelector("button").disabled = false; | |
| } | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html) | |