Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['ANONYMIZED_TELEMETRY'] = 'False' | |
| import zipfile | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| # Extract and load database | |
| DB_PATH = "./medqa_db" | |
| if not os.path.exists(DB_PATH) and os.path.exists("./medqa_db.zip"): | |
| print("π¦ Extracting database...") | |
| with zipfile.ZipFile("./medqa_db.zip", 'r') as z: | |
| z.extractall(".") | |
| print("β Database extracted") | |
| print("π Loading ChromaDB...") | |
| client = chromadb.PersistentClient(path=DB_PATH) | |
| collection = client.get_collection("medqa") | |
| print(f"β Loaded {collection.count()} questions") | |
| print("π§ Loading MedCPT model...") | |
| model = SentenceTransformer('ncbi/MedCPT-Query-Encoder') | |
| print("β Model ready") | |
| # Search function | |
| def search(query, num_results=3, source_filter=None): | |
| emb = model.encode(query).tolist() | |
| # Apply source filter if specified | |
| where_clause = None | |
| if source_filter and source_filter != "all": | |
| where_clause = {"source": source_filter} | |
| return collection.query( | |
| query_embeddings=[emb], | |
| n_results=int(num_results), | |
| where=where_clause | |
| ) | |
| # Enhanced Gradio UI | |
| def ui_search(query, num_results=3, source_filter="all"): | |
| if not query.strip(): | |
| return "π‘ Enter a medical query to search" | |
| try: | |
| r = search(query, num_results, source_filter if source_filter != "all" else None) | |
| if not r['documents'][0]: | |
| return "β No results found" | |
| out = f"π Found {len(r['documents'][0])} results\n\n" | |
| for i in range(len(r['documents'][0])): | |
| source = r['metadatas'][0][i].get('source', 'unknown') | |
| distance = r['distances'][0][i] | |
| similarity = 1 - distance | |
| # Source emoji | |
| if source == 'medgemini': | |
| source_icon = "π¬" | |
| source_name = "Med-Gemini" | |
| elif source.startswith('medqa_'): | |
| source_icon = "π" | |
| split = source.replace('medqa_', '').upper() | |
| source_name = f"MedQA {split}" | |
| else: | |
| source_icon = "π" | |
| source_name = source.upper() | |
| out += f"\n{'='*70}\n" | |
| out += f"{source_icon} Result {i+1} | {source_name} | Similarity: {similarity:.3f}\n" | |
| out += f"{'='*70}\n\n" | |
| out += r['documents'][0][i] | |
| # Show answer | |
| answer = r['metadatas'][0][i].get('answer', 'N/A') | |
| out += f"\n\nβ CORRECT ANSWER: {answer}\n" | |
| # Show explanation if available (Med-Gemini) | |
| explanation = r['metadatas'][0][i].get('explanation', '') | |
| if explanation and explanation.strip(): | |
| out += f"\nπ‘ EXPLANATION:\n{explanation}\n" | |
| out += "\n" | |
| return out | |
| except Exception as e: | |
| return f"β Error: {e}" | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="MedQA Search") as demo: | |
| gr.Markdown(""" | |
| # π₯ MedQA Semantic Search | |
| Search across **Med-Gemini** (expert explanations) and **MedQA** (USMLE questions) databases. | |
| Uses medical-specific embeddings (MedCPT) for accurate retrieval. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query_input = gr.Textbox( | |
| label="Medical Query", | |
| placeholder="e.g., hyponatremia, myocardial infarction, diabetes management...", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| num_results = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| label="Number of Results" | |
| ) | |
| with gr.Row(): | |
| source_filter = gr.Radio( | |
| choices=["all", "medgemini", "medqa_train", "medqa_dev", "medqa_test"], | |
| value="all", | |
| label="Filter by Source" | |
| ) | |
| search_btn = gr.Button("π Search", variant="primary", size="lg") | |
| output = gr.Textbox( | |
| label="Search Results", | |
| lines=25, | |
| max_lines=50 | |
| ) | |
| search_btn.click( | |
| fn=ui_search, | |
| inputs=[query_input, num_results, source_filter], | |
| outputs=output | |
| ) | |
| query_input.submit( | |
| fn=ui_search, | |
| inputs=[query_input, num_results, source_filter], | |
| outputs=output | |
| ) | |
| gr.Markdown(""" | |
| ### π Database Info | |
| **Med-Gemini**: Expert-relabeled questions with detailed explanations | |
| **MedQA**: USMLE-style questions (Train/Dev/Test splits) | |
| **Total Questions**: Use the database you built with `build_combined_db.py` | |
| """) | |
| gr.Examples( | |
| examples=[ | |
| ["hyponatremia", 3, "all"], | |
| ["myocardial infarction treatment", 2, "medgemini"], | |
| ["diabetes complications", 3, "all"], | |
| ["antibiotics for pneumonia", 2, "medqa_train"] | |
| ], | |
| inputs=[query_input, num_results, source_filter] | |
| ) | |
| # FastAPI | |
| app = FastAPI() | |
| class SearchRequest(BaseModel): | |
| query: str | |
| num_results: int = 3 | |
| source_filter: str = None | |
| def api_search(req: SearchRequest): | |
| r = search(req.query, req.num_results, req.source_filter) | |
| return {"results": [{ | |
| "result_number": i+1, | |
| "question": r['documents'][0][i], | |
| "answer": r['metadatas'][0][i].get('answer', 'N/A'), | |
| "source": r['metadatas'][0][i].get('source', 'unknown'), | |
| "similarity": 1 - r['distances'][0][i] | |
| } for i in range(len(r['documents'][0]))]} | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| # Launch | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |