Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Base Machine Learning Experiment | |
| # | |
| # @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
| # Institute for Artifical Intelligence in Medicine, | |
| # University Medicine Essen | |
| import logging | |
| logger = logging.getLogger("__main__") | |
| logger.addHandler(logging.NullHandler()) | |
| import wandb | |
| class EarlyStopping: | |
| """Early Stopping Class | |
| Args: | |
| patience (int): Patience to wait before stopping | |
| strategy (str, optional): Optimization strategy. | |
| Please select 'minimize' or 'maximize' for strategy. Defaults to "minimize". | |
| """ | |
| def __init__(self, patience: int, strategy: str = "minimize"): | |
| assert strategy.lower() in [ | |
| "minimize", | |
| "maximize", | |
| ], "Please select 'minimize' or 'maximize' for strategy" | |
| self.patience = patience | |
| self.counter = 0 | |
| self.strategy = strategy.lower() | |
| self.best_metric = None | |
| self.best_epoch = None | |
| self.early_stop = False | |
| logger.info( | |
| f"Using early stopping with a range of {self.patience} and {self.strategy} strategy" | |
| ) | |
| def __call__(self, metric: float, epoch: int) -> bool: | |
| """Early stopping update call | |
| Args: | |
| metric (float): Metric for early stopping | |
| epoch (int): Current epoch | |
| Returns: | |
| bool: Returns true if the model is performing better than the current best model, | |
| otherwise false | |
| """ | |
| if self.best_metric is None: | |
| self.best_metric = metric | |
| self.best_epoch = epoch | |
| return True | |
| else: | |
| if self.strategy == "minimize": | |
| if self.best_metric >= metric: | |
| self.best_metric = metric | |
| self.best_epoch = epoch | |
| self.counter = 0 | |
| wandb.run.summary["Best-Epoch"] = epoch | |
| wandb.run.summary["Best-Metric"] = metric | |
| return True | |
| else: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| self.early_stop = True | |
| return False | |
| elif self.strategy == "maximize": | |
| if self.best_metric <= metric: | |
| self.best_metric = metric | |
| self.best_epoch = epoch | |
| self.counter = 0 | |
| wandb.run.summary["Best-Epoch"] = epoch | |
| wandb.run.summary["Best-Metric"] = metric | |
| return True | |
| else: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| self.early_stop = True | |
| return False | |