| import gradio as gr | |
| from rex.utils.initialization import set_seed_and_log_path | |
| from rex.utils.logging import logger | |
| from src.task import MrcQaTask, SchemaGuidedInstructBertTask | |
| set_seed_and_log_path(log_path="app.log") | |
| class MrcQaPipeline: | |
| def __init__(self, task_dir: str, load_path: str = None) -> None: | |
| self.task = MrcQaTask.from_taskdir( | |
| task_dir, load_best_model=load_path is None, initialize=False | |
| ) | |
| if load_path: | |
| self.task.load(load_path, load_history=False) | |
| def predict(self, query, context, background=None): | |
| data = [ | |
| { | |
| "query": query, | |
| "context": context, | |
| "background": background, | |
| } | |
| ] | |
| results = self.task.predict(data) | |
| ret = results[0] | |
| data[0]["pred"] = ret | |
| logger.opt(colors=False).debug(data[0]) | |
| return ret | |
| class InstructBertPipeline: | |
| def __init__(self, task_dir: str, load_path: str = None) -> None: | |
| self.task = SchemaGuidedInstructBertTask.from_taskdir( | |
| task_dir, load_best_model=load_path is None, initialize=False | |
| ) | |
| if load_path: | |
| self.task.load(load_path, load_history=False) | |
| def predict(self, instruction, schema, text, background): | |
| data = [ | |
| { | |
| "query": query, | |
| "context": context, | |
| "background": background, | |
| } | |
| ] | |
| results = self.task.predict(data) | |
| ret = results[0] | |
| data[0]["pred"] = ret | |
| logger.opt(colors=False).debug(data[0]) | |
| return ret | |
| def mrc_qa(): | |
| pipe = Pipeline("outputs/RobertaBase_data20230314v2") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🪞 Mirror Mirror") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| query = gr.Textbox( | |
| label="Query", placeholder="Mirror Mirror, tell me ..." | |
| ) | |
| with gr.Row(): | |
| context = gr.TextArea( | |
| label="Candidates", | |
| placeholder="Separated by comma (,) without spaces.", | |
| ) | |
| with gr.Row(): | |
| background = gr.TextArea( | |
| label="Background", | |
| placeholder="Background explanation, could be empty", | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| trigger_button = gr.Button("Tell me the truth", variant="primary") | |
| with gr.Row(): | |
| output = gr.TextArea(label="Output") | |
| trigger_button.click( | |
| pipe.predict, inputs=[query, context, background], outputs=output | |
| ) | |
| demo.launch(show_error=True, share=False) | |
| def instruct_bert_pipeline(): | |
| task = SchemaGuidedInstructBertTask.from_taskdir() | |