""" ISNet model for transformers library This file is automatically loaded when trust_remote_code=True is used """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from transformers import PreTrainedModel, PretrainedConfig # Import the ISNet model import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.isnet import ISNetDIS class ISNetConfig(PretrainedConfig): """Configuration for ISNet model""" model_type = "isnet" def __init__(self, in_ch=3, out_ch=1, **kwargs): super().__init__(**kwargs) self.in_ch = in_ch self.out_ch = out_ch self.num_labels = out_ch # Required for AutoModelForImageSegmentation self.architectures = ["ISNetForImageSegmentation"] class ISNetForImageSegmentation(PreTrainedModel): """Transformers-compatible ISNet model for image segmentation""" config_class = ISNetConfig base_model_prefix = "isnet" def __init__(self, config): super().__init__(config) self.isnet = ISNetDIS(in_ch=config.in_ch, out_ch=config.out_ch) def forward(self, pixel_values, labels=None, threshold=0.5): """Forward pass""" outputs = self.isnet(pixel_values) # ISNet returns a tuple: (segmentation_masks, feature_maps) if isinstance(outputs, tuple) and len(outputs) == 2: segmentation_masks = outputs[0] # List of 6 sigmoid outputs feature_maps = outputs[1] # List of 6 feature maps # Use the first mask (highest resolution) as the main output mask = segmentation_masks[0] # Shape: [batch_size, 1, height, width] # Return in transformers format - just the mask return mask else: # Fallback for other output formats return outputs @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """Load model from pretrained weights""" from transformers.utils import cached_file config = kwargs.pop("config", None) if config is None: config = ISNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) # Load the state dict if "state_dict" in kwargs: state_dict = kwargs["state_dict"] else: # Try to load from Hugging Face Hub try: # Try pytorch_model.bin first model_file = cached_file( pretrained_model_name_or_path, "pytorch_model.bin", **kwargs ) state_dict = torch.load(model_file, map_location="cpu") except: try: # Try model.safetensors model_file = cached_file( pretrained_model_name_or_path, "model.safetensors", **kwargs ) from safetensors import safe_open with safe_open(model_file, framework="pt", device="cpu") as f: state_dict = {key: f.get_tensor(key) for key in f.keys()} except: # Fallback to the original model file model_file = cached_file( pretrained_model_name_or_path, "supplyswap_isnet.pth", **kwargs ) state_dict = torch.load(model_file, map_location="cpu") # Handle different state dict formats if isinstance(state_dict, dict): # Check if the state dict has the expected keys if any(key.startswith('isnet.') for key in state_dict.keys()): # State dict already has the correct prefix pass elif any(key.startswith('conv_in.') or key.startswith('stage') for key in state_dict.keys()): # State dict is from the original ISNet model, needs to be wrapped wrapped_state_dict = {} for key, value in state_dict.items(): wrapped_state_dict[f"isnet.{key}"] = value state_dict = wrapped_state_dict else: # Try to load directly pass # Load the weights into the ISNet model try: model.isnet.load_state_dict(state_dict) except Exception as e: print(f"Warning: Could not load state dict directly: {e}") print("Attempting to load with strict=False...") model.isnet.load_state_dict(state_dict, strict=False) model.eval() return model