Spaces:
Build error
Build error
| """Modified from https://github.com/mlfoundations/open_flamingo""" | |
| import time | |
| from contextlib import suppress | |
| import torch | |
| from tqdm import tqdm | |
| def get_cast_dtype(precision: str): | |
| cast_dtype = None | |
| if precision == "bf16": | |
| cast_dtype = torch.bfloat16 | |
| elif precision == "fp16": | |
| cast_dtype = torch.float16 | |
| return cast_dtype | |
| def get_autocast(precision): | |
| if precision == "amp": | |
| return torch.cuda.amp.autocast | |
| elif precision == "amp_bfloat16" or precision == "amp_bf16": | |
| # amp_bfloat16 is more stable than amp float16 for clip training | |
| return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) | |
| else: | |
| return suppress | |
| def train_one_epoch( | |
| args, | |
| model, | |
| epoch, | |
| laion_loader, | |
| mmc4_loader, | |
| tokenizer, | |
| optimizer, | |
| lr_scheduler, | |
| device_id, | |
| wandb, | |
| ): | |
| num_batches_per_epoch_laion = laion_loader.num_batches | |
| num_batches_per_epoch_mmc4 = mmc4_loader.num_batches | |
| assert ( | |
| num_batches_per_epoch_laion == num_batches_per_epoch_mmc4 | |
| ), "Number of batches in laion and mmc4 datasets must be the same" | |
| num_batches_per_epoch = num_batches_per_epoch_mmc4 | |
| total_training_steps = num_batches_per_epoch * args.num_epochs | |
| autocast = get_autocast(args.precision) | |
| cast_dtype = get_cast_dtype(args.precision) | |
| media_token_id = tokenizer("<image>", add_special_tokens=False)["input_ids"][-1] | |
| endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1] | |
| model.train() | |
| # setup logging | |
| step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum) | |
| data_time_m = ( | |
| AverageMeter() | |
| ) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum) | |
| end = time.time() | |
| # loop through dataloader | |
| for num_steps, (batch_laion, batch_mmc4) in tqdm( | |
| enumerate(zip(laion_loader, mmc4_loader)), | |
| disable=args.rank != 0, | |
| total=total_training_steps, | |
| initial=(epoch * num_batches_per_epoch), | |
| ): | |
| data_time_m.update(time.time() - end) | |
| global_step = num_steps + epoch * num_batches_per_epoch | |
| #### LAION FORWARD PASS #### | |
| images = batch_laion[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1) | |
| input_ids = batch_laion[1][0].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| attention_mask = batch_laion[1][1].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| labels = input_ids.clone() | |
| labels[labels == tokenizer.pad_token_id] = -100 | |
| labels[:, 0] = -100 | |
| labels[labels == media_token_id] = -100 | |
| labels.to(device_id) | |
| with autocast(): | |
| loss_laion = model( | |
| vision_x=images, | |
| lang_x=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| )[0] | |
| divided_loss_laion = loss_laion / args.gradient_accumulation_steps | |
| #### C4 FORWARD PASS #### | |
| images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(2) | |
| input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1) | |
| attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1) | |
| # NOTE: irena: expected shape of clip_text_input_ids / attention_mask is (N, I, max_seq_len) | |
| labels = input_ids.clone() | |
| labels[labels == tokenizer.pad_token_id] = -100 | |
| labels[:, 0] = -100 | |
| for i in range(labels.shape[0]): | |
| # remove loss for any token before the first <image> token | |
| label_idx = 0 | |
| while label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id: | |
| labels[i][label_idx] = -100 | |
| label_idx += 1 | |
| # get index of all endofchunk tokens in the sequence | |
| endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0] | |
| for endofchunk_idx in endofchunk_idxs: | |
| token_idx = endofchunk_idx + 1 | |
| while token_idx < labels.shape[1] and labels[i][token_idx] != media_token_id: | |
| labels[i][token_idx] = -100 | |
| token_idx += 1 | |
| labels[labels == media_token_id] = -100 | |
| labels.to(device_id) | |
| with autocast(): | |
| loss_mmc4 = model( | |
| vision_x=images, | |
| lang_x=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| )[0] | |
| # if loss is nan, skip this batch | |
| if torch.isnan(loss_mmc4): | |
| print("loss is nan, skipping this batch") | |
| print("input_ids: ", tokenizer.batch_decode(input_ids)) | |
| print("labels: ", labels) | |
| print("images: ", images) | |
| optimizer.zero_grad() | |
| continue | |
| divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps | |
| #### BACKWARD PASS #### | |
| loss = divided_loss_laion * args.loss_multiplier_laion + divided_loss_mmc4 * args.loss_multiplier_mmc4 | |
| loss.backward() | |
| #### MASK GRADIENTS FOR EMBEDDINGS #### | |
| # Note (anas): Do not apply weight decay to embeddings as it will break this function. | |
| def mask_embedding(m): | |
| if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad: | |
| zero_mask = torch.zeros_like(m.weight.grad) | |
| zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) | |
| zero_mask[endofchunk_token_id] = torch.ones_like(zero_mask[endofchunk_token_id]) | |
| m.weight.grad = m.weight.grad * zero_mask | |
| model.apply(mask_embedding) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| # step optimizer and log | |
| if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1): | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| # step time and reset end outside of rank 0 | |
| step_time_m.update(time.time() - end) | |
| end = time.time() | |
| if args.rank == 0 and args.report_to_wandb: | |
| # compute within rank 0 | |
| laion_samples_per_second = ( | |
| args.gradient_accumulation_steps * args.batch_size_laion * args.world_size / step_time_m.val | |
| ) | |
| laion_samples_per_second_per_gpu = ( | |
| args.gradient_accumulation_steps * args.batch_size_laion / step_time_m.val | |
| ) | |
| c4_samples_per_second = ( | |
| args.gradient_accumulation_steps * args.batch_size_mmc4 * args.world_size / step_time_m.val | |
| ) | |
| c4_samples_per_second_per_gpu = ( | |
| args.gradient_accumulation_steps * args.batch_size_mmc4 / step_time_m.val | |
| ) | |
| wandb.log( | |
| { | |
| "data_time": data_time_m.avg, | |
| "step_time": step_time_m.avg, | |
| "laion_samples_per_second": laion_samples_per_second, | |
| "laion_samples_per_second_per_gpu": laion_samples_per_second_per_gpu, | |
| "c4_samples_per_second": c4_samples_per_second, | |
| "c4_samples_per_second_per_gpu": c4_samples_per_second_per_gpu, | |
| "lr": optimizer.param_groups[0]["lr"], | |
| }, | |
| commit=False, | |
| ) | |
| step_time_m.reset() | |
| data_time_m.reset() | |
| wandb.log( | |
| { | |
| "loss_laion": divided_loss_laion.item(), | |
| "global_step": global_step, | |
| }, | |
| commit=False, | |
| ) | |
| wandb.log( | |
| {"loss_mmc4": divided_loss_mmc4.item(), "global_step": global_step}, | |
| commit=True, | |
| ) | |
| # Log loss to console | |
| if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0: | |
| print( | |
| f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss LAION: {loss_laion.item():.3f} // Loss MMC4: {loss_mmc4.item():.3f}" | |
| ) | |
| def get_checkpoint(model: torch.nn.Module): | |
| state_dict = model.state_dict() | |
| parameters = {k: v for k, v in model.named_parameters()} | |
| # remove duplicate parameters | |
| duplicate_keys = set(state_dict.keys()) - set(parameters.keys()) | |
| for k in duplicate_keys: | |
| del state_dict[k] | |
| # remove non-grad parameters | |
| for name, p in parameters.items(): | |
| if not p.requires_grad: | |
| del state_dict[name] | |
| return state_dict | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |