Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Base Trainer Class | |
| # | |
| # @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
| # Institute for Artifical Intelligence in Medicine, | |
| # University Medicine Essen | |
| import logging | |
| from abc import abstractmethod | |
| from typing import Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import wandb | |
| from base_ml.base_early_stopping import EarlyStopping | |
| from pathlib import Path | |
| from torch.nn.modules.loss import _Loss | |
| from torch.optim import Optimizer | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| from torch.utils.data import DataLoader | |
| from utils.tools import flatten_dict | |
| class BaseTrainer: | |
| """ | |
| Base class for all trainers with important ML components | |
| Args: | |
| model (nn.Module): Model that should be trained | |
| loss_fn (_Loss): Loss function | |
| optimizer (Optimizer): Optimizer | |
| scheduler (_LRScheduler): Learning rate scheduler | |
| device (str): Cuda device to use, e.g., cuda:0. | |
| logger (logging.Logger): Logger module | |
| logdir (Union[Path, str]): Logging directory | |
| experiment_config (dict): Configuration of this experiment | |
| early_stopping (EarlyStopping, optional): Early Stopping Class. Defaults to None. | |
| accum_iter (int, optional): Accumulation steps for gradient accumulation. | |
| Provide a number greater than 1 for activating gradient accumulation. Defaults to 1. | |
| mixed_precision (bool, optional): If mixed-precision should be used. Defaults to False. | |
| log_images (bool, optional): If images should be logged to WandB. Defaults to False. | |
| """ | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| loss_fn: _Loss, | |
| optimizer: Optimizer, | |
| scheduler: _LRScheduler, | |
| device: str, | |
| logger: logging.Logger, | |
| logdir: Union[Path, str], | |
| experiment_config: dict, | |
| early_stopping: EarlyStopping = None, | |
| accum_iter: int = 1, | |
| mixed_precision: bool = False, | |
| log_images: bool = False, | |
| #model_ema: bool = True, | |
| ) -> None: | |
| self.model = model | |
| self.loss_fn = loss_fn | |
| self.optimizer = optimizer | |
| self.scheduler = scheduler | |
| self.device = device | |
| self.logger = logger | |
| self.logdir = Path(logdir) | |
| self.early_stopping = early_stopping | |
| self.accum_iter = accum_iter | |
| self.start_epoch = 0 | |
| self.experiment_config = experiment_config | |
| self.log_images = log_images | |
| self.mixed_precision = mixed_precision | |
| if self.mixed_precision: | |
| self.scaler = torch.cuda.amp.GradScaler(enabled=True) | |
| else: | |
| self.scaler = None | |
| def train_epoch( | |
| self, epoch: int, train_loader: DataLoader, **kwargs | |
| ) -> Tuple[dict, dict]: | |
| """Training logic for a training epoch | |
| Args: | |
| epoch (int): Current epoch number | |
| train_loader (DataLoader): Train dataloader | |
| Raises: | |
| NotImplementedError: Needs to be implemented | |
| Returns: | |
| Tuple[dict, dict]: wandb logging dictionaries | |
| * Scalar metrics | |
| * Image metrics | |
| """ | |
| raise NotImplementedError | |
| def validation_epoch( | |
| self, epoch: int, val_dataloader: DataLoader | |
| ) -> Tuple[dict, dict, float]: | |
| """Training logic for an validation epoch | |
| Args: | |
| epoch (int): Current epoch number | |
| val_dataloader (DataLoader): Validation dataloader | |
| Raises: | |
| NotImplementedError: Needs to be implemented | |
| Returns: | |
| Tuple[dict, dict, float]: wandb logging dictionaries and early_stopping_metric | |
| * Scalar metrics | |
| * Image metrics | |
| * Early Stopping metric as float | |
| """ | |
| raise NotImplementedError | |
| def train_step(self, batch: object, batch_idx: int, num_batches: int): | |
| """Training logic for one training batch | |
| Args: | |
| batch (object): A training batch | |
| batch_idx (int): Current batch index | |
| num_batches (int): Maximum number of batches | |
| Raises: | |
| NotImplementedError: Needs to be implemented | |
| """ | |
| raise NotImplementedError | |
| def validation_step(self, batch, batch_idx: int): | |
| """Training logic for one validation batch | |
| Args: | |
| batch (object): A training batch | |
| batch_idx (int): Current batch index | |
| Raises: | |
| NotImplementedError: Needs to be implemented | |
| """ | |
| def fit( | |
| self, | |
| epochs: int, | |
| train_dataloader: DataLoader, | |
| val_dataloader: DataLoader, | |
| metric_init: dict = None, | |
| eval_every: int = 1, | |
| **kwargs, | |
| ): | |
| """Fitting function to start training and validation of the trainer | |
| Args: | |
| epochs (int): Number of epochs the network should be training | |
| train_dataloader (DataLoader): Dataloader with training data | |
| val_dataloader (DataLoader): Dataloader with validation data | |
| metric_init (dict, optional): Initialization dictionary with scalar metrics that should be initialized for startup. | |
| This is just import for logging with wandb if you want to have the plots properly scaled. | |
| The data in the the metric dictionary is used as values for epoch 0 (before training has startetd). | |
| If not provided, step 0 (epoch 0) is not logged. Should have the same scalar keys as training and validation epochs report. | |
| For more information, you should have a look into the train_epoch and val_epoch methods where the wandb logging dicts are assembled. | |
| Defaults to None. | |
| eval_every (int, optional): How often the network should be evaluated (after how many epochs). Defaults to 1. | |
| **kwargs | |
| """ | |
| self.logger.info(f"Starting training, total number of epochs: {epochs}") | |
| if metric_init is not None and self.start_epoch == 0: | |
| wandb.log(metric_init, step=0) | |
| for epoch in range(self.start_epoch, epochs): | |
| # training epoch | |
| #train_sampler.set_epoch(epoch) # for distributed training | |
| self.logger.info(f"Epoch: {epoch+1}/{epochs}") | |
| train_scalar_metrics, train_image_metrics = self.train_epoch( | |
| epoch, train_dataloader, **kwargs | |
| ) | |
| wandb.log(train_scalar_metrics, step=epoch + 1) | |
| if self.log_images: | |
| wandb.log(train_image_metrics, step=epoch + 1) | |
| if epoch >=95 and ((epoch + 1)) % eval_every == 0: | |
| # validation epoch | |
| ( | |
| val_scalar_metrics, | |
| val_image_metrics, | |
| early_stopping_metric, | |
| ) = self.validation_epoch(epoch, val_dataloader) | |
| wandb.log(val_scalar_metrics, step=epoch + 1) | |
| if self.log_images: | |
| wandb.log(val_image_metrics, step=epoch + 1) | |
| #self.save_checkpoint(epoch, f"checkpoint_{epoch}.pth") | |
| # log learning rate | |
| curr_lr = self.optimizer.param_groups[0]["lr"] | |
| wandb.log( | |
| { | |
| "Learning-Rate/Learning-Rate": curr_lr, | |
| }, | |
| step=epoch + 1, | |
| ) | |
| if epoch >=95 and ((epoch + 1)) % eval_every == 0: | |
| # early stopping | |
| if self.early_stopping is not None: | |
| best_model = self.early_stopping(early_stopping_metric, epoch) | |
| if best_model: | |
| self.logger.info("New best model - save checkpoint") | |
| self.save_checkpoint(epoch, "model_best.pth") | |
| elif self.early_stopping.early_stop: | |
| self.logger.info("Performing early stopping!") | |
| break | |
| self.save_checkpoint(epoch, "latest_checkpoint.pth") | |
| # scheduling | |
| if type(self.scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau: | |
| self.scheduler.step(float(val_scalar_metrics["Loss/Validation"])) | |
| else: | |
| self.scheduler.step() | |
| new_lr = self.optimizer.param_groups[0]["lr"] | |
| self.logger.debug(f"Old lr: {curr_lr:.6f} - New lr: {new_lr:.6f}") | |
| def save_checkpoint(self, epoch: int, checkpoint_name: str): | |
| if self.early_stopping is None: | |
| best_metric = None | |
| best_epoch = None | |
| else: | |
| best_metric = self.early_stopping.best_metric | |
| best_epoch = self.early_stopping.best_epoch | |
| arch = type(self.model).__name__ | |
| state = { | |
| "arch": arch, | |
| "epoch": epoch, | |
| "model_state_dict": self.model.state_dict(), | |
| "optimizer_state_dict": self.optimizer.state_dict(), | |
| "scheduler_state_dict": self.scheduler.state_dict(), | |
| "best_metric": best_metric, | |
| "best_epoch": best_epoch, | |
| "config": flatten_dict(wandb.config), | |
| "wandb_id": wandb.run.id, | |
| "logdir": str(self.logdir.resolve()), | |
| "run_name": str(Path(self.logdir).name), | |
| "scaler_state_dict": self.scaler.state_dict() | |
| if self.scaler is not None | |
| else None, | |
| } | |
| checkpoint_dir = self.logdir / "checkpoints" | |
| checkpoint_dir.mkdir(exist_ok=True, parents=True) | |
| filename = str(checkpoint_dir / checkpoint_name) | |
| torch.save(state, filename) | |
| def resume_checkpoint(self, checkpoint): | |
| self.logger.info("Loading checkpoint") | |
| self.logger.info("Loading Model") | |
| self.model.load_state_dict(checkpoint["model_state_dict"]) | |
| self.logger.info("Loading Optimizer state dict") | |
| self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) | |
| if self.early_stopping is not None: | |
| self.early_stopping.best_metric = checkpoint["best_metric"] | |
| self.early_stopping.best_epoch = checkpoint["best_epoch"] | |
| if self.scaler is not None: | |
| self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) | |
| self.logger.info(f"Checkpoint epoch: {int(checkpoint['epoch'])}") | |
| self.start_epoch = int(checkpoint["epoch"]) | |
| self.logger.info(f"Next epoch is: {self.start_epoch + 1}") | |