Spaces:
Running
on
Zero
Running
on
Zero
| import os, sys | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) | |
| import torch | |
| from my_utils.misc import dump_config | |
| from lightning.pytorch.callbacks.callback import Callback | |
| from lightning.pytorch.utilities.rank_zero import rank_zero_only | |
| class ConfigSnapshotCallback(Callback): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| def setup(self, trainer, pl_module, stage) -> None: | |
| self.savedir = os.path.join(pl_module.hparams.exp_dir, 'config') | |
| def save_config_snapshot(self): | |
| os.makedirs(self.savedir, exist_ok=True) | |
| dump_config(os.path.join(self.savedir, 'parsed.yaml'), self.config) | |
| def on_fit_start(self, trainer, pl_module): | |
| self.save_config_snapshot() | |
| class GPUCacheCleanCallback(Callback): | |
| def on_train_batch_start(self, *args, **kwargs): | |
| torch.cuda.empty_cache() | |
| def on_validation_batch_start(self, *args, **kwargs): | |
| torch.cuda.empty_cache() | |
| def on_test_batch_start(self, *args, **kwargs): | |
| torch.cuda.empty_cache() | |
| def on_predict_batch_start(self, *args, **kwargs): | |
| torch.cuda.empty_cache() |