DIPO / my_utils /callbacks.py
xinjie.wang
init commit
c28dddb
import os, sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))
import torch
from my_utils.misc import dump_config
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.rank_zero import rank_zero_only
class ConfigSnapshotCallback(Callback):
def __init__(self, config):
super().__init__()
self.config = config
def setup(self, trainer, pl_module, stage) -> None:
self.savedir = os.path.join(pl_module.hparams.exp_dir, 'config')
@rank_zero_only
def save_config_snapshot(self):
os.makedirs(self.savedir, exist_ok=True)
dump_config(os.path.join(self.savedir, 'parsed.yaml'), self.config)
def on_fit_start(self, trainer, pl_module):
self.save_config_snapshot()
class GPUCacheCleanCallback(Callback):
def on_train_batch_start(self, *args, **kwargs):
torch.cuda.empty_cache()
def on_validation_batch_start(self, *args, **kwargs):
torch.cuda.empty_cache()
def on_test_batch_start(self, *args, **kwargs):
torch.cuda.empty_cache()
def on_predict_batch_start(self, *args, **kwargs):
torch.cuda.empty_cache()