| | import os |
| | import gc |
| | import torch |
| |
|
| | import torch.nn.functional as F |
| | import lightning as pl |
| |
|
| | from typing import Optional |
| | from transformers import AutoModelForMaskedLM, AutoTokenizer |
| |
|
| | from src.utils.model_utils import _print |
| | from src.utils.optimizer_utils import get_optimizer, get_scheduler |
| |
|
| |
|
| | class MembraneDiffusion(pl.LightningModule): |
| | def __init__(self, config): |
| | """ |
| | Args: |
| | config (OmegaConf): config.yaml file with all training parameters |
| | """ |
| | super().__init__() |
| | self.config = config |
| | self.save_hyperparameters(logger=True) |
| |
|
| | self.model = AutoModelForMaskedLM.from_pretrained(config.lm.pretrained_evoflow, trust_remote_code=True) |
| | self.tokenizer = AutoTokenizer.from_pretrained(config.lm.pretrained_evoflow) |
| |
|
| | self.mask_id = self.tokenizer.mask_token_id |
| | self.pad_id = self.tokenizer.pad_token_id |
| |
|
| | def forward(self, input_ids, attention_mask, guidance: Optional[bool] = False): |
| | """ |
| | Forward pass through language model. |
| | |
| | Args: |
| | - input_ids (torch.Tensor): [B, L], token ids |
| | - attention_mask (torch.Tensor): [B, L], pad/non-pad binary mask |
| | Returns: |
| | - logits (torch.Tensor): [B, L, V], unnormalized model outputs |
| | """ |
| | return self.model(input_ids=input_ids, attention_mask=attention_mask).logits |
| |
|
| | |
| | def step(self, batch): |
| | labels = batch['input_ids'] |
| |
|
| | |
| | t1 = self.sample_t(labels) |
| | xt, _ = self.noise_x0(labels, t1, maskable_mask=self.is_maskable(labels)) |
| | logits = self.forward(input_ids=xt, attention_mask=batch['attention_mask']) |
| |
|
| | |
| | weight = self.get_weight(t1, weight_type=self.config.lm.weight_type) |
| | loss_out = self.compute_loss(logits, labels, weight) |
| |
|
| | self.cleanup() |
| | return loss_out['loss'], loss_out['ppl'] |
| | |
| | def sample_t(self, labels, rdm_coupling=False): |
| | """ |
| | Sample diffusion timesteps. Non-coupling RDM only uses one timestep (t1). |
| | """ |
| | timesteps = torch.randint( |
| | 1, |
| | self.config.lm.num_diffusion_timesteps + 1, |
| | (2 if rdm_coupling else 1) * (labels.size(0),), |
| | device=labels.device |
| | ) |
| |
|
| | if rdm_coupling: |
| | return timesteps.chunk(2) |
| | return timesteps |
| |
|
| | def noise_x0(self, x0, t1, maskable_mask): |
| | """ |
| | Apply noise to the initial sequence x0. |
| | """ |
| | u = torch.rand_like(x0, dtype=torch.float) |
| | t1_mask = (u < (t1 / self.config.lm.num_diffusion_timesteps)[:, None]) & maskable_mask |
| | x_t1 = x0.masked_fill(t1_mask, self.mask_id) |
| | return x_t1, t1_mask |
| |
|
| | def get_weight(self, t, weight_type): |
| | """ |
| | Compute the weighting factor for the RDM-derived loss (weighted cross-entropy). |
| | """ |
| | num_timesteps = self.config.lm.num_diffusion_timesteps |
| | weight = { |
| | "linear": (num_timesteps - (t - 1)), |
| | "constant": num_timesteps * torch.ones_like(t), |
| | }[weight_type][:, None].float() / num_timesteps |
| | return weight.squeeze() |
| |
|
| | def compute_loss(self, logits, labels, weight): |
| | """ |
| | Compute the cross entropy loss per sample. |
| | First, compute the per-token loss (with no reduction), then reduce over the sequence length for each sample. |
| | Finally, average over the batch. |
| | |
| | Args: |
| | logits (torch.Tensor): [B, L, vocab_size], unnormalized model outputs |
| | labels (torch.Tensor): [B, L], target labels (with padding tokens as -100) |
| | weight (torch.Tensor): [B, 1], per-sample weight for loss calculation |
| | Returns: |
| | loss (torch.Tensor): Averaged loss over the batch |
| | logging_output (torch.Tensor): Dictionary of values for logging |
| | """ |
| |
|
| | loss_token = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | labels.view(-1), |
| | reduction='none', |
| | ignore_index=self.pad_id, |
| | ) |
| | |
| | loss_token = loss_token.view(labels.size(0), labels.size(1)) |
| | valid_mask = (labels != self.pad_id) |
| | |
| | sample_loss = (loss_token * valid_mask.float()).sum(dim=1) / valid_mask.float().sum(dim=1).clamp(min=1) |
| | sample_loss *= weight |
| | ppl = torch.exp(sample_loss) |
| |
|
| | return {'ppl': ppl.mean(), 'loss': sample_loss.mean()} |
| | |
| |
|
| | |
| | def training_step(self, batch): |
| | loss, ppl = self.step(batch) |
| | self.log("train/loss", loss.item(), on_step=True, on_epoch=False, prog_bar=True) |
| | self.log("train/ppl", ppl.item(), on_step=True, on_epoch=False, prog_bar=False) |
| | return loss |
| | |
| | def validation_step(self, batch): |
| | loss, ppl = self.step(batch) |
| | self.cleanup() |
| | self.log("val/loss", loss.item(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| | self.log("val/ppl", ppl.item(), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| | return loss |
| |
|
| | def test_step(self, batch): |
| | loss, ppl = self.step(batch) |
| | self.cleanup() |
| | self.log('test/loss', loss.item(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| | self.log("test/ppl", ppl.item(), on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) |
| | return loss |
| |
|
| |
|
| | |
| | def is_maskable(self, input_ids: torch.Tensor): |
| | return ( |
| | (input_ids != self.tokenizer.pad_token_id) |
| | & (input_ids != self.tokenizer.cls_token_id) |
| | & (input_ids != self.tokenizer.eos_token_id) |
| | ) |
| |
|
| | def configure_optimizers(self): |
| | """ |
| | Choosing which optimizer and lr scheduler to use. |
| | """ |
| | optimizer = get_optimizer(self.config, self.model.parameters()) |
| | lr_scheduler, extra_kwargs = get_scheduler(self.config, optimizer) |
| | return { |
| | "optimizer": optimizer, |
| | "lr_scheduler": {"scheduler": lr_scheduler, **extra_kwargs}, |
| | } |
| |
|
| | def validate_config(self): |
| | assert os.path.isdir(self.config.checkpointing.save_dir), "invalid checkpointing path" |
| | assert self.config.training.mode in ["train", "test", "resume_from_checkpoint"], "invalid mode" |
| |
|
| | def get_state_dict(self, ckpt_path): |
| | def remove_model_prefix(state_dict): |
| | for k, v in state_dict.items(): |
| | if "model." in k: |
| | k.replace('model.', '') |
| | return state_dict |
| |
|
| | checkpoint = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu') |
| | state_dict = checkpoint.get("state_dict", checkpoint) |
| |
|
| | if any(k.startswith("model.") for k in state_dict.keys()): |
| | state_dict = remove_model_prefix(state_dict) |
| | |
| | return state_dict |
| |
|
| | def cleanup(self): |
| | torch.cuda.empty_cache() |
| | gc.collect() |