Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # StarDist Experiment Class | |
| # | |
| # @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
| # Institute for Artifical Intelligence in Medicine, | |
| # University Medicine Essen | |
| import inspect | |
| import os | |
| import sys | |
| import yaml | |
| from base_ml.base_trainer import BaseTrainer | |
| currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
| parentdir = os.path.dirname(currentdir) | |
| sys.path.insert(0, parentdir) | |
| from pathlib import Path | |
| from typing import Callable, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from torch.optim import Optimizer | |
| from torch.optim.lr_scheduler import ( | |
| ConstantLR, | |
| CosineAnnealingLR, | |
| ExponentialLR, | |
| ReduceLROnPlateau, | |
| SequentialLR, | |
| _LRScheduler, | |
| ) | |
| from torch.utils.data import Dataset | |
| from torchinfo import summary | |
| from base_ml.base_loss import retrieve_loss_fn | |
| from cell_unireplknet.cell_segmentation.experiments.experiment_cellvit_pannuke_origin import ( | |
| ExperimentCellVitPanNuke, | |
| ) | |
| from cell_segmentation.trainer.trainer_stardist import CellViTStarDistTrainer | |
| from models.segmentation.cell_segmentation.cellvit_stardist import ( | |
| CellViTStarDist, | |
| CellViT256StarDist, | |
| CellViTSAMStarDist, | |
| ) | |
| from models.segmentation.cell_segmentation.cellvit_stardist_shared import ( | |
| CellViTStarDistShared, | |
| CellViT256StarDistShared, | |
| CellViTSAMStarDistShared, | |
| ) | |
| from models.segmentation.cell_segmentation.cpp_net_stardist_rn50 import StarDistRN50 | |
| class ExperimentCellViTStarDist(ExperimentCellVitPanNuke): | |
| def load_dataset_setup(self, dataset_path: Union[Path, str]) -> None: | |
| """Load the configuration of the PanNuke cell segmentation dataset. | |
| The dataset must have a dataset_config.yaml file in their dataset path with the following entries: | |
| * tissue_types: describing the present tissue types with corresponding integer | |
| * nuclei_types: describing the present nuclei types with corresponding integer | |
| Args: | |
| dataset_path (Union[Path, str]): Path to dataset folder | |
| """ | |
| dataset_config_path = Path(dataset_path) / "dataset_config.yaml" | |
| with open(dataset_config_path, "r") as dataset_config_file: | |
| yaml_config = yaml.safe_load(dataset_config_file) | |
| self.dataset_config = dict(yaml_config) | |
| def get_loss_fn(self, loss_fn_settings: dict) -> dict: | |
| """Create a dictionary with loss functions for all branches | |
| Branches: "dist_map", "stardist_map", "nuclei_type_map", "tissue_types" | |
| Args: | |
| loss_fn_settings (dict): Dictionary with the loss function settings. Structure | |
| branch_name(str): | |
| loss_name(str): | |
| loss_fn(str): String matching to the loss functions defined in the LOSS_DICT (base_ml.base_loss) | |
| weight(float): Weighting factor as float value | |
| (optional) args: Optional parameters for initializing the loss function | |
| arg_name: value | |
| If a branch is not provided, the defaults settings (described below) are used. | |
| For further information, please have a look at the file configs/examples/cell_segmentation/train_cellvit.yaml | |
| under the section "loss" | |
| Example: | |
| nuclei_type_map: | |
| bce: | |
| loss_fn: xentropy_loss | |
| weight: 1 | |
| dice: | |
| loss_fn: dice_loss | |
| weight: 1 | |
| Returns: | |
| dict: Dictionary with loss functions for each branch. Structure: | |
| branch_name(str): | |
| loss_name(str): | |
| "loss_fn": Callable loss function | |
| "weight": weight of the loss since in the end all losses of all branches are added together for backward pass | |
| loss_name(str): | |
| "loss_fn": Callable loss function | |
| "weight": weight of the loss since in the end all losses of all branches are added together for backward pass | |
| branch_name(str) | |
| ... | |
| Default loss dictionary: | |
| dist_map: | |
| bceweighted: | |
| loss_fn: BCEWithLogitsLoss | |
| weight: 1 | |
| stardist_map: | |
| L1LossWeighted: | |
| loss_fn: L1LossWeighted | |
| weight: 1 | |
| nuclei_type_map | |
| bce: | |
| loss_fn: xentropy_loss | |
| weight: 1 | |
| dice: | |
| loss_fn: dice_loss | |
| weight: 1 | |
| tissue_type has no default loss and might be skipped | |
| """ | |
| loss_fn_dict = {} | |
| if "dist_map" in loss_fn_settings.keys(): | |
| loss_fn_dict["dist_map"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["dist_map"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["dist_map"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| else: | |
| loss_fn_dict["dist_map"] = { | |
| "bceweighted": { | |
| "loss_fn": retrieve_loss_fn("BCEWithLogitsLoss"), | |
| "weight": 1, | |
| }, | |
| } | |
| if "stardist_map" in loss_fn_settings.keys(): | |
| loss_fn_dict["stardist_map"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["stardist_map"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["stardist_map"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| else: | |
| loss_fn_dict["stardist_map"] = { | |
| "L1LossWeighted": { | |
| "loss_fn": retrieve_loss_fn("L1LossWeighted"), | |
| "weight": 1, | |
| }, | |
| } | |
| if "nuclei_type_map" in loss_fn_settings.keys(): | |
| loss_fn_dict["nuclei_type_map"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["nuclei_type_map"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["nuclei_type_map"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| else: | |
| loss_fn_dict["nuclei_type_map"] = { | |
| "bce": {"loss_fn": retrieve_loss_fn("xentropy_loss"), "weight": 1}, | |
| "dice": {"loss_fn": retrieve_loss_fn("dice_loss"), "weight": 1}, | |
| } | |
| if "tissue_types" in loss_fn_settings.keys(): | |
| loss_fn_dict["tissue_types"] = {} | |
| for loss_name, loss_sett in loss_fn_settings["tissue_types"].items(): | |
| parameters = loss_sett.get("args", {}) | |
| loss_fn_dict["tissue_types"][loss_name] = { | |
| "loss_fn": retrieve_loss_fn(loss_sett["loss_fn"], **parameters), | |
| "weight": loss_sett["weight"], | |
| } | |
| # skip default tissue loss! | |
| return loss_fn_dict | |
| def get_scheduler(self, scheduler_type: str, optimizer: Optimizer) -> _LRScheduler: | |
| """Get the learning rate scheduler for CellViT | |
| The configuration of the scheduler is given in the "training" -> "scheduler" section. | |
| Currenlty, "constant", "exponential" and "cosine" schedulers are implemented. | |
| Required parameters for implemented schedulers: | |
| - "constant": None | |
| - "exponential": gamma (optional, defaults to 0.95) | |
| - "cosine": eta_min (optional, defaults to 1-e5) | |
| - "reducelronplateau": everything hardcoded right now, uses vall los for checking | |
| Args: | |
| scheduler_type (str): Type of scheduler as a string. Currently implemented: | |
| - "constant" (lowering by a factor of ten after 25 epochs, increasing after 50, decreasimg again after 75) | |
| - "exponential" (ExponentialLR with given gamma, gamma defaults to 0.95) | |
| - "cosine" (CosineAnnealingLR, eta_min as parameter, defaults to 1-e5) | |
| optimizer (Optimizer): Optimizer | |
| Returns: | |
| _LRScheduler: PyTorch Scheduler | |
| """ | |
| implemented_schedulers = [ | |
| "constant", | |
| "exponential", | |
| "cosine", | |
| "reducelronplateau", | |
| ] | |
| if scheduler_type.lower() not in implemented_schedulers: | |
| self.logger.warning( | |
| f"Unknown Scheduler - No scheduler from the list {implemented_schedulers} select. Using default scheduling." | |
| ) | |
| if scheduler_type.lower() == "constant": | |
| scheduler = SequentialLR( | |
| optimizer=optimizer, | |
| schedulers=[ | |
| ConstantLR(optimizer, factor=1, total_iters=25), | |
| ConstantLR(optimizer, factor=0.1, total_iters=25), | |
| ConstantLR(optimizer, factor=1, total_iters=25), | |
| ConstantLR(optimizer, factor=0.1, total_iters=1000), | |
| ], | |
| milestones=[24, 49, 74], | |
| ) | |
| elif scheduler_type.lower() == "exponential": | |
| scheduler = ExponentialLR( | |
| optimizer, | |
| gamma=self.run_conf["training"]["scheduler"].get("gamma", 0.95), | |
| ) | |
| elif scheduler_type.lower() == "cosine": | |
| scheduler = CosineAnnealingLR( | |
| optimizer, | |
| T_max=self.run_conf["training"]["epochs"], | |
| eta_min=self.run_conf["training"]["scheduler"].get("eta_min", 1e-5), | |
| ) | |
| elif scheduler_type.lower() == "reducelronplateau": | |
| scheduler = ReduceLROnPlateau( | |
| optimizer, | |
| mode="min", | |
| factor=0.5, | |
| min_lr=0.0000001, | |
| patience=10, | |
| threshold=1e-20, | |
| ) | |
| else: | |
| scheduler = super().get_scheduler(optimizer) | |
| return scheduler | |
| def get_datasets( | |
| self, | |
| train_transforms: Callable = None, | |
| val_transforms: Callable = None, | |
| ) -> Tuple[Dataset, Dataset]: | |
| """Retrieve training dataset and validation dataset | |
| Args: | |
| dataset_name (str): Name of dataset to use | |
| train_transforms (Callable, optional): PyTorch transformations for train set. Defaults to None. | |
| val_transforms (Callable, optional): PyTorch transformations for validation set. Defaults to None. | |
| Returns: | |
| Tuple[Dataset, Dataset]: Training dataset and validation dataset | |
| """ | |
| self.run_conf["data"]["stardist"] = True | |
| train_dataset, val_dataset = super().get_datasets( | |
| train_transforms=train_transforms, | |
| val_transforms=val_transforms, | |
| ) | |
| return train_dataset, val_dataset | |
| def get_train_model( | |
| self, | |
| pretrained_encoder: Union[Path, str] = None, | |
| pretrained_model: Union[Path, str] = None, | |
| backbone_type: str = "default", | |
| shared_decoders: bool = False, | |
| **kwargs, | |
| ) -> nn.Module: | |
| """Return the CellViTStarDist training model | |
| Args: | |
| pretrained_encoder (Union[Path, str]): Path to a pretrained encoder. Defaults to None. | |
| pretrained_model (Union[Path, str], optional): Path to a pretrained model. Defaults to None. | |
| backbone_type (str, optional): Backbone Type. Currently supported are default (None, ViT256, SAM-B, SAM-L, SAM-H, RN50). Defaults to None | |
| shared_decoders (bool, optional): If shared skip decoders should be used. Defaults to False. | |
| Returns: | |
| nn.Module: StarDist training model with given setup | |
| """ | |
| # reseed needed, due to subprocess seeding compatibility | |
| self.seed_run(self.default_conf["random_seed"]) | |
| # check for backbones | |
| implemented_backbones = ["default", "vit256", "sam-b", "sam-l", "sam-h", "rn50"] | |
| if backbone_type.lower() not in implemented_backbones: | |
| raise NotImplementedError( | |
| f"Unknown Backbone Type - Currently supported are: {implemented_backbones}" | |
| ) | |
| if backbone_type.lower() == "default": | |
| if shared_decoders: | |
| model_class = CellViTStarDistShared | |
| else: | |
| model_class = CellViTStarDist | |
| model = model_class( | |
| num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
| num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
| embed_dim=self.run_conf["model"]["embed_dim"], | |
| input_channels=self.run_conf["model"].get("input_channels", 3), | |
| depth=self.run_conf["model"]["depth"], | |
| num_heads=self.run_conf["model"]["num_heads"], | |
| extract_layers=self.run_conf["model"]["extract_layers"], | |
| drop_rate=self.run_conf["training"].get("drop_rate", 0), | |
| attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0), | |
| drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0), | |
| nrays=self.run_conf["model"].get("nrays", 32), | |
| ) | |
| if pretrained_model is not None: | |
| self.logger.info( | |
| f"Loading pretrained CellViT model from path: {pretrained_model}" | |
| ) | |
| cellvit_pretrained = torch.load(pretrained_model) | |
| self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True)) | |
| self.logger.info("Loaded CellViT model") | |
| if backbone_type.lower() == "vit256": | |
| if shared_decoders: | |
| model_class = CellViT256StarDistShared | |
| else: | |
| model_class = CellViT256StarDist | |
| model = model_class( | |
| model256_path=pretrained_encoder, | |
| num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
| num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
| drop_rate=self.run_conf["training"].get("drop_rate", 0), | |
| attn_drop_rate=self.run_conf["training"].get("attn_drop_rate", 0), | |
| drop_path_rate=self.run_conf["training"].get("drop_path_rate", 0), | |
| nrays=self.run_conf["model"].get("nrays", 32), | |
| ) | |
| model.load_pretrained_encoder(model.model256_path) | |
| if pretrained_model is not None: | |
| self.logger.info( | |
| f"Loading pretrained CellViT model from path: {pretrained_model}" | |
| ) | |
| cellvit_pretrained = torch.load(pretrained_model, map_location="cpu") | |
| self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True)) | |
| model.freeze_encoder() | |
| self.logger.info("Loaded CellVit256 model") | |
| if backbone_type.lower() in ["sam-b", "sam-l", "sam-h"]: | |
| if shared_decoders: | |
| model_class = CellViTSAMStarDistShared | |
| else: | |
| model_class = CellViTSAMStarDist | |
| model = model_class( | |
| model_path=pretrained_encoder, | |
| num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"], | |
| num_tissue_classes=self.run_conf["data"]["num_tissue_classes"], | |
| vit_structure=backbone_type, | |
| drop_rate=self.run_conf["training"].get("drop_rate", 0), | |
| nrays=self.run_conf["model"].get("nrays", 32), | |
| ) | |
| model.load_pretrained_encoder(model.model_path) | |
| if pretrained_model is not None: | |
| self.logger.info( | |
| f"Loading pretrained CellViT model from path: {pretrained_model}" | |
| ) | |
| cellvit_pretrained = torch.load(pretrained_model, map_location="cpu") | |
| self.logger.info(model.load_state_dict(cellvit_pretrained, strict=True)) | |
| model.freeze_encoder() | |
| self.logger.info(f"Loaded CellViT-SAM model with backbone: {backbone_type}") | |
| if backbone_type.lower() == "rn50": | |
| model = StarDistRN50( | |
| n_rays=self.run_conf["model"].get("nrays", 32), | |
| n_seg_cls=self.run_conf["data"]["num_nuclei_classes"], | |
| ) | |
| self.logger.info(f"\nModel: {model}") | |
| model = model.to("cpu") | |
| self.logger.info( | |
| f"\n{summary(model, input_size=(1, 3, 256, 256), device='cpu')}" | |
| ) | |
| return model | |
| def get_trainer(self) -> BaseTrainer: | |
| """Return Trainer matching to this network | |
| Returns: | |
| BaseTrainer: Trainer | |
| """ | |
| return CellViTStarDistTrainer | |