Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| from pydub import AudioSegment | |
| from sentence_transformers import SentenceTransformer, util | |
| import spacy | |
| spacy.cli.download("en_core_web_sm") | |
| import json | |
| from faster_whisper import WhisperModel | |
| # Audio conversion from MP4 to MP3 | |
| def convert_mp4_to_mp3(mp4_path, mp3_path): | |
| try: | |
| audio = AudioSegment.from_file(mp4_path, format="mp4") | |
| audio.export(mp3_path, format="mp3") | |
| except Exception as e: | |
| raise RuntimeError(f"Error converting MP4 to MP3: {e}") | |
| # Check if CUDA is available for GPU acceleration | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| compute_type = "float16" | |
| else: | |
| device = "cpu" | |
| compute_type = "int8" | |
| # Load Faster Whisper Model for transcription | |
| def load_faster_whisper(): | |
| model = WhisperModel("deepdml/faster-whisper-large-v3-turbo-ct2", device=device, compute_type=compute_type) | |
| return model | |
| # Load NLP model and other helpers | |
| nlp = spacy.load("en_core_web_sm") | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| tokenizer = AutoTokenizer.from_pretrained("aws-prototyping/MegaBeam-Mistral-7B-512k") | |
| model = AutoModelForCausalLM.from_pretrained("aws-prototyping/MegaBeam-Mistral-7B-512k") | |
| summarizer = pipeline("summarization", model=model, tokenizer=tokenizer) | |
| soap_prompts = { | |
| "subjective": "Personal reports, symptoms described by patients, or personal health concerns. Details reflecting individual symptoms or health descriptions.", | |
| "objective": "Observable facts, clinical findings, professional observations, specific medical specialties, and diagnoses.", | |
| "assessment": "Clinical assessments, expertise-based opinions on conditions, and significance of medical interventions. Focused on medical evaluations or patient condition summaries.", | |
| "plan": "Future steps, recommendations for treatment, follow-up instructions, and healthcare management plans." | |
| } | |
| soap_embeddings = {section: embedder.encode(prompt, convert_to_tensor=True) for section, prompt in soap_prompts.items()} | |
| # Query function for MegaBeam-Mistral-7B | |
| def megabeam_query(user_prompt, soap_note): | |
| combined_prompt = f"User Instructions:\n{user_prompt}\n\nContext:\n{soap_note}" | |
| try: | |
| inputs = tokenizer(combined_prompt, return_tensors="pt") | |
| outputs = model.generate(**inputs, max_length=512) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| except Exception as e: | |
| return f"Error generating response: {e}" | |
| # Convert the response to JSON format | |
| def convert_to_json(template): | |
| try: | |
| lines = template.split("\n") | |
| json_data = {} | |
| section = None | |
| for line in lines: | |
| if line.endswith(":"): | |
| section = line[:-1] | |
| json_data[section] = [] | |
| elif section: | |
| json_data[section].append(line.strip()) | |
| return json.dumps(json_data, indent=2) | |
| except Exception as e: | |
| return f"Error converting to JSON: {e}" | |
| # Transcription using Faster Whisper | |
| def transcribe_audio(mp4_path): | |
| try: | |
| print(f"Processing MP4 file: {mp4_path}") | |
| model = load_faster_whisper() | |
| mp3_path = "output_audio.mp3" | |
| convert_mp4_to_mp3(mp4_path, mp3_path) | |
| # Transcribe using Faster Whisper | |
| result, segments = model.transcribe(mp3_path, beam_size=5) | |
| transcription = " ".join([seg.text for seg in segments]) | |
| return transcription | |
| except Exception as e: | |
| return f"Error during transcription: {e}" | |
| # Classify the sentence to the correct SOAP section | |
| def classify_sentence(sentence): | |
| similarities = {section: util.pytorch_cos_sim(embedder.encode(sentence), soap_embeddings[section]) for section in soap_prompts.keys()} | |
| return max(similarities, key=similarities.get) | |
| # Summarize the section if it's too long | |
| def summarize_section(section_text): | |
| if len(section_text.split()) < 50: | |
| return section_text | |
| target_length = int(len(section_text.split()) * 0.65) | |
| inputs = tokenizer.encode(section_text, return_tensors="pt", truncation=True, max_length=1024) | |
| summary_ids = model.generate( | |
| inputs, | |
| max_length=target_length, | |
| min_length=int(target_length * 0.60), | |
| length_penalty=1.0, | |
| num_beams=4 | |
| ) | |
| return tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| # Analyze the SOAP content and divide into sections | |
| def soap_analysis(text): | |
| doc = nlp(text) | |
| soap_note = {section: "" for section in soap_prompts.keys()} | |
| for sentence in doc.sents: | |
| section = classify_sentence(sentence.text) | |
| soap_note[section] += sentence.text + " " | |
| # Summarize each section of the SOAP note | |
| for section in soap_note: | |
| soap_note[section] = summarize_section(soap_note[section].strip()) | |
| return format_soap_output(soap_note) | |
| # Format the SOAP note output | |
| def format_soap_output(soap_note): | |
| return ( | |
| f"Subjective:\n{soap_note['subjective']}\n\n" | |
| f"Objective:\n{soap_note['objective']}\n\n" | |
| f"Assessment:\n{soap_note['assessment']}\n\n" | |
| f"Plan:\n{soap_note['plan']}\n" | |
| ) | |
| # Process file function for audio to SOAP | |
| def process_file(mp4_file, user_prompt): | |
| transcription = transcribe_audio(mp4_file.name) | |
| print("Transcribed Text: ", transcription) | |
| soap_note = soap_analysis(transcription) | |
| print("SOAP Notes: ", soap_note) | |
| template_output = megabeam_query(user_prompt, soap_note) | |
| print("Template: ", template_output) | |
| json_output = convert_to_json(template_output) | |
| return soap_note, template_output, json_output | |
| # Process text function for text input to SOAP | |
| def process_text(text, user_prompt): | |
| soap_note = soap_analysis(text) | |
| print(soap_note) | |
| template_output = megabeam_query(user_prompt, soap_note) | |
| print(template_output) | |
| json_output = convert_to_json(template_output) | |
| return soap_note, template_output, json_output | |
| # Launch the Gradio interface | |
| def launch_gradio(): | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| gr.Markdown("# SOAP Note Generator") | |
| with gr.Tab("Audio to SOAP"): | |
| gr.Interface( | |
| fn=process_file, | |
| inputs=[ | |
| gr.File(label="Upload MP4 File"), | |
| gr.Textbox(label="Enter Prompt for Template", placeholder="Enter a detailed prompt...", lines=6), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="SOAP Note"), | |
| gr.Textbox(label="Generated Template from MegaBeam-Mistral-7B"), | |
| gr.Textbox(label="JSON Output"), | |
| ], | |
| ) | |
| with gr.Tab("Text to SOAP"): | |
| gr.Interface( | |
| fn=process_text, | |
| inputs=[ | |
| gr.Textbox(label="Enter Text", placeholder="Enter medical notes...", lines=6), | |
| gr.Textbox(label="Enter Prompt for Template", placeholder="Enter a detailed prompt...", lines=6), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="SOAP Note"), | |
| gr.Textbox(label="Generated Template from MegaBeam-Mistral-7B"), | |
| gr.Textbox(label="JSON Output"), | |
| ], | |
| ) | |
| demo.launch(share=True, debug=True) | |
| # Run the Gradio app | |
| if __name__ == "__main__": | |
| launch_gradio() | |