Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import torch | |
| import imageio | |
| import numpy as np | |
| class SaverMixin(): | |
| def save_dir(self): | |
| return self.hparams.save_dir | |
| def convert_format(self, data): | |
| if isinstance(data, np.ndarray): | |
| return data | |
| elif isinstance(data, torch.Tensor): | |
| return data.cpu().numpy() | |
| elif isinstance(data, list): | |
| return [self.convert_format(d) for d in data] | |
| elif isinstance(data, dict): | |
| return {k: self.convert_format(v) for k, v in data.items()} | |
| else: | |
| raise TypeError('Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting', type(data)) | |
| def get_save_path(self, filename): | |
| save_path = os.path.join(self.save_dir, filename) | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| return save_path | |
| def save_rgb_image(self, filename, img): | |
| imageio.imwrite(self.get_save_path(filename), img) | |
| def save_rgb_video(self, filename, stage='fit', filter=None): | |
| img_dir = os.path.join(self.logger.log_dir, 'images', stage) | |
| writer_graph = imageio.get_writer(os.path.join(img_dir, filename), fps=1) | |
| for file in sorted(os.listdir(img_dir)): | |
| if file.endswith('.png') and 'gt' not in file: | |
| if filter is not None: | |
| if filter in file: | |
| writer_graph.append_data(imageio.imread(os.path.join(img_dir, file))) | |
| else: | |
| writer_graph.append_data(imageio.imread(os.path.join(img_dir, file))) | |
| writer_graph.close() | |
| def save_json(self, filename, data): | |
| save_path = self.get_save_path(filename) | |
| with open(save_path, 'w') as f: | |
| json.dump(data, f) | |