Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from torch import nn, Tensor | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import List, Optional, Dict, Tuple | |
| from copy import deepcopy | |
| from .vit import vit_names_and_weights, _vit | |
| from .convnext import convnext_names_and_weights, _convnext | |
| from .resnet import resnet_names_and_weights, _resnet | |
| from .mobileclip import mobileclip_names_and_weights, _mobileclip | |
| from .utils import encode_text, optimize_text_prompts | |
| from ..utils import conv1x1 | |
| supported_models_and_weights = deepcopy(vit_names_and_weights) | |
| supported_models_and_weights.update(convnext_names_and_weights) | |
| supported_models_and_weights.update(resnet_names_and_weights) | |
| supported_models_and_weights.update(mobileclip_names_and_weights) | |
| class CLIP_EBC(nn.Module): | |
| def __init__( | |
| self, | |
| model_name: str, | |
| weight_name: str, | |
| block_size: Optional[int] = None, | |
| bins: Optional[List[Tuple[float, float]]] = None, | |
| bin_centers: Optional[List[float]] = None, | |
| zero_inflated: Optional[bool] = True, | |
| num_vpt: Optional[int] = None, | |
| vpt_drop: Optional[float] = None, | |
| input_size: Optional[int] = None, | |
| text_prompts: Optional[Dict[str, List[str]]] = None, | |
| norm: Optional[str] = "none", | |
| act: Optional[str] = "none", | |
| ) -> None: | |
| super().__init__() | |
| if "mobileclip" in model_name.lower() or "vit" in model_name.lower(): | |
| model_name = model_name.replace("_", "-") | |
| assert model_name in supported_models_and_weights, f"Model name should be one of {list(supported_models_and_weights.keys())}, but got {model_name}." | |
| assert weight_name in supported_models_and_weights[model_name], f"Pretrained should be one of {supported_models_and_weights[model_name]}, but got {weight_name}." | |
| assert len(bins) == len(bin_centers), f"Expected bins and bin_centers to have the same length, got {len(bins)} and {len(bin_centers)}" | |
| assert len(bins) >= 2, f"Expected at least 2 bins, got {len(bins)}" | |
| assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}" | |
| bins = [(float(b[0]), float(b[1])) for b in bins] | |
| assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}" | |
| self.model_name = model_name | |
| self.weight_name = weight_name | |
| self.block_size = block_size | |
| self.bins = bins | |
| self.register_buffer("bin_centers", torch.tensor(bin_centers, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1)) | |
| self.zero_inflated = zero_inflated | |
| self.text_prompts = text_prompts | |
| # Image encoder | |
| if model_name in vit_names_and_weights: | |
| assert num_vpt is not None and num_vpt >= 0, f"Number of VPT tokens should be greater than 0, but got {num_vpt}." | |
| vpt_drop = 0. if vpt_drop is None else vpt_drop | |
| self.backbone = _vit( | |
| model_name=model_name, | |
| weight_name=weight_name, | |
| num_vpt=num_vpt, | |
| vpt_drop=vpt_drop, | |
| block_size=block_size, | |
| input_size=(input_size, input_size), | |
| norm=norm, | |
| act=act | |
| ) | |
| elif model_name in convnext_names_and_weights: | |
| self.backbone = _convnext( | |
| model_name=model_name, | |
| weight_name=weight_name, | |
| block_size=block_size, | |
| norm=norm, | |
| act=act | |
| ) | |
| elif model_name in resnet_names_and_weights: | |
| self.backbone = _resnet( | |
| model_name=model_name, | |
| weight_name=weight_name, | |
| block_size=block_size, | |
| norm=norm, | |
| act=act | |
| ) | |
| elif model_name in mobileclip_names_and_weights: | |
| self.backbone = _mobileclip( | |
| model_name=model_name, | |
| weight_name=weight_name, | |
| block_size=block_size, | |
| norm=norm, | |
| act=act | |
| ) | |
| self._build_text_feats() | |
| self._build_head() | |
| def _build_text_feats(self) -> None: | |
| model_name, weight_name = self.model_name, self.weight_name | |
| text_prompts = self.text_prompts | |
| if text_prompts is None: | |
| bins = [b[0] if b[0] == b[1] else b for b in self.bins] # if the bin is a single value (e.g., [0, 0]), use that value | |
| if self.zero_inflated: # separate 0 from the rest | |
| assert bins[0] == 0, f"Expected the first bin to be 0, got {bins[0]}." | |
| bins_pi = [0, (1, float("inf"))] | |
| bins_lambda = bins[1:] | |
| pi_text_prompts = optimize_text_prompts(model_name, weight_name, bins_pi) | |
| lambda_text_prompts = optimize_text_prompts(model_name, weight_name, bins_lambda) | |
| self.text_prompts = {"pi": pi_text_prompts, "lambda": lambda_text_prompts} | |
| pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts) | |
| lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts) | |
| pi_text_feats.requires_grad = False | |
| lambda_text_feats.requires_grad = False | |
| self.register_buffer("pi_text_feats", pi_text_feats) | |
| self.register_buffer("lambda_text_feats", lambda_text_feats) | |
| else: | |
| text_prompts = optimize_text_prompts(model_name, weight_name, bins) | |
| self.text_prompts = text_prompts | |
| text_feats = encode_text(model_name, weight_name, text_prompts) | |
| text_feats.requires_grad = False | |
| self.register_buffer("text_feats", text_feats) | |
| else: | |
| if self.zero_inflated: | |
| assert "pi" in text_prompts and "lambda" in text_prompts, f"Expected text_prompts to have keys 'pi' and 'lambda', got {text_prompts.keys()}." | |
| pi_text_prompts = text_prompts["pi"] | |
| lambda_text_prompts = text_prompts["lambda"] | |
| pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts) | |
| lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts) | |
| pi_text_feats.requires_grad = False | |
| lambda_text_feats.requires_grad = False | |
| self.register_buffer("pi_text_feats", pi_text_feats) | |
| self.register_buffer("lambda_text_feats", lambda_text_feats) | |
| else: | |
| text_feats = encode_text(model_name, weight_name, text_prompts) | |
| text_feats.requires_grad = False | |
| self.register_buffer("text_feats", text_feats) | |
| def _build_head(self) -> None: | |
| in_channels = self.backbone.in_features | |
| out_channels = self.backbone.out_features | |
| if self.zero_inflated: | |
| self.pi_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True) | |
| self.lambda_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True) | |
| self.pi_head = conv1x1(in_channels, out_channels, bias=False) | |
| self.lambda_head = conv1x1(in_channels, out_channels, bias=False) | |
| else: | |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True) | |
| self.head = conv1x1(in_channels, out_channels, bias=False) | |
| def forward(self, image: Tensor): | |
| image_feats = self.backbone(image) | |
| # image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
| if self.zero_inflated: | |
| pi_image_feats, lambda_image_feats = self.pi_head(image_feats), self.lambda_head(image_feats) | |
| pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
| lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C) | |
| pi_text_feats, lambda_text_feats = self.pi_text_feats, self.lambda_text_feats | |
| pi_logit_scale, lambda_logit_scale = self.pi_logit_scale.exp(), self.lambda_logit_scale.exp() | |
| pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image | |
| lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image | |
| pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W) | |
| lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W) | |
| lambda_map = (lambda_logit_map.softmax(dim=1) * self.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W) | |
| # pi_logit_map.softmax(dim=1)[:, 0] is the probability of zeros | |
| den_map = pi_logit_map.softmax(dim=1)[:, 1:] * lambda_map # (B, 1, H, W) | |
| if self.training: | |
| return pi_logit_map, lambda_logit_map, lambda_map, den_map | |
| else: | |
| return den_map | |
| else: | |
| image_feats = self.head(image_feats) | |
| image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) | |
| text_feats = self.text_feats | |
| logit_scale = self.logit_scale.exp() | |
| logit_map = logit_scale * image_feats @ text_feats.t() # (B, H, W, N), logits per image | |
| logit_map = logit_map.permute(0, 3, 1, 2) # (B, N, H, W) | |
| den_map = (logit_map.softmax(dim=1) * self.bin_centers).sum(dim=1, keepdim=True) # (B, 1, H, W) | |
| if self.training: | |
| return logit_map, den_map | |
| else: | |
| return den_map | |
| def _clip_ebc( | |
| model_name: str, | |
| weight_name: str, | |
| block_size: Optional[int] = None, | |
| bins: Optional[List[Tuple[float, float]]] = None, | |
| bin_centers: Optional[List[float]] = None, | |
| zero_inflated: Optional[bool] = True, | |
| num_vpt: Optional[int] = None, | |
| vpt_drop: Optional[float] = None, | |
| input_size: Optional[int] = None, | |
| text_prompts: Optional[List[str]] = None, | |
| norm: Optional[str] = "none", | |
| act: Optional[str] = "none", | |
| ) -> CLIP_EBC: | |
| return CLIP_EBC( | |
| model_name=model_name, | |
| weight_name=weight_name, | |
| block_size=block_size, | |
| bins=bins, | |
| bin_centers=bin_centers, | |
| zero_inflated=zero_inflated, | |
| num_vpt=num_vpt, | |
| vpt_drop=vpt_drop, | |
| input_size=input_size, | |
| text_prompts=text_prompts, | |
| norm=norm, | |
| act=act, | |
| ) |