Spaces:
Sleeping
Sleeping
| # ------------------------------------------------------------------------------ | |
| # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport | |
| # Copyright (c) 2024 Borui Zhang. All Rights Reserved. | |
| # Licensed under the MIT License [see LICENSE for details] | |
| # ------------------------------------------------------------------------------ | |
| from typing import Callable | |
| import argparse | |
| import os | |
| from omegaconf import OmegaConf | |
| from functools import partial | |
| from torchinfo import summary | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data.dataloader import DataLoader | |
| from revq.utils.init import initiate_from_config_recursively | |
| from revq.data.dataloader import maybe_get_subset | |
| import revq.utils.logger as L | |
| def setup_config(opt: argparse.Namespace): | |
| L.log.info("\n\n### Setting up the configurations. ###") | |
| # load the config files | |
| config = OmegaConf.load(opt.config) | |
| # overwrite the certain arguments according to the config.args mapping | |
| for key, value in config.args_map.items(): | |
| if hasattr(opt, key) and getattr(opt, key) is not None: | |
| msg = f"config.{value} = opt.{key}" | |
| L.log.info(f"Overwrite the config: {msg}") | |
| exec(msg) | |
| return config | |
| def setup_dataloader(data, batch_size, is_distributed: bool = True, is_train: bool = True, num_workers: int = 8): | |
| if is_train: | |
| if is_distributed: | |
| # setup the sampler | |
| sampler = torch.utils.data.distributed.DistributedSampler(data, shuffle=True, drop_last=True) | |
| # setup the dataloader | |
| loader = DataLoader( | |
| dataset=data, batch_size=batch_size, num_workers=num_workers, | |
| drop_last=True, sampler=sampler, persistent_workers=True, pin_memory=True | |
| ) | |
| else: | |
| # setup the dataloader | |
| loader = DataLoader( | |
| dataset=data, batch_size=batch_size, num_workers=num_workers, | |
| drop_last=True, shuffle=True, persistent_workers=True, pin_memory=True | |
| ) | |
| else: | |
| if is_distributed: | |
| # setup the sampler | |
| sampler = torch.utils.data.distributed.DistributedSampler(data, shuffle=False, drop_last=False) | |
| # setup the dataloader | |
| loader = DataLoader( | |
| dataset=data, batch_size=batch_size, num_workers=num_workers, | |
| drop_last=False, sampler=sampler, persistent_workers=True, pin_memory=True | |
| ) | |
| else: | |
| # setup the dataloader | |
| loader = DataLoader( | |
| dataset=data, batch_size=batch_size, num_workers=num_workers, | |
| drop_last=False, shuffle=False, persistent_workers=True, pin_memory=True | |
| ) | |
| return loader | |
| def setup_dataset(config: OmegaConf): | |
| L.log.info("\n\n### Setting up the datasets. ###") | |
| # setup the training dataset | |
| train_data = initiate_from_config_recursively(config.data.train) | |
| if config.data.use_train_subset is not None: | |
| train_data = maybe_get_subset(train_data, subset_size=config.data.use_train_subset, num_data_repeat=config.data.use_train_repeat) | |
| L.log.info(f"Training dataset size: {len(train_data)}") | |
| # setup the validation dataset | |
| val_data = initiate_from_config_recursively(config.data.val) | |
| if config.data.use_val_subset is not None: | |
| val_data = maybe_get_subset(val_data, subset_size=config.data.use_val_subset) | |
| L.log.info(f"Validation dataset size: {len(val_data)}") | |
| return train_data, val_data | |
| def setup_model(config: OmegaConf, device): | |
| L.log.info("\n\n### Setting up the models. ###") | |
| # setup the model | |
| model = initiate_from_config_recursively(config.model.autoencoder) | |
| if config.is_distributed: | |
| # apply syncBN | |
| model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
| # model to devices | |
| model = model.to(device) | |
| find_unused_parameters = True | |
| model = torch.nn.parallel.DistributedDataParallel( | |
| module=model, device_ids=[config.gpu], | |
| find_unused_parameters=find_unused_parameters | |
| ) | |
| model_ori = model.module | |
| else: | |
| model = model.to(device) | |
| model_ori = model | |
| input_size = config.data.train.params.transform.params.resize | |
| in_channels = getattr(model_ori.encoder, "in_dim", 3) | |
| sout = summary(model_ori, (1, in_channels, input_size, input_size), device="cuda", verbose=0) | |
| L.log.info(sout) | |
| # count the total number of parameters | |
| for name, module in model_ori.named_children(): | |
| num_params = sum(p.numel() for p in module.parameters()) | |
| num_trainable = sum(p.numel() for p in module.parameters() if p.requires_grad) | |
| L.log.info(f"Module: {name}, Total params: {num_params}, Trainable params: {num_trainable}") | |
| return model | |
| ### factory functions | |
| def get_setup_optimizers(config): | |
| name = config.train.pipeline | |
| func_name = "setup_optimizers_" + name | |
| return globals()[func_name] | |
| def get_pipeline(config): | |
| name = config.train.pipeline | |
| func_name = "pipeline_" + name | |
| return globals()[func_name] | |
| def _forward_backward( | |
| config, | |
| x: torch.Tensor, | |
| forward: Callable, | |
| model: nn.Module, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: torch.optim.lr_scheduler._LRScheduler, | |
| scaler: torch.cuda.amp.GradScaler, | |
| ): | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, | |
| enabled=config.use_amp): | |
| # forward pass | |
| loss, *output = forward(x) | |
| loss_acc = loss / config.data.gradient_accumulate | |
| scaler.scale(loss_acc).backward() | |
| # gradient accumulate | |
| if L.log.total_steps % config.data.gradient_accumulate == 0: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| scaler.step(optimizer) | |
| optimizer.zero_grad() | |
| scaler.update() | |
| if scheduler is not None: | |
| scheduler.step() | |
| return loss, output | |
| ### autoencoder version | |
| def _find_weight_decay_id(modules: list, params_ids: list, | |
| include_class: tuple = (nn.Linear, nn.Conv2d, | |
| nn.ConvTranspose2d, | |
| nn.MultiheadAttention), | |
| include_name: list = ["weight"]): | |
| for mod in modules: | |
| for sub_mod in mod.modules(): | |
| if isinstance(sub_mod, include_class): | |
| for name, param in sub_mod.named_parameters(): | |
| if any([k in name for k in include_name]): | |
| params_ids.append(id(param)) | |
| params_ids = list(set(params_ids)) | |
| return params_ids | |
| def set_weight_decay(modules: list): | |
| weight_decay_ids = _find_weight_decay_id(modules, []) | |
| wd_params, wd_names, no_wd_params, no_wd_names = [], [], [], [] | |
| for mod in modules: | |
| for name, param in mod.named_parameters(): | |
| if id(param) in weight_decay_ids: | |
| wd_params.append(param) | |
| wd_names.append(name) | |
| else: | |
| no_wd_params.append(param) | |
| no_wd_names.append(name) | |
| return wd_params, wd_names, no_wd_params, no_wd_names | |
| def setup_optimizers_ae(config: OmegaConf, model: nn.Module, total_steps: int): | |
| L.log.info("\n\n### Setting up the optimizers and schedulers. ###") | |
| # compute the total batch size and the learning rate | |
| total_batch_size = config.data.batch_size * config.world_size * config.data.gradient_accumulate | |
| total_learning_rate = config.train.learning_rate * total_batch_size | |
| multipled_learning_rate = total_learning_rate * config.train.mul_learning_rate | |
| L.log.info(f"Total batch size: {total_batch_size} = {config.data.batch_size} * {config.world_size} * {config.data.gradient_accumulate}") | |
| L.log.info(f"Total learning rate: {total_learning_rate} = {config.train.learning_rate} * {total_batch_size}") | |
| L.log.info(f"Multipled learning rate: {multipled_learning_rate} = {total_learning_rate} * {config.train.mul_learning_rate}") | |
| # setup the optimizers | |
| param_group = [] | |
| ## base learning rate | |
| wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.encoder, model.decoder, model.quant_conv, model.post_quant_conv]) | |
| param_group.append({ | |
| "params": wd_params, "lr": total_learning_rate, "eps": 1e-7, | |
| "weight_decay": config.train.weight_decay, "beta": (0.9, 0.999), | |
| }) | |
| param_group.append({ | |
| "params": no_wd_params, "lr": total_learning_rate, "eps": 1e-7, | |
| "weight_decay": 0.0, "beta": (0.9, 0.999), | |
| }) | |
| ## multipled learning rate | |
| wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.quantize]) | |
| param_group.append({ | |
| "params": wd_params, "lr": multipled_learning_rate, "eps": 1e-7, | |
| "weight_decay": config.train.weight_decay, "beta": (0.9, 0.999), | |
| }) | |
| param_group.append({ | |
| "params": no_wd_params, "lr": multipled_learning_rate, "eps": 1e-7, | |
| "weight_decay": 0.0, "beta": (0.9, 0.999), | |
| }) | |
| optimizer_ae = torch.optim.AdamW(param_group) | |
| optimizer_dict = {"optimizer_ae": optimizer_ae} | |
| # setup the schedulers | |
| scheduler_ae = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer=optimizer_ae, max_lr=[total_learning_rate, total_learning_rate, multipled_learning_rate, multipled_learning_rate], | |
| total_steps=total_steps, pct_start=0.01, anneal_strategy="cos" | |
| ) | |
| scheduler_dict = {"scheduler_ae": scheduler_ae} | |
| # setup the scalers | |
| scaler_dict = {"scaler_ae": torch.GradScaler(enabled=config.use_amp)} | |
| L.log.info(f"Enable AMP: {config.use_amp}") | |
| return optimizer_dict, scheduler_dict, scaler_dict | |
| def pipeline_ae( | |
| config, | |
| x: torch.Tensor, | |
| model: nn.Module, | |
| optimizers: dict, | |
| schedulers: dict, | |
| scalers: dict, | |
| ): | |
| assert "optimizer_ae" in optimizers | |
| assert "scheduler_ae" in schedulers | |
| assert "scaler_ae" in scalers | |
| optimizer = optimizers["optimizer_ae"] | |
| scheduler = schedulers["scheduler_ae"] | |
| scaler = scalers["scaler_ae"] | |
| forward = partial(model, mode=0) | |
| _, (loss_ae_dict, indices) = _forward_backward(config, x, forward, model, optimizer, scheduler, scaler) | |
| log_per_step = loss_ae_dict | |
| log_per_epoch = {"indices": indices} | |
| return log_per_step, log_per_epoch | |
| ### autoencoder + disc version | |
| def setup_optimizers_ae_disc(config: OmegaConf, model: nn.Module, total_steps: int): | |
| L.log.info("\n\n### Setting up the optimizers and schedulers. ###") | |
| # compute the total batch size and the learning rate | |
| total_batch_size = config.data.batch_size * config.world_size * config.data.gradient_accumulate | |
| total_learning_rate = config.train.learning_rate * total_batch_size | |
| multipled_learning_rate = total_learning_rate * config.train.mul_learning_rate | |
| L.log.info(f"Total batch size: {total_batch_size} = {config.data.batch_size} * {config.world_size} * {config.data.gradient_accumulate}") | |
| L.log.info(f"Total learning rate: {total_learning_rate} = {config.train.learning_rate} * {total_batch_size}") | |
| L.log.info(f"Multipled learning rate: {multipled_learning_rate} = {total_learning_rate} * {config.train.mul_learning_rate}") | |
| # setup the optimizers | |
| param_group = [] | |
| ## base learning rate | |
| wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.encoder, model.decoder, model.quant_conv, model.post_quant_conv]) | |
| param_group.append({ | |
| "params": wd_params, "lr": total_learning_rate, "eps": 1e-7, | |
| "weight_decay": config.train.weight_decay, "beta": (0.9, 0.999), | |
| }) | |
| param_group.append({ | |
| "params": no_wd_params, "lr": total_learning_rate, "eps": 1e-7, | |
| "weight_decay": 0.0, "beta": (0.9, 0.999), | |
| }) | |
| ## multipled learning rate | |
| wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.quantize]) | |
| param_group.append({ | |
| "params": wd_params, "lr": multipled_learning_rate, "eps": 1e-7, | |
| "weight_decay": config.train.weight_decay, "beta": (0.9, 0.999), | |
| }) | |
| param_group.append({ | |
| "params": no_wd_params, "lr": multipled_learning_rate, "eps": 1e-7, | |
| "weight_decay": 0.0, "beta": (0.9, 0.999), | |
| }) | |
| optimizer_ae = torch.optim.AdamW(param_group) | |
| param_group = [] | |
| wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.loss.discriminator]) | |
| param_group.append({ | |
| "params": wd_params, "lr": total_learning_rate, "eps": 1e-7, | |
| "weight_decay": config.train.weight_decay, "beta": (0.9, 0.999), | |
| }) | |
| param_group.append({ | |
| "params": no_wd_params, "lr": total_learning_rate, "eps": 1e-7, | |
| "weight_decay": 0.0, "beta": (0.9, 0.999), | |
| }) | |
| optimizer_disc = torch.optim.AdamW(param_group) | |
| optimizer_dict = {"optimizer_ae": optimizer_ae, "optimizer_disc": optimizer_disc} | |
| # setup the schedulers | |
| scheduler_ae = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer=optimizer_ae, max_lr=[total_learning_rate, total_learning_rate, multipled_learning_rate, multipled_learning_rate], | |
| total_steps=total_steps, pct_start=0.01, anneal_strategy="cos" | |
| ) | |
| scheduler_disc = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer=optimizer_disc, max_lr=[total_learning_rate, total_learning_rate], | |
| total_steps=total_steps, pct_start=0.01, anneal_strategy="cos" | |
| ) | |
| scheduler_dict = {"scheduler_ae": scheduler_ae, "scheduler_disc": scheduler_disc} | |
| # setup the scalers | |
| scaler_dict = {"scaler_ae": torch.GradScaler(enabled=config.use_amp), | |
| "scaler_disc": torch.GradScaler(enabled=config.use_amp)} | |
| L.log.info(f"Enable AMP: {config.use_amp}") | |
| return optimizer_dict, scheduler_dict, scaler_dict | |
| def pipeline_ae_disc( | |
| config, | |
| x: torch.Tensor, | |
| model: nn.Module, | |
| optimizers: dict, | |
| schedulers: dict, | |
| scalers: dict, | |
| ): | |
| # autoencoder step | |
| assert "optimizer_ae" in optimizers | |
| assert "scheduler_ae" in schedulers | |
| assert "scaler_ae" in scalers | |
| optimizer = optimizers["optimizer_ae"] | |
| scheduler = schedulers["scheduler_ae"] | |
| scaler = scalers["scaler_ae"] | |
| forward = partial(model, mode=0) | |
| _, (loss_ae_dict, indices) = _forward_backward(config, x, forward, model, optimizer, scheduler, scaler) | |
| log_per_step = loss_ae_dict | |
| log_per_epoch = {"indices": indices} | |
| # discriminator step | |
| assert "optimizer_disc" in optimizers | |
| assert "scheduler_disc" in schedulers | |
| assert "scaler_disc" in scalers | |
| optimizer = optimizers["optimizer_disc"] | |
| scheduler = schedulers["scheduler_disc"] | |
| scaler = scalers["scaler_disc"] | |
| forward = partial(model, mode=1) | |
| _, (loss_disc_dict, _) = _forward_backward(config, x, forward, model, optimizer, scheduler, scaler) | |
| log_per_step.update(loss_disc_dict) | |
| return log_per_step, log_per_epoch |