Spaces:
Build error
Build error
| # Copyright (c) Open-MMLab. All rights reserved. | |
| import os.path as osp | |
| import time | |
| from tempfile import TemporaryDirectory | |
| import torch | |
| from torch.optim import Optimizer | |
| import mmcv | |
| from mmcv.parallel import is_module_wrapper | |
| from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict | |
| try: | |
| import apex | |
| except: | |
| print('apex is not installed') | |
| def save_checkpoint(model, filename, optimizer=None, meta=None): | |
| """Save checkpoint to file. | |
| The checkpoint will have 4 fields: ``meta``, ``state_dict`` and | |
| ``optimizer``, ``amp``. By default ``meta`` will contain version | |
| and time info. | |
| Args: | |
| model (Module): Module whose params are to be saved. | |
| filename (str): Checkpoint filename. | |
| optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. | |
| meta (dict, optional): Metadata to be saved in checkpoint. | |
| """ | |
| if meta is None: | |
| meta = {} | |
| elif not isinstance(meta, dict): | |
| raise TypeError(f'meta must be a dict or None, but got {type(meta)}') | |
| meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) | |
| if is_module_wrapper(model): | |
| model = model.module | |
| if hasattr(model, 'CLASSES') and model.CLASSES is not None: | |
| # save class name to the meta | |
| meta.update(CLASSES=model.CLASSES) | |
| checkpoint = { | |
| 'meta': meta, | |
| 'state_dict': weights_to_cpu(get_state_dict(model)) | |
| } | |
| # save optimizer state dict in the checkpoint | |
| if isinstance(optimizer, Optimizer): | |
| checkpoint['optimizer'] = optimizer.state_dict() | |
| elif isinstance(optimizer, dict): | |
| checkpoint['optimizer'] = {} | |
| for name, optim in optimizer.items(): | |
| checkpoint['optimizer'][name] = optim.state_dict() | |
| # save amp state dict in the checkpoint | |
| checkpoint['amp'] = apex.amp.state_dict() | |
| if filename.startswith('pavi://'): | |
| try: | |
| from pavi import modelcloud | |
| from pavi.exception import NodeNotFoundError | |
| except ImportError: | |
| raise ImportError( | |
| 'Please install pavi to load checkpoint from modelcloud.') | |
| model_path = filename[7:] | |
| root = modelcloud.Folder() | |
| model_dir, model_name = osp.split(model_path) | |
| try: | |
| model = modelcloud.get(model_dir) | |
| except NodeNotFoundError: | |
| model = root.create_training_model(model_dir) | |
| with TemporaryDirectory() as tmp_dir: | |
| checkpoint_file = osp.join(tmp_dir, model_name) | |
| with open(checkpoint_file, 'wb') as f: | |
| torch.save(checkpoint, f) | |
| f.flush() | |
| model.create_file(checkpoint_file, name=model_name) | |
| else: | |
| mmcv.mkdir_or_exist(osp.dirname(filename)) | |
| # immediately flush buffer | |
| with open(filename, 'wb') as f: | |
| torch.save(checkpoint, f) | |
| f.flush() | |