| from typing import Dict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import repeat | |
| from models.base import CaptionMetaMixin | |
| from utils.model_util import init | |
| class WmlEncoderKdWrapper(nn.Module, CaptionMetaMixin): | |
| def __init__(self, | |
| model: nn.Module, | |
| shared_dim: int, | |
| tchr_layer_to_dims: Dict[str, int], | |
| loss_type: str = "mse",): | |
| super().__init__() | |
| self.model = model | |
| self.tchr_layers = list(tchr_layer_to_dims.keys()) | |
| self.stdnt_qv_proj = nn.Linear(model.encoder.fc_emb_size, | |
| 2 * shared_dim) | |
| self.stdnt_qv_proj.apply(init) | |
| for layer, dim in tchr_layer_to_dims.items(): | |
| self.add_module(f'tchr_kv_proj_{layer}', nn.Linear(dim, 2 * shared_dim)) | |
| getattr(self, f'tchr_kv_proj_{layer}').apply(init) | |
| if loss_type == "mse": | |
| self.loss_fn = nn.MSELoss(reduction="none") | |
| def forward(self, input_dict: Dict): | |
| output_dict = self.model(input_dict) | |
| if "tchr_output" in input_dict: | |
| stdnt_emb = output_dict["fc_emb"] | |
| stdnt_qv = self.stdnt_qv_proj(stdnt_emb) | |
| stdnt_q, stdnt_v = torch.chunk(stdnt_qv, 2, dim=-1) | |
| tchr_output = input_dict["tchr_output"] | |
| layer_ks, layer_vs = [], [] | |
| for layer in self.tchr_layers: | |
| layer_kv = getattr(self, f'tchr_kv_proj_{layer}')(tchr_output[layer]) | |
| layer_k, layer_v = torch.chunk(layer_kv, 2, dim=-1) | |
| layer_ks.append(layer_k) | |
| layer_vs.append(layer_v) | |
| layer_ks = torch.stack(layer_ks, dim=1) | |
| layer_vs = torch.stack(layer_vs, dim=1) | |
| weights = torch.softmax(stdnt_q.unsqueeze(1) @ layer_ks.transpose(1, 2), dim=-1) | |
| stdnt_v = repeat(stdnt_v, 'b d -> b n d', n=len(self.tchr_layers)) | |
| loss = self.loss_fn(stdnt_v, layer_vs).mean(dim=-1, keepdim=True) | |
| loss = (weights @ loss).mean() | |
| output_dict["enc_kd_loss"] = loss | |
| return output_dict | |
| class MseEncoderKdWrapper(nn.Module, CaptionMetaMixin): | |
| def __init__(self, | |
| model: nn.Module, | |
| shared_dim: int, | |
| tchr_dim: int, | |
| use_tchr_proj: bool = True, | |
| l2_norm: bool = False, | |
| ): | |
| super().__init__() | |
| self.model = model | |
| self.use_tchr_proj = use_tchr_proj | |
| if not use_tchr_proj: | |
| assert shared_dim == tchr_dim | |
| self.tchr_dim = tchr_dim | |
| self.l2_norm = l2_norm | |
| if hasattr(model, "encoder"): | |
| self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, | |
| shared_dim) | |
| else: | |
| self.stdnt_proj = nn.Linear(model.fc_emb_size, | |
| shared_dim) | |
| self.stdnt_proj.apply(init) | |
| if use_tchr_proj: | |
| self.tchr_proj = nn.Linear(tchr_dim, shared_dim) | |
| self.tchr_proj.apply(init) | |
| else: | |
| self.tchr_proj = nn.Identity() | |
| def forward(self, input_dict: Dict): | |
| unsup = input_dict.get("unsup", False) | |
| if unsup is False: | |
| if self.use_tchr_proj: | |
| output_dict = self.model(input_dict) | |
| stdnt_emb = output_dict["fc_emb"] | |
| else: | |
| encoder_output = self.model.encoder(input_dict) | |
| stdnt_emb = encoder_output["fc_emb"] | |
| encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"]) | |
| encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"]) | |
| output_dict = self.model.forward_decoder(input_dict, encoder_output) | |
| else: | |
| output_dict = self.model.encoder(input_dict) | |
| stdnt_emb = output_dict["fc_emb"] | |
| if "tchr_output" in input_dict: | |
| stdnt_emb = self.stdnt_proj(stdnt_emb) | |
| tchr_emb = input_dict["tchr_output"]["embedding"] | |
| thcr_emb = self.tchr_proj(tchr_emb) | |
| if self.l2_norm: | |
| stdnt_emb = F.normalize(stdnt_emb, dim=-1) | |
| thcr_emb = F.normalize(thcr_emb, dim=-1) | |
| loss = F.mse_loss(stdnt_emb, thcr_emb) | |
| output_dict["enc_kd_loss"] = loss | |
| return output_dict | |
| class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin): | |
| def __init__(self, | |
| model: nn.Module, | |
| shared_dim: int, | |
| tchr_dim: int, | |
| ): | |
| super().__init__() | |
| self.model = model | |
| self.tchr_dim = tchr_dim | |
| if hasattr(model, "encoder"): | |
| self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, | |
| shared_dim) | |
| else: | |
| self.stdnt_proj = nn.Linear(model.fc_emb_size, | |
| shared_dim) | |
| self.stdnt_proj.apply(init) | |
| self.tchr_proj = nn.Linear(tchr_dim, shared_dim) | |
| self.tchr_proj.apply(init) | |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
| def forward(self, input_dict: Dict): | |
| unsup = input_dict.get("unsup", False) | |
| if unsup is False: | |
| output_dict = self.model(input_dict) | |
| else: | |
| output_dict = self.model.encoder(input_dict) | |
| if "tchr_output" in input_dict: | |
| stdnt_emb = output_dict["fc_emb"] | |
| stdnt_emb = self.stdnt_proj(stdnt_emb) | |
| tchr_emb = input_dict["tchr_output"]["embedding"] | |
| thcr_emb = self.tchr_proj(tchr_emb) | |
| stdnt_emb = F.normalize(stdnt_emb, dim=-1) | |
| thcr_emb = F.normalize(thcr_emb, dim=-1) | |
| unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1) | |
| logit = self.logit_scale * unscaled_logit | |
| label = torch.arange(logit.shape[0]).to(logit.device) | |
| loss1 = F.cross_entropy(logit, label) | |
| loss2 = F.cross_entropy(logit.transpose(0, 1), label) | |
| loss = (loss1 + loss2) / 2 | |
| output_dict["enc_kd_loss"] = loss | |
| return output_dict | |
| class ContraMseEncoderKdWrapper(nn.Module, CaptionMetaMixin): | |
| def __init__(self, | |
| model: nn.Module, | |
| shared_dim: int, | |
| tchr_dim: int, | |
| use_tchr_proj: bool = True, | |
| l2_norm: bool = False, | |
| ): | |
| super().__init__() | |
| self.model = model | |
| self.use_tchr_proj = use_tchr_proj | |
| if not use_tchr_proj: | |
| assert shared_dim == tchr_dim | |
| self.tchr_dim = tchr_dim | |
| self.l2_norm = l2_norm | |
| if hasattr(model, "encoder"): | |
| self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, | |
| shared_dim) | |
| else: | |
| self.stdnt_proj = nn.Linear(model.fc_emb_size, | |
| shared_dim) | |
| self.stdnt_proj.apply(init) | |
| if use_tchr_proj: | |
| self.tchr_proj = nn.Linear(tchr_dim, shared_dim) | |
| self.tchr_proj.apply(init) | |
| else: | |
| self.tchr_proj = nn.Identity() | |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
| def forward(self, input_dict: Dict): | |
| unsup = input_dict.get("unsup", False) | |
| if unsup is False: | |
| if self.use_tchr_proj: | |
| output_dict = self.model(input_dict) | |
| stdnt_emb = output_dict["fc_emb"] | |
| else: | |
| encoder_output = self.model.encoder(input_dict) | |
| stdnt_emb = encoder_output["fc_emb"] | |
| encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"]) | |
| encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"]) | |
| output_dict = self.model.forward_decoder(input_dict, encoder_output) | |
| else: | |
| output_dict = self.model.encoder(input_dict) | |
| stdnt_emb = output_dict["fc_emb"] | |
| if "tchr_output" in input_dict: | |
| stdnt_emb = self.stdnt_proj(stdnt_emb) | |
| tchr_emb = input_dict["tchr_output"]["embedding"] | |
| thcr_emb = self.tchr_proj(tchr_emb) | |
| if self.l2_norm: | |
| stdnt_emb = F.normalize(stdnt_emb, dim=-1) | |
| thcr_emb = F.normalize(thcr_emb, dim=-1) | |
| mse_loss = F.mse_loss(stdnt_emb, thcr_emb) | |
| stdnt_emb = F.normalize(stdnt_emb, dim=-1) | |
| thcr_emb = F.normalize(thcr_emb, dim=-1) | |
| unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1) | |
| logit = self.logit_scale * unscaled_logit | |
| label = torch.arange(logit.shape[0]).to(logit.device) | |
| loss1 = F.cross_entropy(logit, label) | |
| loss2 = F.cross_entropy(logit.transpose(0, 1), label) | |
| cntr_loss = (loss1 + loss2) / 2 | |
| output_dict["enc_kd_loss"] = mse_loss + cntr_loss | |
| return output_dict | |