isnet-background-remover / modeling_isnet.py
mateenahmed's picture
Upload modeling_isnet.py with huggingface_hub
7e6a010 verified
"""
ISNet model for transformers library
This file is automatically loaded when trust_remote_code=True is used
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import PreTrainedModel, PretrainedConfig
# Import the ISNet model
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.isnet import ISNetDIS
class ISNetConfig(PretrainedConfig):
"""Configuration for ISNet model"""
model_type = "isnet"
def __init__(self, in_ch=3, out_ch=1, **kwargs):
super().__init__(**kwargs)
self.in_ch = in_ch
self.out_ch = out_ch
self.num_labels = out_ch # Required for AutoModelForImageSegmentation
self.architectures = ["ISNetForImageSegmentation"]
class ISNetForImageSegmentation(PreTrainedModel):
"""Transformers-compatible ISNet model for image segmentation"""
config_class = ISNetConfig
base_model_prefix = "isnet"
def __init__(self, config):
super().__init__(config)
self.isnet = ISNetDIS(in_ch=config.in_ch, out_ch=config.out_ch)
def forward(self, pixel_values, labels=None, threshold=0.5):
"""Forward pass"""
outputs = self.isnet(pixel_values)
# ISNet returns a tuple: (segmentation_masks, feature_maps)
if isinstance(outputs, tuple) and len(outputs) == 2:
segmentation_masks = outputs[0] # List of 6 sigmoid outputs
feature_maps = outputs[1] # List of 6 feature maps
# Use the first mask (highest resolution) as the main output
mask = segmentation_masks[0] # Shape: [batch_size, 1, height, width]
# Return in transformers format - just the mask
return mask
else:
# Fallback for other output formats
return outputs
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""Load model from pretrained weights"""
from transformers.utils import cached_file
config = kwargs.pop("config", None)
if config is None:
config = ISNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
# Load the state dict
if "state_dict" in kwargs:
state_dict = kwargs["state_dict"]
else:
# Try to load from Hugging Face Hub
try:
# Try pytorch_model.bin first
model_file = cached_file(
pretrained_model_name_or_path,
"pytorch_model.bin",
**kwargs
)
state_dict = torch.load(model_file, map_location="cpu")
except:
try:
# Try model.safetensors
model_file = cached_file(
pretrained_model_name_or_path,
"model.safetensors",
**kwargs
)
from safetensors import safe_open
with safe_open(model_file, framework="pt", device="cpu") as f:
state_dict = {key: f.get_tensor(key) for key in f.keys()}
except:
# Fallback to the original model file
model_file = cached_file(
pretrained_model_name_or_path,
"supplyswap_isnet.pth",
**kwargs
)
state_dict = torch.load(model_file, map_location="cpu")
# Handle different state dict formats
if isinstance(state_dict, dict):
# Check if the state dict has the expected keys
if any(key.startswith('isnet.') for key in state_dict.keys()):
# State dict already has the correct prefix
pass
elif any(key.startswith('conv_in.') or key.startswith('stage') for key in state_dict.keys()):
# State dict is from the original ISNet model, needs to be wrapped
wrapped_state_dict = {}
for key, value in state_dict.items():
wrapped_state_dict[f"isnet.{key}"] = value
state_dict = wrapped_state_dict
else:
# Try to load directly
pass
# Load the weights into the ISNet model
try:
model.isnet.load_state_dict(state_dict)
except Exception as e:
print(f"Warning: Could not load state dict directly: {e}")
print("Attempting to load with strict=False...")
model.isnet.load_state_dict(state_dict, strict=False)
model.eval()
return model