| import os | |
| from rex.utils.logging import logger | |
| from src.task import MrcTaggingTask | |
| if __name__ == "__main__": | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| task = MrcTaggingTask.from_taskdir( | |
| "outputs/bert_mrc_ner", | |
| load_best_model=True, | |
| update_config={ | |
| "skip_train": True, | |
| "debug_mode": False, | |
| }, | |
| ) | |
| cases = ["123123", "123123"] | |
| logger.info(f"Cases: {cases}") | |
| ents = task.predict(cases) | |
| logger.info(f"Results: {ents}") | |