Spaces:
Sleeping
Sleeping
New code
Browse files
app.py
CHANGED
|
@@ -505,22 +505,6 @@ def compute_ragbench_metrics(judge_response: dict, retrieved_sentence_keys: list
|
|
| 505 |
"Adherence": adherence
|
| 506 |
}
|
| 507 |
|
| 508 |
-
# --- Dataset dictionary ---
|
| 509 |
-
domain_datasets = {
|
| 510 |
-
"Legal": legal_dataset,
|
| 511 |
-
"Medical": med_dataset,
|
| 512 |
-
"GK": gk_dataset,
|
| 513 |
-
"CS": cs_dataset,
|
| 514 |
-
"Finance": fin_dataset
|
| 515 |
-
}
|
| 516 |
-
|
| 517 |
-
# --- Get questions for selected domain ---
|
| 518 |
-
def get_questions_for_domain(domain):
|
| 519 |
-
dataset = domain_datasets.get(domain, [])
|
| 520 |
-
if not dataset:
|
| 521 |
-
return "β οΈ No dataset found for the selected domain."
|
| 522 |
-
|
| 523 |
-
return "\n".join([f"{i}. {item['question']}" for i, item in enumerate(dataset)])
|
| 524 |
|
| 525 |
def evaluate_rag_pipeline(domain, q_indices):
|
| 526 |
import torch
|
|
@@ -613,47 +597,39 @@ def evaluate_rag_pipeline(domain, q_indices):
|
|
| 613 |
|
| 614 |
# Updated wrapper
|
| 615 |
def evaluate_rag_gradio(domain, q_indices_str):
|
|
|
|
| 616 |
log_stream = io.StringIO()
|
| 617 |
sys.stdout = log_stream
|
| 618 |
|
| 619 |
try:
|
|
|
|
| 620 |
q_indices = [int(x.strip()) for x in q_indices_str.split(",") if x.strip().isdigit()]
|
| 621 |
results = evaluate_rag_pipeline(domain, q_indices)
|
|
|
|
| 622 |
logs = log_stream.getvalue()
|
| 623 |
return results, logs
|
|
|
|
| 624 |
except Exception as e:
|
| 625 |
traceback.print_exc()
|
| 626 |
return {"error": str(e)}, log_stream.getvalue()
|
| 627 |
-
finally:
|
| 628 |
-
sys.stdout = sys.__stdout__
|
| 629 |
-
|
| 630 |
-
# === Gradio UI using Blocks ===
|
| 631 |
-
with gr.Blocks(title="RAG Evaluation Dashboard") as demo:
|
| 632 |
-
gr.Markdown("## π RAG Evaluation Dashboard")
|
| 633 |
-
gr.Markdown("Evaluate your RAG pipeline and also browse the questions available for each domain.")
|
| 634 |
-
|
| 635 |
-
with gr.Row():
|
| 636 |
-
domain_input = gr.Dropdown(choices=list(domain_datasets.keys()), label="Select Domain")
|
| 637 |
-
q_index_input = gr.Textbox(label="Enter Query Indices (e.g., 89,121,245)", lines=1)
|
| 638 |
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
)
|
| 657 |
|
| 658 |
-
#
|
| 659 |
-
|
|
|
|
| 505 |
"Adherence": adherence
|
| 506 |
}
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
def evaluate_rag_pipeline(domain, q_indices):
|
| 510 |
import torch
|
|
|
|
| 597 |
|
| 598 |
# Updated wrapper
|
| 599 |
def evaluate_rag_gradio(domain, q_indices_str):
|
| 600 |
+
# Capture logs
|
| 601 |
log_stream = io.StringIO()
|
| 602 |
sys.stdout = log_stream
|
| 603 |
|
| 604 |
try:
|
| 605 |
+
# Parse comma-separated indices
|
| 606 |
q_indices = [int(x.strip()) for x in q_indices_str.split(",") if x.strip().isdigit()]
|
| 607 |
results = evaluate_rag_pipeline(domain, q_indices)
|
| 608 |
+
|
| 609 |
logs = log_stream.getvalue()
|
| 610 |
return results, logs
|
| 611 |
+
|
| 612 |
except Exception as e:
|
| 613 |
traceback.print_exc()
|
| 614 |
return {"error": str(e)}, log_stream.getvalue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
|
| 616 |
+
finally:
|
| 617 |
+
sys.stdout = sys.__stdout__ # Restore stdout
|
| 618 |
+
|
| 619 |
+
# Gradio interface
|
| 620 |
+
iface = gr.Interface(
|
| 621 |
+
fn=evaluate_rag_gradio,
|
| 622 |
+
inputs=[
|
| 623 |
+
gr.Dropdown(choices=["Legal", "Medical", "GK", "CS", "Finance"], label="Domain"),
|
| 624 |
+
gr.Textbox(label="Comma-separated Query Indices (e.g. 89,121,245)", lines=1),
|
| 625 |
+
],
|
| 626 |
+
outputs=[
|
| 627 |
+
gr.JSON(label="Evaluation Metrics (RMSE & AUC-ROC)"),
|
| 628 |
+
gr.Textbox(label="Execution Log", lines=10, interactive=True),
|
| 629 |
+
],
|
| 630 |
+
title="RAG Evaluation Dashboard",
|
| 631 |
+
description="Evaluate your RAG pipeline across selected queries using GPT-based generation and judgment."
|
| 632 |
+
)
|
|
|
|
| 633 |
|
| 634 |
+
# Launch app
|
| 635 |
+
iface.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|