Spaces:
Build error
Build error
| # Copyright (c) Open-MMLab. All rights reserved. | |
| import os.path as osp | |
| import platform | |
| import shutil | |
| import torch | |
| from torch.optim import Optimizer | |
| import mmcv | |
| from mmcv.runner import RUNNERS, EpochBasedRunner | |
| from .checkpoint import save_checkpoint | |
| try: | |
| import apex | |
| except: | |
| print('apex is not installed') | |
| class EpochBasedRunnerAmp(EpochBasedRunner): | |
| """Epoch-based Runner with AMP support. | |
| This runner train models epoch by epoch. | |
| """ | |
| def save_checkpoint(self, | |
| out_dir, | |
| filename_tmpl='epoch_{}.pth', | |
| save_optimizer=True, | |
| meta=None, | |
| create_symlink=True): | |
| """Save the checkpoint. | |
| Args: | |
| out_dir (str): The directory that checkpoints are saved. | |
| filename_tmpl (str, optional): The checkpoint filename template, | |
| which contains a placeholder for the epoch number. | |
| Defaults to 'epoch_{}.pth'. | |
| save_optimizer (bool, optional): Whether to save the optimizer to | |
| the checkpoint. Defaults to True. | |
| meta (dict, optional): The meta information to be saved in the | |
| checkpoint. Defaults to None. | |
| create_symlink (bool, optional): Whether to create a symlink | |
| "latest.pth" to point to the latest checkpoint. | |
| Defaults to True. | |
| """ | |
| if meta is None: | |
| meta = dict(epoch=self.epoch + 1, iter=self.iter) | |
| elif isinstance(meta, dict): | |
| meta.update(epoch=self.epoch + 1, iter=self.iter) | |
| else: | |
| raise TypeError( | |
| f'meta should be a dict or None, but got {type(meta)}') | |
| if self.meta is not None: | |
| meta.update(self.meta) | |
| filename = filename_tmpl.format(self.epoch + 1) | |
| filepath = osp.join(out_dir, filename) | |
| optimizer = self.optimizer if save_optimizer else None | |
| save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) | |
| # in some environments, `os.symlink` is not supported, you may need to | |
| # set `create_symlink` to False | |
| if create_symlink: | |
| dst_file = osp.join(out_dir, 'latest.pth') | |
| if platform.system() != 'Windows': | |
| mmcv.symlink(filename, dst_file) | |
| else: | |
| shutil.copy(filepath, dst_file) | |
| def resume(self, | |
| checkpoint, | |
| resume_optimizer=True, | |
| map_location='default'): | |
| if map_location == 'default': | |
| if torch.cuda.is_available(): | |
| device_id = torch.cuda.current_device() | |
| checkpoint = self.load_checkpoint( | |
| checkpoint, | |
| map_location=lambda storage, loc: storage.cuda(device_id)) | |
| else: | |
| checkpoint = self.load_checkpoint(checkpoint) | |
| else: | |
| checkpoint = self.load_checkpoint( | |
| checkpoint, map_location=map_location) | |
| self._epoch = checkpoint['meta']['epoch'] | |
| self._iter = checkpoint['meta']['iter'] | |
| if 'optimizer' in checkpoint and resume_optimizer: | |
| if isinstance(self.optimizer, Optimizer): | |
| self.optimizer.load_state_dict(checkpoint['optimizer']) | |
| elif isinstance(self.optimizer, dict): | |
| for k in self.optimizer.keys(): | |
| self.optimizer[k].load_state_dict( | |
| checkpoint['optimizer'][k]) | |
| else: | |
| raise TypeError( | |
| 'Optimizer should be dict or torch.optim.Optimizer ' | |
| f'but got {type(self.optimizer)}') | |
| if 'amp' in checkpoint: | |
| apex.amp.load_state_dict(checkpoint['amp']) | |
| self.logger.info('load amp state dict') | |
| self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) | |