Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Helper functions for models | |
| # | |
| # @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
| # Institute for Artifical Intelligence in Medicine, | |
| # University Medicine Essen | |
| from torch import nn | |
| def reset_weights(model: nn.Module) -> None: | |
| """Reset the parameters of the model to avaid weight leakage | |
| Args: | |
| model (nn.Module): PyTorch Model | |
| """ | |
| for layer in model.children(): | |
| if hasattr(layer, "reset_parameters"): | |
| layer.reset_parameters() | |
| def initialize_weights(module: nn.Module) -> None: | |
| """Initialize Module weights according to xavier | |
| Args: | |
| module (nn.Module): Model | |
| """ | |
| for m in module.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_normal_(m.weight) | |
| m.bias.data.zero_() | |
| elif isinstance(m, nn.BatchNorm1d): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |