imurra's picture
You need to use YOUR ORIGINAL app.py that works with ChromaDB!
9f98759 verified
raw
history blame
5.9 kB
import os
os.environ['ANONYMIZED_TELEMETRY'] = 'False'
import zipfile
import chromadb
from sentence_transformers import SentenceTransformer
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
# Extract and load database
DB_PATH = "./medqa_db"
if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"):
print("πŸ“¦ Extracting database...")
with zipfile.ZipFile("./medqa_db.zip", 'r') as z:
z.extractall(".")
print("βœ… Database extracted")
print("πŸ”Œ Loading ChromaDB...")
client = chromadb.PersistentClient(path=DB_PATH)
collection = client.get_collection("medqa")
print(f"βœ… Loaded {collection.count()} questions")
print("🧠 Loading MedCPT model...")
model = SentenceTransformer('ncbi/MedCPT-Query-Encoder')
print("βœ… Model ready")
# Search function
def search(query, num_results=3, source_filter=None):
emb = model.encode(query).tolist()
# Apply source filter if specified
where_clause = None
if source_filter and source_filter != "all":
where_clause = {"source": source_filter}
return collection.query(
query_embeddings=[emb],
n_results=int(num_results),
where=where_clause
)
# Enhanced Gradio UI
def ui_search(query, num_results=3, source_filter="all"):
if not query.strip():
return "πŸ’‘ Enter a medical query to search"
try:
r = search(query, num_results, source_filter if source_filter != "all" else None)
if not r['documents'][0]:
return "❌ No results found"
out = f"πŸ” Found {len(r['documents'][0])} results\n\n"
for i in range(len(r['documents'][0])):
source = r['metadatas'][0][i].get('source', 'unknown')
distance = r['distances'][0][i]
similarity = 1 - distance
# Source emoji
if source == 'medgemini':
source_icon = "πŸ”¬"
source_name = "Med-Gemini"
elif source.startswith('medqa_'):
source_icon = "πŸ“š"
split = source.replace('medqa_', '').upper()
source_name = f"MedQA {split}"
else:
source_icon = "πŸ“„"
source_name = source.upper()
out += f"\n{'='*70}\n"
out += f"{source_icon} Result {i+1} | {source_name} | Similarity: {similarity:.3f}\n"
out += f"{'='*70}\n\n"
out += r['documents'][0][i]
# Show answer
answer = r['metadatas'][0][i].get('answer', 'N/A')
out += f"\n\nβœ… CORRECT ANSWER: {answer}\n"
# Show explanation if available (Med-Gemini)
explanation = r['metadatas'][0][i].get('explanation', '')
if explanation and explanation.strip():
out += f"\nπŸ’‘ EXPLANATION:\n{explanation}\n"
out += "\n"
return out
except Exception as e:
return f"❌ Error: {e}"
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo:
gr.Markdown("""
# πŸ₯ MedQA Semantic Search
Search across **Med-Gemini** (expert explanations) and **MedQA** (USMLE questions) databases.
Uses medical-specific embeddings (MedCPT) for accurate retrieval.
""")
with gr.Row():
with gr.Column(scale=3):
query_input = gr.Textbox(
label="Medical Query",
placeholder="e.g., hyponatremia, myocardial infarction, diabetes management...",
lines=2
)
with gr.Column(scale=1):
num_results = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Number of Results"
)
with gr.Row():
source_filter = gr.Radio(
choices=["all", "medgemini", "medqa_train", "medqa_dev", "medqa_test"],
value="all",
label="Filter by Source"
)
search_btn = gr.Button("πŸ” Search", variant="primary", size="lg")
output = gr.Textbox(
label="Search Results",
lines=25,
max_lines=50
)
search_btn.click(
fn=ui_search,
inputs=[query_input, num_results, source_filter],
outputs=output
)
query_input.submit(
fn=ui_search,
inputs=[query_input, num_results, source_filter],
outputs=output
)
gr.Markdown("""
### πŸ“Š Database Info
**Med-Gemini**: Expert-relabeled questions with detailed explanations
**MedQA**: USMLE-style questions (Train/Dev/Test splits)
**Total Questions**: Use the database you built with `build_combined_db.py`
""")
gr.Examples(
examples=[
["hyponatremia", 3, "all"],
["myocardial infarction treatment", 2, "medgemini"],
["diabetes complications", 3, "all"],
["antibiotics for pneumonia", 2, "medqa_train"]
],
inputs=[query_input, num_results, source_filter]
)
# FastAPI
app = FastAPI()
class SearchRequest(BaseModel):
query: str
num_results: int = 3
source_filter: str = None
@app.post("/search_medqa")
def api_search(req: SearchRequest):
r = search(req.query, req.num_results, req.source_filter)
return {"results": [{
"result_number": i+1,
"question": r['documents'][0][i],
"answer": r['metadatas'][0][i].get('answer', 'N/A'),
"source": r['metadatas'][0][i].get('source', 'unknown'),
"similarity": 1 - r['distances'][0][i]
} for i in range(len(r['documents'][0]))]}
app = gr.mount_gradio_app(app, demo, path="/")
# Launch
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)