Renangi's picture
main dot py file added both separate pages evaluation and chat
2942435
# app/main.py
import os
from typing import List, Tuple, Optional
import requests
from datasets import load_dataset
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from ragbench_eval.pipeline import RagBenchExperiment
from ragbench_eval.retriever import ExampleRetriever
from ragbench_eval.generator import RAGGenerator
from ragbench_eval.judge import RAGJudge
from ragbench_eval.metrics import trace_from_attributes
from ragbench_eval.config import RAGBENCH_DATASET, DOMAIN_TO_SUBSETS
# ---------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------
app = FastAPI(title="RAGBench Chat + RAG Evaluation API")
# ---------------------------------------------------------------------
# Config for Hugging Face router (LLM chat)
# ---------------------------------------------------------------------
HF_TOKEN = os.getenv("HF_TOKEN")
HF_CHAT_BASE_URL = os.getenv("HF_CHAT_BASE_URL", "https://router.huggingface.co/v1")
HF_CHAT_MODEL = os.getenv(
"HF_CHAT_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct"
)
if HF_TOKEN is None or HF_TOKEN.strip() == "":
# We don't crash the app, but /chat will raise a clear error
print("[RAGBench Chat] WARNING: HF_TOKEN is not set. /chat will fail.")
# ---------------------------------------------------------------------
# Request models
# ---------------------------------------------------------------------
class RunRequest(BaseModel):
domain: str # "biomedical", "general_knowledge", ...
k: int = 3
max_examples: Optional[int] = 20
split: str = "test" # "test" or "validation"
class QAExampleRequest(BaseModel):
subset: str # e.g. "covidqa", "pubmedqa"
index: int = 0 # which row from that subset
k: int = 3
split: str = "test"
class ChatRequest(BaseModel):
domain: str # must be one of DOMAIN_TO_SUBSETS keys
question: str
# ---------------------------------------------------------------------
# LLM Chat endpoint (using HF router OpenAI-compatible API)
# ---------------------------------------------------------------------
@app.post("/chat")
def chat(req: ChatRequest):
"""
Simple domain-aware chat endpoint.
Uses Hugging Face router OpenAI-compatible Chat Completions API:
POST {HF_CHAT_BASE_URL}/chat/completions
"""
if not HF_TOKEN:
raise HTTPException(
status_code=500,
detail="HF_TOKEN environment variable is not set in the backend.",
)
if req.domain not in DOMAIN_TO_SUBSETS:
raise HTTPException(
status_code=400,
detail=f"Unknown domain '{req.domain}'. "
f"Valid domains: {', '.join(DOMAIN_TO_SUBSETS.keys())}",
)
system_prompt = (
"You are an assistant answering questions in the domain: "
f"{req.domain}. "
"Answer using correct, verifiable information. "
"If you are not sure, clearly say that you are not sure instead of "
"guessing. Be concise and avoid fabricating facts."
)
payload = {
"model": HF_CHAT_MODEL,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": req.question},
],
"temperature": 0.2,
"max_tokens": 512,
}
try:
resp = requests.post(
f"{HF_CHAT_BASE_URL}/chat/completions",
headers={
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json",
},
json=payload,
timeout=60,
)
resp.raise_for_status()
except requests.exceptions.RequestException as e:
# Surface clear error to frontend (will show "Error: HTTP 500")
raise HTTPException(
status_code=500,
detail=f"Hugging Face router request failed: {e}",
)
data = resp.json()
try:
answer = data["choices"][0]["message"]["content"]
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Unexpected response format from HF router: {e}",
)
return {"answer": answer}
# ---------------------------------------------------------------------
# RAGBench evaluation endpoints
# ---------------------------------------------------------------------
@app.post("/run_domain")
def run_domain(req: RunRequest):
"""
Run a full RAGBench evaluation over a domain (biomedical, finance, etc.)
"""
exp = RagBenchExperiment(
k=req.k,
max_examples=req.max_examples,
split=req.split,
)
result = exp.run_domain(req.domain)
return result
@app.post("/qa_example")
def qa_example(req: QAExampleRequest):
"""
Run RAG (retriever + generator + judge) on a single RAGBench example.
"""
ds = load_dataset(RAGBENCH_DATASET, req.subset, split=req.split)
if req.index < 0 or req.index >= len(ds):
return {"error": f"index {req.index} out of range (0..{len(ds) - 1})"}
row = ds[req.index]
# Build full per-document sentence lists
docs_sentences_full: List[List[Tuple[str, str]]] = []
for doc in row["documents_sentences"]:
docs_sentences_full.append([(k, s) for k, s in doc])
question = row["question"]
# 1) Retrieve top-k docs
retriever = ExampleRetriever()
doc_indices = retriever.rank_docs(question, docs_sentences_full, k=req.k)
selected_docs = [docs_sentences_full[j] for j in doc_indices]
# 2) Generate answer from retrieved docs
generator = RAGGenerator()
answer = generator.generate(question, selected_docs)
# 3) Judge + metrics
judge = RAGJudge()
attrs = judge.annotate(question, answer, selected_docs)
pred_metrics = trace_from_attributes(attrs, selected_docs)
docs_view = []
for i, doc in enumerate(selected_docs):
docs_view.append(
{
"doc_index": doc_indices[i],
"sentences": [{"key": k, "text": s} for k, s in doc],
}
)
return {
"subset": req.subset,
"index": req.index,
"question": question,
"answer": answer,
"retrieved_docs": docs_view,
"judge_attributes": attrs,
"predicted_trace_metrics": pred_metrics,
"ground_truth": {
"relevance_score": row.get("relevance_score"),
"utilization_score": row.get("utilization_score"),
"completeness_score": row.get("completeness_score"),
"adherence_score": row.get("adherence_score"),
},
}
@app.get("/health")
def health():
return {"status": "ok"}
# ---------------------------------------------------------------------
# HTML Chat UI at root "/"
# ---------------------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
def chat_ui():
html = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<title>RAGBench Chat</title>
<meta name="viewport" content="width=device-width, initial-scale=1" />
<style>
* { box-sizing: border-box; }
body {
margin: 0;
font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI",
sans-serif;
background: #f3f4f6;
color: #111827;
}
.app {
max-width: 960px;
margin: 0 auto;
height: 100vh;
display: flex;
flex-direction: column;
padding: 0.75rem;
}
header {
display: flex;
align-items: center;
justify-content: space-between;
margin-bottom: 0.75rem;
}
header h1 {
font-size: 1.1rem;
margin: 0;
}
header small {
color: #6b7280;
font-size: 0.75rem;
}
select {
padding: 0.3rem 0.5rem;
border-radius: 999px;
border: 1px solid #d1d5db;
background: white;
font-size: 0.85rem;
}
.chat-window {
flex: 1;
background: white;
border-radius: 0.75rem;
padding: 0.75rem;
overflow-y: auto;
border: 1px solid #e5e7eb;
}
.message-row {
display: flex;
margin-bottom: 0.5rem;
}
.message-row.user {
justify-content: flex-end;
}
.bubble {
max-width: 70%;
padding: 0.5rem 0.75rem;
border-radius: 0.75rem;
font-size: 0.9rem;
line-height: 1.35;
white-space: pre-wrap;
word-wrap: break-word;
}
.bubble.user {
background: #2563eb;
color: white;
border-bottom-right-radius: 0.15rem;
}
.bubble.assistant {
background: #e5e7eb;
color: #111827;
border-bottom-left-radius: 0.15rem;
}
.status {
margin-top: 0.25rem;
font-size: 0.8rem;
color: #6b7280;
min-height: 1.1rem;
}
form {
margin-top: 0.5rem;
display: flex;
gap: 0.5rem;
}
textarea {
flex: 1;
resize: none;
min-height: 48px;
max-height: 96px;
padding: 0.5rem 0.6rem;
border-radius: 999px;
border: 1px solid #d1d5db;
font-size: 0.9rem;
}
button {
border: none;
border-radius: 999px;
padding: 0 1.2rem;
background: #2563eb;
color: white;
font-size: 0.9rem;
font-weight: 500;
cursor: pointer;
}
button:disabled {
opacity: 0.6;
cursor: default;
}
@media (max-width: 640px) {
.app { padding: 0.5rem; }
.chat-window { padding: 0.5rem; }
.bubble { max-width: 82%; }
}
</style>
</head>
<body>
<div class="app">
<header>
<div>
<h1>RAGBench Chat</h1>
<small>Select a domain, then start chatting.</small>
</div>
<div>
<select id="domain">
<option value="biomedical">Biomedical</option>
<option value="general_knowledge">General knowledge</option>
<option value="customer_support">Customer support</option>
<option value="finance">Finance</option>
<option value="legal">Legal</option>
</select>
</div>
</header>
<div id="chat" class="chat-window"></div>
<div id="status" class="status"></div>
<form id="chat-form">
<textarea id="question" placeholder="Ask a question..."></textarea>
<button type="submit">Send</button>
</form>
</div>
<script>
const form = document.getElementById("chat-form");
const questionEl = document.getElementById("question");
const domainEl = document.getElementById("domain");
const chatEl = document.getElementById("chat");
const statusEl = document.getElementById("status");
function addMessage(role, text) {
const row = document.createElement("div");
row.className = "message-row " + role;
const bubble = document.createElement("div");
bubble.className = "bubble " + role;
bubble.textContent = text;
row.appendChild(bubble);
chatEl.appendChild(row);
chatEl.scrollTop = chatEl.scrollHeight;
}
form.addEventListener("submit", async (e) => {
e.preventDefault();
const question = questionEl.value.trim();
if (!question) return;
const domain = domainEl.value;
addMessage("user", question);
questionEl.value = "";
statusEl.textContent = "Thinking...";
form.querySelector("button").disabled = true;
try {
const resp = await fetch("/chat", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ domain, question }),
});
if (!resp.ok) {
throw new Error("HTTP " + resp.status);
}
const data = await resp.json();
addMessage("assistant", data.answer || "[No answer returned]");
statusEl.textContent = "";
} catch (err) {
statusEl.textContent = "Error: " + err;
} finally {
form.querySelector("button").disabled = false;
}
});
</script>
</body>
</html>
"""
return HTMLResponse(content=html)
# ---------------------------------------------------------------------
# HTML RAGBench Evaluation UI at "/eval"
# ---------------------------------------------------------------------
@app.get("/eval", response_class=HTMLResponse)
def eval_ui():
"""Simple page to run /run_domain evaluations from the browser."""
html = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<title>RAGBench RAG Evaluation</title>
<meta name="viewport" content="width=device-width, initial-scale=1" />
<style>
* { box-sizing: border-box; }
body {
margin: 0;
font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
background: #f3f4f6;
color: #111827;
}
.app {
max-width: 960px;
margin: 0 auto;
padding: 24px 16px 40px;
}
header {
display: flex;
justify-content: space-between;
align-items: baseline;
margin-bottom: 20px;
}
header h1 {
font-size: 1.6rem;
margin: 0;
}
header nav a {
font-size: 0.9rem;
color: #2563eb;
text-decoration: none;
margin-left: 12px;
}
header nav a:hover {
text-decoration: underline;
}
.card {
background: #ffffff;
border-radius: 12px;
box-shadow: 0 10px 25px rgba(15, 23, 42, 0.08);
padding: 20px 20px 24px;
margin-bottom: 24px;
}
.card-title {
font-size: 1.1rem;
font-weight: 600;
margin-bottom: 12px;
}
.form-row {
display: grid;
grid-template-columns: repeat(4, minmax(0, 1fr));
gap: 12px;
margin-bottom: 12px;
}
label {
display: block;
font-size: 0.8rem;
font-weight: 500;
color: #4b5563;
margin-bottom: 4px;
}
select, input {
width: 100%;
padding: 6px 8px;
border-radius: 8px;
border: 1px solid #d1d5db;
font-size: 0.9rem;
outline: none;
}
select:focus, input:focus {
border-color: #2563eb;
box-shadow: 0 0 0 1px #2563eb33;
}
button {
padding: 8px 16px;
border-radius: 999px;
border: none;
background: #2563eb;
color: #ffffff;
font-weight: 500;
font-size: 0.95rem;
cursor: pointer;
}
button:disabled {
opacity: 0.6;
cursor: default;
}
#status {
font-size: 0.85rem;
color: #6b7280;
margin-left: 12px;
}
pre {
background: #111827;
color: #e5e7eb;
padding: 16px;
border-radius: 10px;
font-size: 0.8rem;
overflow: auto;
max-height: 60vh;
}
@media (max-width: 768px) {
.form-row {
grid-template-columns: repeat(2, minmax(0, 1fr));
}
}
@media (max-width: 480px) {
.form-row {
grid-template-columns: 1fr;
}
}
</style>
</head>
<body>
<div class="app">
<header>
<h1>RAGBench RAG Evaluation</h1>
<nav>
<a href="/">Chat</a>
<a href="/eval"><strong>Evaluation</strong></a>
</nav>
</header>
<div class="card">
<div class="card-title">Run Domain Evaluation</div>
<form id="eval-form">
<div class="form-row">
<div>
<label for="domain">Domain</label>
<select id="domain" name="domain">
<option value="biomedical">Biomedical</option>
<option value="general_knowledge">General Knowledge</option>
<option value="legal">Legal</option>
<option value="customer_support">Customer Support</option>
<option value="finance">Finance</option>
</select>
</div>
<div>
<label for="k">Top-k documents</label>
<input id="k" name="k" type="number" min="1" value="3" />
</div>
<div>
<label for="max_examples">Max examples</label>
<input id="max_examples" name="max_examples" type="number" min="1" value="5" />
</div>
<div>
<label for="split">Dataset split</label>
<select id="split" style="padding:10px;border-radius:8px;border:1px solid #d0d7de;">
<option value="test">test</option>
<option value="train">train</option>
<option value="validation">validation</option>
</select>
</div>
</div>
<button type="submit">Run Domain Evaluation</button>
<span id="status"></span>
</form>
</div>
<div class="card">
<div class="card-title">Results</div>
<pre id="output">{}</pre>
</div>
</div>
<script>
const form = document.getElementById("eval-form");
const statusEl = document.getElementById("status");
const outputEl = document.getElementById("output");
form.addEventListener("submit", async (ev) => {
ev.preventDefault();
statusEl.textContent = "Running evaluation...";
outputEl.textContent = "{}";
form.querySelector("button").disabled = true;
const domain = document.getElementById("domain").value;
const k = parseInt(document.getElementById("k").value || "3", 10);
const maxExamplesRaw = document.getElementById("max_examples").value;
const split = document.getElementById("split").value;
const payload = {
domain: domain,
k: k,
split: split,
};
if (maxExamplesRaw !== "") {
payload.max_examples = parseInt(maxExamplesRaw, 10);
}
try {
const resp = await fetch("/run_domain", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(payload),
});
if (!resp.ok) {
let msg = "HTTP " + resp.status;
try {
const errData = await resp.json();
if (errData.detail) {
msg += " - " + errData.detail;
}
} catch (_) {}
throw new Error(msg);
}
const data = await resp.json();
statusEl.textContent = "Done.";
outputEl.textContent = JSON.stringify(data, null, 2);
} catch (err) {
statusEl.textContent = "Error: " + err;
outputEl.textContent = "{}";
} finally {
form.querySelector("button").disabled = false;
}
});
</script>
</body>
</html>
"""
return HTMLResponse(content=html)