Spaces:
Runtime error
Runtime error
| import torchvision.transforms as T | |
| TRANSFORMS = dict() | |
| def register_transform(transform): | |
| name = transform.__name__ | |
| if name in TRANSFORMS: | |
| raise RuntimeError(f'Transform {name} has already registered.') | |
| TRANSFORMS.update({name: transform}) | |
| def get_transform(type, resolution): | |
| transform = TRANSFORMS[type](resolution) | |
| transform = T.Compose(transform) | |
| transform.image_size = resolution | |
| return transform | |
| def default_train(n_px): | |
| transform = [ | |
| T.Lambda(lambda img: img.convert('RGB')), | |
| T.Resize(n_px), # Image.BICUBIC | |
| T.CenterCrop(n_px), | |
| # T.RandomHorizontalFlip(), | |
| T.ToTensor(), | |
| T.Normalize([.5], [.5]), | |
| ] | |
| return transform | |