Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # CellVit Experiment Class | |
| # | |
| # @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
| # Institute for Artifical Intelligence in Medicine, | |
| # University Medicine Essen | |
| import argparse | |
| import copy | |
| import datetime | |
| import inspect | |
| import os | |
| import shutil | |
| import sys | |
| import yaml | |
| import numpy as np | |
| import math | |
| currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
| parentdir = os.path.dirname(currentdir) | |
| sys.path.insert(0, parentdir) | |
| import uuid | |
| from pathlib import Path | |
| from typing import Callable, Tuple, Union | |
| import torch | |
| from torchsummary import summary | |
| from torchstat import stat | |
| import albumentations as A | |
| import torch | |
| import torch.nn as nn | |
| import wandb | |
| from torch.optim import Optimizer | |
| from torch.optim.lr_scheduler import ( | |
| ConstantLR, | |
| CosineAnnealingLR, | |
| ExponentialLR, | |
| SequentialLR, | |
| _LRScheduler, | |
| CosineAnnealingWarmRestarts, | |
| ) | |
| from torch.utils.data import ( | |
| DataLoader, | |
| Dataset, | |
| RandomSampler, | |
| Sampler, | |
| Subset, | |
| WeightedRandomSampler, | |
| ) | |
| from torchinfo import summary | |
| from wandb.sdk.lib.runid import generate_id | |
| from base_ml.base_early_stopping import EarlyStopping | |
| from base_ml.base_experiment import BaseExperiment | |
| from base_ml.base_loss import retrieve_loss_fn | |
| from base_ml.base_trainer import BaseTrainer | |
| from cell_segmentation.datasets.base_cell import CellDataset | |
| from cell_segmentation.datasets.dataset_coordinator import select_dataset | |
| from cell_segmentation.trainer.trainer_cellvit import CellViTTrainer | |
| from models.segmentation.cell_segmentation.cellvit import CellViT | |
| from utils.tools import close_logger | |
| class WarmupCosineAnnealingLR(CosineAnnealingLR): | |
| def __init__(self, optimizer, T_max, eta_min=0, warmup_epochs=0, warmup_factor=0): | |
| super().__init__(optimizer, T_max=T_max, eta_min=eta_min) | |
| self.warmup_epochs = warmup_epochs | |
| self.warmup_factor = warmup_factor | |
| self.initial_lr = [group['lr'] for group in optimizer.param_groups] #初始化的学习率 | |
| def get_lr(self): | |
| if self.last_epoch < self.warmup_epochs: | |
| warmup_factor = self.warmup_factor + (1.0 - self.warmup_factor) * (self.last_epoch / self.warmup_epochs) | |
| return [base_lr * warmup_factor for base_lr in self.initial_lr] | |
| else: | |
| return [base_lr * self.get_lr_ratio() for base_lr in self.initial_lr] | |
| def get_lr_ratio(self): | |
| T_cur = min(self.last_epoch - self.warmup_epochs, self.T_max - self.warmup_epochs) | |
| return 0.5 * (1 + math.cos(math.pi * T_cur / (self.T_max - self.warmup_epochs))) | |
| class ExperimentCellVitPanNuke(BaseExperiment): | |
| def __init__(self, default_conf: dict, checkpoint=None) -> None: | |
| super().__init__(default_conf, checkpoint) | |
| self.load_dataset_setup(dataset_path=self.default_conf["data"]["dataset_path"]) | |
| def run_experiment(self) -> Tuple[Path, dict, nn.Module, dict]: | |
| """Main Experiment Code""" | |
| ### Setup | |
| # close loggers | |
| self.close_remaining_logger() | |
| # Initialize distributed training environment | |
| # get the config for the current run | |
| self.run_conf = copy.deepcopy(self.default_conf) | |
| self.run_conf["dataset_config"] = self.dataset_config | |
| self.run_name = f"{datetime.datetime.now().strftime('%Y-%m-%dT%H%M%S')}_{self.run_conf['logging']['log_comment']}" | |
| wandb_run_id = generate_id() | |
| resume = None | |
| if self.checkpoint is not None: | |
| wandb_run_id = self.checkpoint["wandb_id"] | |
| resume = "must" | |
| self.run_name = self.checkpoint["run_name"] | |
| # initialize wandb | |
| run = wandb.init( | |
| project=self.run_conf["logging"]["project"], | |
| tags=self.run_conf["logging"].get("tags", []), | |
| name=self.run_name, | |
| notes=self.run_conf["logging"]["notes"], | |
| dir=self.run_conf["logging"]["wandb_dir"], | |
| mode=self.run_conf["logging"]["mode"].lower(), | |
| group=self.run_conf["logging"].get("group", str(uuid.uuid4())), | |
| allow_val_change=True, | |
| id=wandb_run_id, | |
| resume=resume, | |
| settings=wandb.Settings(start_method="fork"), | |
| ) | |
| # get ids | |
| self.run_conf["logging"]["run_id"] = run.id | |
| self.run_conf["logging"]["wandb_file"] = run.id | |
| # overwrite configuration with sweep values are leave them as they are | |
| if self.run_conf["run_sweep"] is True: | |
| self.run_conf["logging"]["sweep_id"] = run.sweep_id | |
| self.run_conf["logging"]["log_dir"] = str( | |
| Path(self.default_conf["logging"]["log_dir"]) | |
| / f"sweep_{run.sweep_id}" | |
| / f"{self.run_name}_{self.run_conf['logging']['run_id']}" | |
| ) | |
| self.overwrite_sweep_values(self.run_conf, run.config) | |
| else: | |
| self.run_conf["logging"]["log_dir"] = str( | |
| Path(self.default_conf["logging"]["log_dir"]) / self.run_name | |
| ) | |
| # update wandb | |
| wandb.config.update( | |
| self.run_conf, allow_val_change=True | |
| ) # this may lead to the problem | |
| # create output folder, instantiate logger and store config | |
| self.create_output_dir(self.run_conf["logging"]["log_dir"]) | |
| self.logger = self.instantiate_logger() | |
| self.logger.info("Instantiated Logger. WandB init and config update finished.") | |
| self.logger.info(f"Run ist stored here: {self.run_conf['logging']['log_dir']}") | |
| self.store_config() | |
| self.logger.info( | |
| f"Cuda devices: {[torch.cuda.device(i) for i in range(torch.cuda.device_count())]}" | |
| ) | |
| ### Machine Learning | |
| #device = f"cuda:{2}" | |
| #device = torch.device("cuda:2") | |
| device = f"cuda:{self.run_conf['gpu']}" | |
| self.logger.info(f"Using GPU: {device}") | |
| self.logger.info(f"Using device: {device}") | |
| # loss functions | |
| loss_fn_dict = self.get_loss_fn(self.run_conf.get("loss", {})) | |
| self.logger.info("Loss functions:") | |
| self.logger.info(loss_fn_dict) | |
| # model | |
| model = self.get_train_model( | |
| pretrained_encoder=self.run_conf["model"].get("pretrained_encoder", None), | |
| pretrained_model=self.run_conf["model"].get("pretrained", None), | |
| backbone_type=self.run_conf["model"].get("backbone", "default"), | |
| shared_decoders=self.run_conf["model"].get("shared_decoders", False), | |
| regression_loss=self.run_conf["model"].get("regression_loss", False), | |
| ) | |
| model.to(device) | |
| # optimizer | |
| optimizer = self.get_optimizer( | |
| model, | |
| self.run_conf["training"]["optimizer"].lower(), | |
| self.run_conf["training"]["optimizer_hyperparameter"], | |
| #self.run_conf["training"]["optimizer"], | |
| self.run_conf["training"]["layer_decay"], | |
| ) | |
| # scheduler | |
| scheduler = self.get_scheduler( | |
| optimizer=optimizer, | |
| scheduler_type=self.run_conf["training"]["scheduler"]["scheduler_type"], | |
| ) | |
| # early stopping (no early stopping for basic setup) | |
| early_stopping = None | |
| if "early_stopping_patience" in self.run_conf["training"]: | |
| if self.run_conf["training"]["early_stopping_patience"] is not None: | |
| early_stopping = EarlyStopping( | |
| patience=self.run_conf["training"]["early_stopping_patience"], | |
| strategy="maximize", | |
| ) | |
| ### Data handling | |
| train_transforms, val_transforms = self.get_transforms( | |
| self.run_conf["transformations"], | |
| input_shape=self.run_conf["data"].get("input_shape", 256), | |
| ) | |
| train_dataset, val_dataset = self.get_datasets( | |
| train_transforms=train_transforms, | |
| val_transforms=val_transforms, | |
| ) | |
| # load sampler | |
| training_sampler = self.get_sampler( | |
| train_dataset=train_dataset, | |
| strategy=self.run_conf["training"].get("sampling_strategy", "random"), | |
| gamma=self.run_conf["training"].get("sampling_gamma", 1), | |
| ) | |
| # define dataloaders | |
| train_dataloader = DataLoader( | |
| train_dataset, | |
| batch_size=self.run_conf["training"]["batch_size"], | |
| sampler=training_sampler, | |
| num_workers=16, | |
| pin_memory=False, | |
| worker_init_fn=self.seed_worker, | |
| ) | |
| val_dataloader = DataLoader( | |
| val_dataset, | |
| batch_size=64, | |
| num_workers=8, | |
| pin_memory=True, | |
| worker_init_fn=self.seed_worker, | |
| ) | |
| # start Training | |
| self.logger.info("Instantiate Trainer") | |
| trainer_fn = self.get_trainer() | |
| trainer = trainer_fn( | |
| model=model, | |
| loss_fn_dict=loss_fn_dict, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| device=device, | |
| logger=self.logger, | |
| logdir=self.run_conf["logging"]["log_dir"], | |
| num_classes=self.run_conf["data"]["num_nuclei_classes"], | |
| dataset_config=self.dataset_config, | |
| early_stopping=early_stopping, | |
| experiment_config=self.run_conf, | |
| log_images=self.run_conf["logging"].get("log_images", False), | |
| magnification=self.run_conf["data"].get("magnification", 40), | |
| mixed_precision=self.run_conf["training"].get("mixed_precision", False), | |
| ) | |
| # Load checkpoint if provided | |
| if self.checkpoint is not None: | |
| self.logger.info("Checkpoint was provided. Restore ...") | |
| trainer.resume_checkpoint(self.checkpoint) | |
| # Call fit method | |
| self.logger.info("Calling Trainer Fit") | |
| trainer.fit( | |
| epochs=self.run_conf["training"]["epochs"], | |
| train_dataloader=train_dataloader, | |
| val_dataloader=val_dataloader, | |
| metric_init=self.get_wandb_init_dict(), | |
| unfreeze_epoch=self.run_conf["training"]["unfreeze_epoch"], | |
| eval_every=self.run_conf["training"].get("eval_every", 1), | |
| ) | |
| # Select best model if not provided by early stopping | |
| checkpoint_dir = Path(self.run_conf["logging"]["log_dir"]) / "checkpoints" | |
| if not (checkpoint_dir / "model_best.pth").is_file(): | |
| shutil.copy( | |
| checkpoint_dir / "latest_checkpoint.pth", | |
| checkpoint_dir / "model_best.pth", | |
| ) | |
| # At the end close logger | |
| self.logger.info(f"Finished run {run.id}") | |
| close_logger(self.logger) | |
| return self.run_conf["logging"]["log_dir"] | |
| def load_dataset_setup(self, dataset_path: Union[Path, str]) -> None: | |
| """Load the configuration of the cell segmentation dataset. | |
| The dataset must have a dataset_config.yaml file in their dataset path with the following entries: | |
| * tissue_types: describing the present tissue types with corresponding integer | |
| * nuclei_types: describing the present nuclei types with corresponding integer | |
| Args: | |
| dataset_path (Union[Path, str]): Path to dataset folder | |
| """ | |
| dataset_config_path = Path(dataset_path) / "dataset_config.yaml" | |
| with open(dataset_config_path, "r") as dataset_config_file: | |
| yaml_config = yaml.safe_load(dataset_config_file) | |
| self.dataset_config = dict(yaml_config) | |
| def get_loss_fn(self, loss_fn_settings: dict) -> dict: | |
| """Create a dictionary with loss functions for all branches | |
| Branches: "nuclei_binary_map", "hv_map", "nuclei_type_map", "tissue_types" | |
| Args: | |
| loss_fn_settings (dict): Dictionary with the loss function settings. Structure | |
| branch_name(str): | |
| loss_name(str): | |
| loss_fn(str): String matching to the loss functions defined in the LOSS_DICT (base_ml.base_loss) | |
| weight(float): Weighting factor as float value | |
| (optional) args: Optional parameters for initializing the loss function | |
| arg_name: value | |
| If a branch is not provided, the defaults settings (described below) are used. | |
| For further information, please have a look at the file configs/examples/cell_segmentation/train_cellvit.yaml | |
| under the section "loss" | |
| Example: | |
| nuclei_binary_map: | |
| bce: | |
| loss_fn: xentropy_loss | |
| weight: 1 | |
| dice: | |
| loss_fn: dice_loss | |
| weight: 1 | |
| Returns: | |
| dict: Dictionary with loss functions for each branch. Structure: | |
| branch_name(str): | |
| loss_name(str): | |
| "loss_fn": Callable loss function | |
| "weight": weight of the loss since in the end all losses of all branches are added together for backward pass | |
| loss_name(str): | |
| "loss_fn": Callable loss function | |
| "weight": weight of the loss since in the end all losses of all branches are added together for backward pass | |
| branch_name(str) | |
| ... | |
| Default loss dictionary: | |
| nuclei_binary_map: | |
| bce: | |
| loss_fn: xentropy_loss | |
| weight: 1 | |
| dice: | |
| loss_fn: dice_loss | |
| weight: 1 | |
| hv_map: | |
| mse: | |
| loss_fn: mse_loss_maps | |
| weight: 1 | |
| msge: | |
| loss_fn: msge_loss_maps | |
| weight: 1 | |
| nuclei_type_map | |
| bce: | |
| loss_fn: xentropy_loss | |
| weight: 1 | |
| dice: | |
| loss_fn: dice_loss | |
| weight: 1 | |
| tissue_types | |
| ce: | |
| loss_fn: nn.CrossEntropyLoss() | |
| weight: 1 | |
| """ | |
| loss_fn_dict = {} | |
| if "nuclei_binary_map" in loss_fn_settings.keys(): | |
| loss_fn_dict["nuclei_binary_map"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["nuclei_binary_map"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["nuclei_binary_map"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| else: | |
| loss_fn_dict["nuclei_binary_map"] = { | |
| "bce": {"loss_fn": retrieve_loss_fn("xentropy_loss"), "weight": 1}, | |
| "dice": {"loss_fn": retrieve_loss_fn("dice_loss"), "weight": 1}, | |
| } | |
| if "hv_map" in loss_fn_settings.keys(): | |
| loss_fn_dict["hv_map"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["hv_map"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["hv_map"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| else: | |
| loss_fn_dict["hv_map"] = { | |
| "mse": {"loss_fn": retrieve_loss_fn("mse_loss_maps"), "weight": 1}, | |
| "msge": {"loss_fn": retrieve_loss_fn("msge_loss_maps"), "weight": 1}, | |
| } | |
| if "nuclei_type_map" in loss_fn_settings.keys(): | |
| loss_fn_dict["nuclei_type_map"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["nuclei_type_map"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["nuclei_type_map"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| else: | |
| loss_fn_dict["nuclei_type_map"] = { | |
| "bce": {"loss_fn": retrieve_loss_fn("xentropy_loss"), "weight": 1}, | |
| "dice": {"loss_fn": retrieve_loss_fn("dice_loss"), "weight": 1}, | |
| } | |
| if "tissue_types" in loss_fn_settings.keys(): | |
| loss_fn_dict["tissue_types"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["tissue_types"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["tissue_types"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| else: | |
| loss_fn_dict["tissue_types"] = { | |
| "ce": {"loss_fn": nn.CrossEntropyLoss(), "weight": 1}, | |
| } | |
| if "regression_loss" in loss_fn_settings.keys(): | |
| loss_fn_dict["regression_map"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["regression_loss"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["regression_map"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| elif "regression_loss" in self.run_conf["model"].keys(): | |
| loss_fn_dict["regression_map"] = { | |
| "mse": {"loss_fn": retrieve_loss_fn("mse_loss_maps"), "weight": 1}, | |
| } | |
| return loss_fn_dict | |
| def get_scheduler(self, scheduler_type: str, optimizer: Optimizer) -> _LRScheduler: | |
| """Get the learning rate scheduler for CellViT | |
| The configuration of the scheduler is given in the "training" -> "scheduler" section. | |
| Currenlty, "constant", "exponential" and "cosine" schedulers are implemented. | |
| Required parameters for implemented schedulers: | |
| - "constant": None | |
| - "exponential": gamma (optional, defaults to 0.95) | |
| - "cosine": eta_min (optional, defaults to 1-e5) | |
| Args: | |
| scheduler_type (str): Type of scheduler as a string. Currently implemented: | |
| - "constant" (lowering by a factor of ten after 25 epochs, increasing after 50, decreasimg again after 75) | |
| - "exponential" (ExponentialLR with given gamma, gamma defaults to 0.95) | |
| - "cosine" (CosineAnnealingLR, eta_min as parameter, defaults to 1-e5) | |
| optimizer (Optimizer): Optimizer | |
| Returns: | |
| _LRScheduler: PyTorch Scheduler | |
| """ | |
| implemented_schedulers = ["constant", "exponential", "cosine", "default"] | |
| if scheduler_type.lower() not in implemented_schedulers: | |
| self.logger.warning( | |
| f"Unknown Scheduler - No scheduler from the list {implemented_schedulers} select. Using default scheduling." | |
| ) | |
| if scheduler_type.lower() == "constant": | |
| scheduler = SequentialLR( | |
| optimizer=optimizer, | |
| schedulers=[ | |
| ConstantLR(optimizer, factor=1, total_iters=25), | |
| ConstantLR(optimizer, factor=0.1, total_iters=25), | |
| ConstantLR(optimizer, factor=1, total_iters=25), | |
| ConstantLR(optimizer, factor=0.1, total_iters=1000), | |
| ], | |
| milestones=[24, 49, 74], | |
| ) | |
| elif scheduler_type.lower() == "exponential": | |
| scheduler = ExponentialLR( | |
| optimizer, | |
| gamma=self.run_conf["training"]["scheduler"].get("gamma", 0.95), | |
| ) | |
| elif scheduler_type.lower() == "cosine": | |
| scheduler = CosineAnnealingLR( | |
| optimizer, | |
| T_max=self.run_conf["training"]["epochs"], | |
| eta_min=self.run_conf["training"]["scheduler"].get("eta_min", 1e-5), | |
| ) | |
| # elif scheduler_type.lower == "cosinewarmrestarts": | |
| # scheduler = CosineAnnealingWarmRestarts( | |
| # optimizer, | |
| # T_0=self.run_conf["training"]["scheduler"]["T_0"], | |
| # T_mult=self.run_conf["training"]["scheduler"]["T_mult"], | |
| # eta_min=self.run_conf["training"]["scheduler"].get("eta_min", 1e-5) | |
| # ) | |
| elif scheduler_type.lower() == "default": | |
| scheduler = super().get_scheduler(optimizer) | |
| return scheduler | |
| def get_datasets( | |
| self, | |
| train_transforms: Callable = None, | |
| val_transforms: Callable = None, | |
| ) -> Tuple[Dataset, Dataset]: | |
| """Retrieve training dataset and validation dataset | |
| Args: | |
| train_transforms (Callable, optional): PyTorch transformations for train set. Defaults to None. | |
| val_transforms (Callable, optional): PyTorch transformations for validation set. Defaults to None. | |
| Returns: | |
| Tuple[Dataset, Dataset]: Training dataset and validation dataset | |
| """ | |
| if ( | |
| "val_split" in self.run_conf["data"] | |
| and "val_folds" in self.run_conf["data"] | |
| ): | |
| raise RuntimeError( | |
| "Provide either val_splits or val_folds in configuration file, not both." | |
| ) | |
| if ( | |
| "val_split" not in self.run_conf["data"] | |
| and "val_folds" not in self.run_conf["data"] | |
| ): | |
| raise RuntimeError( | |
| "Provide either val_split or val_folds in configuration file, one is necessary." | |
| ) | |
| if ( | |
| "val_split" not in self.run_conf["data"] | |
| and "val_folds" not in self.run_conf["data"] | |
| ): | |
| raise RuntimeError( | |
| "Provide either val_split or val_fold in configuration file, one is necessary." | |
| ) | |
| if "regression_loss" in self.run_conf["model"].keys(): | |
| self.run_conf["data"]["regression_loss"] = True | |
| full_dataset = select_dataset( | |
| dataset_name="pannuke", | |
| split="train", | |
| dataset_config=self.run_conf["data"], | |
| transforms=train_transforms, | |
| ) | |
| if "val_split" in self.run_conf["data"]: | |
| generator_split = torch.Generator().manual_seed( | |
| self.default_conf["random_seed"] | |
| ) | |
| val_splits = float(self.run_conf["data"]["val_split"]) | |
| train_dataset, val_dataset = torch.utils.data.random_split( | |
| full_dataset, | |
| lengths=[1 - val_splits, val_splits], | |
| generator=generator_split, | |
| ) | |
| val_dataset.dataset = copy.deepcopy(full_dataset) | |
| val_dataset.dataset.set_transforms(val_transforms) | |
| else: | |
| train_dataset = full_dataset | |
| val_dataset = select_dataset( | |
| dataset_name="pannuke", | |
| split="validation", | |
| dataset_config=self.run_conf["data"], | |
| transforms=val_transforms, | |
| ) | |
| return train_dataset, val_dataset | |
| def get_train_model( | |
| self, | |
| pretrained_encoder: Union[Path, str] = None, | |
| pretrained_model: Union[Path, str] = None, | |
| backbone_type: str = "default", | |
| shared_decoders: bool = False, | |
| regression_loss: bool = False, | |
| **kwargs, | |
| ) -> CellViT: | |
| """Return the CellViT training model | |
| Args: | |
| pretrained_encoder (Union[Path, str]): Path to a pretrained encoder. Defaults to None. | |
| pretrained_model (Union[Path, str], optional): Path to a pretrained model. Defaults to None. | |
| backbone_type (str, optional): Backbone Type. Currently supported are default (None, ViT256, SAM-B, SAM-L, SAM-H). Defaults to None | |
| shared_decoders (bool, optional): If shared skip decoders should be used. Defaults to False. | |
| regression_loss (bool, optional): If regression loss is used. Defaults to False | |
| Returns: | |
| CellViT: CellViT training model with given setup | |
| """ | |
| # reseed needed, due to subprocess seeding compatibility | |
| self.seed_run(self.default_conf["random_seed"]) | |
| # check for backbones | |
| implemented_backbones = ["default", "UniRepLKNet", "vit256", "sam-b", "sam-l", "sam-h"] | |
| if backbone_type.lower() not in implemented_backbones: | |
| raise NotImplementedError( | |
| f"Unknown Backbone Type - Currently supported are: {implemented_backbones}" | |
| ) | |
| if backbone_type.lower() == "default": | |
| model_class = CellViT | |
| model = model_class( | |
| model256_path = pretrained_encoder, | |
| num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
| num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
| #embed_dim=self.run_conf["model"]["embed_dim"], | |
| in_channels=self.run_conf["model"].get("input_chanels", 3), | |
| #depth=self.run_conf["model"]["depth"], | |
| #change | |
| #depth=(3, 3, 27, 3), | |
| #num_heads=self.run_conf["model"]["num_heads"], | |
| # extract_layers=self.run_conf["model"]["extract_layers"], | |
| dropout=self.run_conf["training"].get("drop_rate", 0), | |
| #attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0), | |
| drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0.1), | |
| #regression_loss=regression_loss, | |
| ) | |
| model.load_pretrained_encoder(model.model256_path) | |
| #model.load_state_dict(checkpoint["model"]) | |
| if pretrained_model is not None: | |
| self.logger.info( | |
| f"Loading pretrained CellViT model from path: {pretrained_model}" | |
| ) | |
| cellvit_pretrained = torch.load(pretrained_model) | |
| self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True)) | |
| self.logger.info("Loaded CellViT model") | |
| self.logger.info(f"\nModel: {model}") | |
| print(f"\nModel: {model}") | |
| model = model.to("cuda") | |
| self.logger.info( | |
| f"\n{summary(model, input_size=(1, 3, 256, 256), device='cuda')}" | |
| ) | |
| # from thop import profile | |
| # input_size=torch.randn(1, 3, 256, 256) | |
| # self.logger.info( | |
| # f"\n{profile(model, inputs=(input_size,))}" | |
| # ) | |
| #self.logger.info(f"\n{stat(model, (3, 256, 256))}") | |
| total_params = 0 | |
| Trainable_params = 0 | |
| NonTrainable_params = 0 | |
| for param in model.parameters(): | |
| multvalue = np.prod(param.size()) | |
| total_params += multvalue | |
| if param.requires_grad: | |
| Trainable_params += multvalue # 可训练参数量 | |
| else: | |
| NonTrainable_params += multvalue # 非可训练参数量 | |
| print(f'Total params: {total_params}') | |
| print(f'Trainable params: {Trainable_params}') | |
| print(f'Non-trainable params: {NonTrainable_params}') | |
| return model | |
| def get_wandb_init_dict(self) -> dict: | |
| pass | |
| def get_transforms( | |
| self, transform_settings: dict, input_shape: int = 256 | |
| ) -> Tuple[Callable, Callable]: | |
| """Get Transformations (Albumentation Transformations). Return both training and validation transformations. | |
| The transformation settings are given in the following format: | |
| key: dict with parameters | |
| Example: | |
| colorjitter: | |
| p: 0.1 | |
| scale_setting: 0.5 | |
| scale_color: 0.1 | |
| For further information on how to setup the dictionary and default (recommended) values is given here: | |
| configs/examples/cell_segmentation/train_cellvit.yaml | |
| Training Transformations: | |
| Implemented are: | |
| - A.RandomRotate90: Key in transform_settings: randomrotate90, parameters: p | |
| - A.HorizontalFlip: Key in transform_settings: horizontalflip, parameters: p | |
| - A.VerticalFlip: Key in transform_settings: verticalflip, parameters: p | |
| - A.Downscale: Key in transform_settings: downscale, parameters: p, scale | |
| - A.Blur: Key in transform_settings: blur, parameters: p, blur_limit | |
| - A.GaussNoise: Key in transform_settings: gaussnoise, parameters: p, var_limit | |
| - A.ColorJitter: Key in transform_settings: colorjitter, parameters: p, scale_setting, scale_color | |
| - A.Superpixels: Key in transform_settings: superpixels, parameters: p | |
| - A.ZoomBlur: Key in transform_settings: zoomblur, parameters: p | |
| - A.RandomSizedCrop: Key in transform_settings: randomsizedcrop, parameters: p | |
| - A.ElasticTransform: Key in transform_settings: elastictransform, parameters: p | |
| Always implemented at the end of the pipeline: | |
| - A.Normalize with given mean (default: (0.5, 0.5, 0.5)) and std (default: (0.5, 0.5, 0.5)) | |
| Validation Transformations: | |
| A.Normalize with given mean (default: (0.5, 0.5, 0.5)) and std (default: (0.5, 0.5, 0.5)) | |
| Args: | |
| transform_settings (dict): dictionay with the transformation settings. | |
| input_shape (int, optional): Input shape of the images to used. Defaults to 256. | |
| Returns: | |
| Tuple[Callable, Callable]: Train Transformations, Validation Transformations | |
| """ | |
| transform_list = [] | |
| transform_settings = {k.lower(): v for k, v in transform_settings.items()} | |
| if "RandomRotate90".lower() in transform_settings: | |
| p = transform_settings["randomrotate90"]["p"] | |
| if p > 0 and p <= 1: | |
| transform_list.append(A.RandomRotate90(p=p)) | |
| if "HorizontalFlip".lower() in transform_settings.keys(): | |
| p = transform_settings["horizontalflip"]["p"] | |
| if p > 0 and p <= 1: | |
| transform_list.append(A.HorizontalFlip(p=p)) | |
| if "VerticalFlip".lower() in transform_settings: | |
| p = transform_settings["verticalflip"]["p"] | |
| if p > 0 and p <= 1: | |
| transform_list.append(A.VerticalFlip(p=p)) | |
| if "Downscale".lower() in transform_settings: | |
| p = transform_settings["downscale"]["p"] | |
| scale = transform_settings["downscale"]["scale"] | |
| if p > 0 and p <= 1: | |
| transform_list.append( | |
| A.Downscale(p=p, scale_max=scale, scale_min=scale) | |
| ) | |
| if "Blur".lower() in transform_settings: | |
| p = transform_settings["blur"]["p"] | |
| blur_limit = transform_settings["blur"]["blur_limit"] | |
| if p > 0 and p <= 1: | |
| transform_list.append(A.Blur(p=p, blur_limit=blur_limit)) | |
| if "GaussNoise".lower() in transform_settings: | |
| p = transform_settings["gaussnoise"]["p"] | |
| var_limit = transform_settings["gaussnoise"]["var_limit"] | |
| if p > 0 and p <= 1: | |
| transform_list.append(A.GaussNoise(p=p, var_limit=var_limit)) | |
| if "ColorJitter".lower() in transform_settings: | |
| p = transform_settings["colorjitter"]["p"] | |
| scale_setting = transform_settings["colorjitter"]["scale_setting"] | |
| scale_color = transform_settings["colorjitter"]["scale_color"] | |
| if p > 0 and p <= 1: | |
| transform_list.append( | |
| A.ColorJitter( | |
| p=p, | |
| brightness=scale_setting, | |
| contrast=scale_setting, | |
| saturation=scale_color, | |
| hue=scale_color / 2, | |
| ) | |
| ) | |
| if "Superpixels".lower() in transform_settings: | |
| p = transform_settings["superpixels"]["p"] | |
| if p > 0 and p <= 1: | |
| transform_list.append( | |
| A.Superpixels( | |
| p=p, | |
| p_replace=0.1, | |
| n_segments=200, | |
| max_size=int(input_shape / 2), | |
| ) | |
| ) | |
| if "ZoomBlur".lower() in transform_settings: | |
| p = transform_settings["zoomblur"]["p"] | |
| if p > 0 and p <= 1: | |
| transform_list.append(A.ZoomBlur(p=p, max_factor=1.05)) | |
| if "RandomSizedCrop".lower() in transform_settings: | |
| p = transform_settings["randomsizedcrop"]["p"] | |
| if p > 0 and p <= 1: | |
| transform_list.append( | |
| A.RandomSizedCrop( | |
| min_max_height=(input_shape / 2, input_shape), | |
| height=input_shape, | |
| width=input_shape, | |
| p=p, | |
| ) | |
| ) | |
| if "ElasticTransform".lower() in transform_settings: | |
| p = transform_settings["elastictransform"]["p"] | |
| if p > 0 and p <= 1: | |
| transform_list.append( | |
| A.ElasticTransform(p=p, sigma=25, alpha=0.5, alpha_affine=15) | |
| ) | |
| if "normalize" in transform_settings: | |
| mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5)) | |
| std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5)) | |
| else: | |
| mean = (0.5, 0.5, 0.5) | |
| std = (0.5, 0.5, 0.5) | |
| transform_list.append(A.Normalize(mean=mean, std=std)) | |
| train_transforms = A.Compose(transform_list) | |
| val_transforms = A.Compose([A.Normalize(mean=mean, std=std)]) | |
| return train_transforms, val_transforms | |
| def get_sampler( | |
| self, train_dataset: CellDataset, strategy: str = "random", gamma: float = 1 | |
| ) -> Sampler: | |
| """Return the sampler (either RandomSampler or WeightedRandomSampler) | |
| Args: | |
| train_dataset (CellDataset): Dataset for training | |
| strategy (str, optional): Sampling strategy. Defaults to "random" (random sampling). | |
| Implemented are "random", "cell", "tissue", "cell+tissue". | |
| gamma (float, optional): Gamma scaling factor, between 0 and 1. | |
| 1 means total balancing, 0 means original weights. Defaults to 1. | |
| Raises: | |
| NotImplementedError: Not implemented sampler is selected | |
| Returns: | |
| Sampler: Sampler for training | |
| """ | |
| if strategy.lower() == "random": | |
| sampling_generator = torch.Generator().manual_seed( | |
| self.default_conf["random_seed"] | |
| ) | |
| sampler = RandomSampler(train_dataset, generator=sampling_generator) | |
| self.logger.info("Using RandomSampler") | |
| else: | |
| # this solution is not accurate when a subset is used since the weights are calculated on the whole training dataset | |
| if isinstance(train_dataset, Subset): | |
| ds = train_dataset.dataset | |
| else: | |
| ds = train_dataset | |
| ds.load_cell_count() | |
| if strategy.lower() == "cell": | |
| weights = ds.get_sampling_weights_cell(gamma) | |
| elif strategy.lower() == "tissue": | |
| weights = ds.get_sampling_weights_tissue(gamma) | |
| elif strategy.lower() == "cell+tissue": | |
| weights = ds.get_sampling_weights_cell_tissue(gamma) | |
| else: | |
| raise NotImplementedError( | |
| "Unknown sampling strategy - Implemented are cell, tissue and cell+tissue" | |
| ) | |
| if isinstance(train_dataset, Subset): | |
| weights = torch.Tensor([weights[i] for i in train_dataset.indices]) | |
| sampling_generator = torch.Generator().manual_seed( | |
| self.default_conf["random_seed"] | |
| ) | |
| sampler = WeightedRandomSampler( | |
| weights=weights, | |
| num_samples=len(train_dataset), | |
| replacement=True, | |
| generator=sampling_generator, | |
| ) | |
| self.logger.info(f"Using Weighted Sampling with strategy: {strategy}") | |
| self.logger.info(f"Unique-Weights: {torch.unique(weights)}") | |
| return sampler | |
| def get_trainer(self) -> BaseTrainer: | |
| """Return Trainer matching to this network | |
| Returns: | |
| BaseTrainer: Trainer | |
| """ | |
| return CellViTTrainer | |