import os os.environ['ANONYMIZED_TELEMETRY'] = 'False' import zipfile import chromadb from sentence_transformers import SentenceTransformer import gradio as gr from fastapi import FastAPI from pydantic import BaseModel # Extract and load database DB_PATH = "./medqa_db" if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): print("šŸ“¦ Extracting database...") with zipfile.ZipFile("./medqa_db.zip", 'r') as z: z.extractall(".") print("āœ… Database extracted") print("šŸ”Œ Loading ChromaDB...") client = chromadb.PersistentClient(path=DB_PATH) collection = client.get_collection("medqa") print(f"āœ… Loaded {collection.count()} questions") print("🧠 Loading MedCPT model...") model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') print("āœ… Model ready") # Search function def search(query, num_results=3, source_filter=None): emb = model.encode(query).tolist() # Apply source filter if specified where_clause = None if source_filter and source_filter != "all": where_clause = {"source": source_filter} return collection.query( query_embeddings=[emb], n_results=int(num_results), where=where_clause ) # Enhanced Gradio UI def ui_search(query, num_results=3, source_filter="all"): if not query.strip(): return "šŸ’” Enter a medical query to search" try: r = search(query, num_results, source_filter if source_filter != "all" else None) if not r['documents'][0]: return "āŒ No results found" out = f"šŸ” Found {len(r['documents'][0])} results\n\n" for i in range(len(r['documents'][0])): source = r['metadatas'][0][i].get('source', 'unknown') distance = r['distances'][0][i] similarity = 1 - distance # Source emoji if source == 'medgemini': source_icon = "šŸ”¬" source_name = "Med-Gemini" elif source.startswith('medqa_'): source_icon = "šŸ“š" split = source.replace('medqa_', '').upper() source_name = f"MedQA {split}" else: source_icon = "šŸ“„" source_name = source.upper() out += f"\n{'='*70}\n" out += f"{source_icon} Result {i+1} | {source_name} | Similarity: {similarity:.3f}\n" out += f"{'='*70}\n\n" out += r['documents'][0][i] # Show answer answer = r['metadatas'][0][i].get('answer', 'N/A') out += f"\n\nāœ… CORRECT ANSWER: {answer}\n" # Show explanation if available (Med-Gemini) explanation = r['metadatas'][0][i].get('explanation', '') if explanation and explanation.strip(): out += f"\nšŸ’” EXPLANATION:\n{explanation}\n" out += "\n" return out except Exception as e: return f"āŒ Error: {e}" # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo: gr.Markdown(""" # šŸ„ MedQA Semantic Search Search across **Med-Gemini** (expert explanations) and **MedQA** (USMLE questions) databases. Uses medical-specific embeddings (MedCPT) for accurate retrieval. """) with gr.Row(): with gr.Column(scale=3): query_input = gr.Textbox( label="Medical Query", placeholder="e.g., hyponatremia, myocardial infarction, diabetes management...", lines=2 ) with gr.Column(scale=1): num_results = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Number of Results" ) with gr.Row(): source_filter = gr.Radio( choices=["all", "medgemini", "medqa_train", "medqa_dev", "medqa_test"], value="all", label="Filter by Source" ) search_btn = gr.Button("šŸ” Search", variant="primary", size="lg") output = gr.Textbox( label="Search Results", lines=25, max_lines=50 ) search_btn.click( fn=ui_search, inputs=[query_input, num_results, source_filter], outputs=output ) query_input.submit( fn=ui_search, inputs=[query_input, num_results, source_filter], outputs=output ) gr.Markdown(""" ### šŸ“Š Database Info **Med-Gemini**: Expert-relabeled questions with detailed explanations **MedQA**: USMLE-style questions (Train/Dev/Test splits) **Total Questions**: Use the database you built with `build_combined_db.py` """) gr.Examples( examples=[ ["hyponatremia", 3, "all"], ["myocardial infarction treatment", 2, "medgemini"], ["diabetes complications", 3, "all"], ["antibiotics for pneumonia", 2, "medqa_train"] ], inputs=[query_input, num_results, source_filter] ) # FastAPI app = FastAPI() class SearchRequest(BaseModel): query: str num_results: int = 3 source_filter: str = None @app.post("/search_medqa") def api_search(req: SearchRequest): r = search(req.query, req.num_results, req.source_filter) return {"results": [{ "result_number": i+1, "question": r['documents'][0][i], "answer": r['metadatas'][0][i].get('answer', 'N/A'), "source": r['metadatas'][0][i].get('source', 'unknown'), "similarity": 1 - r['distances'][0][i] } for i in range(len(r['documents'][0]))]} app = gr.mount_gradio_app(app, demo, path="/") # Launch if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)