Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Base CLI to parse Arguments | |
| # | |
| # @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
| # Institute for Artifical Intelligence in Medicine, | |
| # University Medicine Essen | |
| import argparse | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from typing import Tuple, Union | |
| import yaml | |
| from pydantic import BaseModel | |
| class ABCParser(ABC): | |
| """Blueprint for Argument Parser""" | |
| def __init__(self) -> None: | |
| pass | |
| def get_config(self) -> Tuple[Union[BaseModel, dict], logging.Logger]: | |
| """Load configuration and create a logger | |
| Returns: | |
| Tuple[PreProcessingConfig, logging.Logger]: Configuration and Logger | |
| """ | |
| pass | |
| def store_config(self) -> None: | |
| """Store the config file in the logging directory to keep track of the configuration.""" | |
| pass | |
| class ExperimentBaseParser: | |
| """Configuration Parser for Machine Learning Experiments""" | |
| def __init__(self) -> None: | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| description="Start an experiment with given configuration file.", | |
| ) | |
| requiredNamed = parser.add_argument_group("required named arguments") | |
| requiredNamed.add_argument( | |
| "--config", type=str, help="Path to a config file", required=True | |
| ) | |
| parser.add_argument("--gpu", type=int, help="Cuda-GPU ID") | |
| group = parser.add_mutually_exclusive_group(required=False) | |
| group.add_argument( | |
| "--sweep", | |
| action="store_true", | |
| help="Starting a sweep. For this the configuration file must be structured according to WandB sweeping. " | |
| "Compare https://docs.wandb.ai/guides/sweeps and https://community.wandb.ai/t/nested-sweep-configuration/3369/3 " | |
| "for further information. This parameter cannot be set in the config file!", | |
| ) | |
| group.add_argument( | |
| "--agent", | |
| type=str, | |
| help="Add a new agent to the sweep. " | |
| "Please pass the sweep ID as argument in the way entity/project/sweep_id, e.g., user1/test_project/v4hwbijh. " | |
| "The agent configuration can be found in the WandB dashboard for the running sweep in the sweep overview tab " | |
| "under launch agent. Just paste the entity/project/sweep_id given there. The provided config file must be a sweep config file." | |
| "This parameter cannot be set in the config file!", | |
| ) | |
| group.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| help="Path to a PyTorch checkpoint file. " | |
| "The file is loaded and continued to train with the provided settings. " | |
| "If this is passed, no sweeps are possible. " | |
| "This parameter cannot be set in the config file!", | |
| ) | |
| self.parser = parser | |
| def parse_arguments(self) -> Tuple[Union[BaseModel, dict]]: | |
| """Parse the arguments from CLI and load yaml config | |
| Returns: | |
| Tuple[Union[BaseModel, dict]]: Parsed arguments | |
| """ | |
| # parse the arguments | |
| opt = self.parser.parse_args() #定义了一个opt变量,用来存储参数 | |
| with open(opt.config, "r") as config_file: | |
| yaml_config = yaml.safe_load(config_file) | |
| yaml_config_dict = dict(yaml_config) #将yaml文件转换为字典 | |
| opt_dict = vars(opt) #将opt转换为字典 | |
| # check for gpu to overwrite with cli argument | |
| if "gpu" in opt_dict: #如果gpu在opt_dict中 | |
| if opt_dict["gpu"] is not None: | |
| yaml_config_dict["gpu"] = opt_dict["gpu"] #将opt_dict中的gpu值赋给yaml_config_dict中的gpu | |
| # check if either training, sweep, checkpoint or start agent should be called | |
| # first step: remove such keys from the config file | |
| if "run_sweep" in yaml_config_dict: #如果yaml_config_dict中有run_sweep | |
| yaml_config_dict.pop("run_sweep") #删除yaml_config_dict中的run_sweep | |
| if "agent" in yaml_config_dict: | |
| yaml_config_dict.pop("agent") | |
| if "checkpoint" in yaml_config_dict: | |
| yaml_config_dict.pop("checkpoint") | |
| # select one of the options | |
| if "sweep" in opt_dict and opt_dict["sweep"] is True: | |
| yaml_config_dict["run_sweep"] = True | |
| else: | |
| yaml_config_dict["run_sweep"] = False | |
| if "agent" in opt_dict: | |
| yaml_config_dict["agent"] = opt_dict["agent"] | |
| if "checkpoint" in opt_dict: | |
| if opt_dict["checkpoint"] is not None: | |
| yaml_config_dict["checkpoint"] = opt_dict["checkpoint"] | |
| self.config = yaml_config_dict #将yaml_config_dict赋给self.config | |
| return self.config | |