|
|
"""
|
|
|
ISNet model for transformers library
|
|
|
This file is automatically loaded when trust_remote_code=True is used
|
|
|
"""
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import numpy as np
|
|
|
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
|
|
|
|
|
|
import sys
|
|
|
import os
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
from models.isnet import ISNetDIS
|
|
|
|
|
|
class ISNetConfig(PretrainedConfig):
|
|
|
"""Configuration for ISNet model"""
|
|
|
model_type = "isnet"
|
|
|
|
|
|
def __init__(self, in_ch=3, out_ch=1, **kwargs):
|
|
|
super().__init__(**kwargs)
|
|
|
self.in_ch = in_ch
|
|
|
self.out_ch = out_ch
|
|
|
self.num_labels = out_ch
|
|
|
self.architectures = ["ISNetForImageSegmentation"]
|
|
|
|
|
|
class ISNetForImageSegmentation(PreTrainedModel):
|
|
|
"""Transformers-compatible ISNet model for image segmentation"""
|
|
|
|
|
|
config_class = ISNetConfig
|
|
|
base_model_prefix = "isnet"
|
|
|
|
|
|
def __init__(self, config):
|
|
|
super().__init__(config)
|
|
|
self.isnet = ISNetDIS(in_ch=config.in_ch, out_ch=config.out_ch)
|
|
|
|
|
|
def forward(self, pixel_values, labels=None, threshold=0.5):
|
|
|
"""Forward pass"""
|
|
|
outputs = self.isnet(pixel_values)
|
|
|
|
|
|
|
|
|
if isinstance(outputs, tuple) and len(outputs) == 2:
|
|
|
segmentation_masks = outputs[0]
|
|
|
feature_maps = outputs[1]
|
|
|
|
|
|
|
|
|
mask = segmentation_masks[0]
|
|
|
|
|
|
|
|
|
return mask
|
|
|
else:
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
@classmethod
|
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
|
"""Load model from pretrained weights"""
|
|
|
from transformers.utils import cached_file
|
|
|
|
|
|
config = kwargs.pop("config", None)
|
|
|
if config is None:
|
|
|
config = ISNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
|
|
|
|
model = cls(config)
|
|
|
|
|
|
|
|
|
if "state_dict" in kwargs:
|
|
|
state_dict = kwargs["state_dict"]
|
|
|
else:
|
|
|
|
|
|
try:
|
|
|
|
|
|
model_file = cached_file(
|
|
|
pretrained_model_name_or_path,
|
|
|
"pytorch_model.bin",
|
|
|
**kwargs
|
|
|
)
|
|
|
state_dict = torch.load(model_file, map_location="cpu")
|
|
|
except:
|
|
|
try:
|
|
|
|
|
|
model_file = cached_file(
|
|
|
pretrained_model_name_or_path,
|
|
|
"model.safetensors",
|
|
|
**kwargs
|
|
|
)
|
|
|
from safetensors import safe_open
|
|
|
with safe_open(model_file, framework="pt", device="cpu") as f:
|
|
|
state_dict = {key: f.get_tensor(key) for key in f.keys()}
|
|
|
except:
|
|
|
|
|
|
model_file = cached_file(
|
|
|
pretrained_model_name_or_path,
|
|
|
"supplyswap_isnet.pth",
|
|
|
**kwargs
|
|
|
)
|
|
|
state_dict = torch.load(model_file, map_location="cpu")
|
|
|
|
|
|
|
|
|
if isinstance(state_dict, dict):
|
|
|
|
|
|
if any(key.startswith('isnet.') for key in state_dict.keys()):
|
|
|
|
|
|
pass
|
|
|
elif any(key.startswith('conv_in.') or key.startswith('stage') for key in state_dict.keys()):
|
|
|
|
|
|
wrapped_state_dict = {}
|
|
|
for key, value in state_dict.items():
|
|
|
wrapped_state_dict[f"isnet.{key}"] = value
|
|
|
state_dict = wrapped_state_dict
|
|
|
else:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
try:
|
|
|
model.isnet.load_state_dict(state_dict)
|
|
|
except Exception as e:
|
|
|
print(f"Warning: Could not load state dict directly: {e}")
|
|
|
print("Attempting to load with strict=False...")
|
|
|
model.isnet.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
return model |