Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import time | |
| from faster_whisper import WhisperModel | |
| from utils import ffmpeg_read, stt, greeting_list | |
| from sentence_transformers import SentenceTransformer, util | |
| import torch | |
| whisper_models = ["tiny", "base", "small", "medium", "large-v1", "large-v2"] | |
| audio_model = WhisperModel("base", compute_type="int8", device="cpu") | |
| text_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| corpus_embeddings = torch.load('corpus_embeddings.pt') | |
| model_type = "whisper" | |
| title= "Greeting detection demo app" | |
| def speech_to_text(upload_audio): | |
| """ | |
| Transcribe audio using whisper model. | |
| """ | |
| # Transcribe audio | |
| if model_type == "whisper": | |
| transcribe_options = dict(task="transcribe", language="ja", beam_size=5, best_of=5, vad_filter=True) | |
| segments_raw, info = audio_model.transcribe(upload_audio, **transcribe_options) | |
| segments = [segment.text for segment in segments_raw] | |
| return ' '.join(segments) | |
| else: | |
| text = stt(upload_audio) | |
| return text | |
| def voice_detect(audio, recongnize_text=""): | |
| """ | |
| Transcribe audio using whisper model. | |
| """ | |
| # time.sleep(2) | |
| if len(recongnize_text) !=0: | |
| count_state = int(recongnize_text[0]) | |
| recongnize_text = recongnize_text[1:] | |
| else: | |
| count_state = 0 | |
| threshold = 0.8 | |
| detect_greeting = 0 | |
| text = speech_to_text(audio) | |
| if "γθ¦θ΄γγγγ¨γγγγγΎγγ" in text: | |
| text = "" | |
| recongnize_text = recongnize_text + " " + text | |
| query_embedding = text_model.encode(text, convert_to_tensor=True) | |
| for greeting in greeting_list: | |
| if greeting in text: | |
| detect_greeting = 1 | |
| break | |
| if detect_greeting == 0: | |
| hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=1)[0] | |
| if hits[0]['score'] > threshold: | |
| detect_greeting = 1 | |
| recongnize_state = str(count_state + detect_greeting) + recongnize_text | |
| return recongnize_text, recongnize_state, count_state | |
| def clear(): | |
| return None, None, None | |
| demo = gr.Blocks(title=title) | |
| with demo: | |
| gr.Markdown(''' | |
| <div> | |
| <h1 style='text-align: center'>ζ¨ζΆγ«γ¦γ³γΏγΌ</h1> | |
| </div> | |
| ''') | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_source = gr.Audio(source="microphone", type="filepath", streaming=True) | |
| state = gr.State(value="") | |
| with gr.Column(): | |
| greeting_count = gr.Number(label="ζ¨ζΆεζ°") | |
| with gr.Row(): | |
| text_output = gr.Textbox(label="θͺθγγγγγγΉγ") | |
| audio_source.stream(voice_detect, inputs=[audio_source, state], outputs=[text_output, state, greeting_count]) | |
| demo.launch(debug=True) |