File size: 5,009 Bytes
09d5196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e6a010
09d5196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e6a010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09d5196
 
7e6a010
 
 
 
 
 
 
09d5196
 
7e6a010
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""

ISNet model for transformers library

This file is automatically loaded when trust_remote_code=True is used

"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import PreTrainedModel, PretrainedConfig

# Import the ISNet model
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.isnet import ISNetDIS

class ISNetConfig(PretrainedConfig):
    """Configuration for ISNet model"""
    model_type = "isnet"
    
    def __init__(self, in_ch=3, out_ch=1, **kwargs):
        super().__init__(**kwargs)
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.num_labels = out_ch  # Required for AutoModelForImageSegmentation
        self.architectures = ["ISNetForImageSegmentation"]

class ISNetForImageSegmentation(PreTrainedModel):
    """Transformers-compatible ISNet model for image segmentation"""
    
    config_class = ISNetConfig
    base_model_prefix = "isnet"
    
    def __init__(self, config):
        super().__init__(config)
        self.isnet = ISNetDIS(in_ch=config.in_ch, out_ch=config.out_ch)
        
    def forward(self, pixel_values, labels=None, threshold=0.5):
        """Forward pass"""
        outputs = self.isnet(pixel_values)
        
        # ISNet returns a tuple: (segmentation_masks, feature_maps)
        if isinstance(outputs, tuple) and len(outputs) == 2:
            segmentation_masks = outputs[0]  # List of 6 sigmoid outputs
            feature_maps = outputs[1]  # List of 6 feature maps
            
            # Use the first mask (highest resolution) as the main output
            mask = segmentation_masks[0]  # Shape: [batch_size, 1, height, width]
            
            # Return in transformers format - just the mask
            return mask
        else:
            # Fallback for other output formats
            return outputs
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """Load model from pretrained weights"""
        from transformers.utils import cached_file
        
        config = kwargs.pop("config", None)
        if config is None:
            config = ISNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        
        model = cls(config)
        
        # Load the state dict
        if "state_dict" in kwargs:
            state_dict = kwargs["state_dict"]
        else:
            # Try to load from Hugging Face Hub
            try:
                # Try pytorch_model.bin first
                model_file = cached_file(
                    pretrained_model_name_or_path,
                    "pytorch_model.bin",
                    **kwargs
                )
                state_dict = torch.load(model_file, map_location="cpu")
            except:
                try:
                    # Try model.safetensors
                    model_file = cached_file(
                        pretrained_model_name_or_path,
                        "model.safetensors",
                        **kwargs
                    )
                    from safetensors import safe_open
                    with safe_open(model_file, framework="pt", device="cpu") as f:
                        state_dict = {key: f.get_tensor(key) for key in f.keys()}
                except:
                    # Fallback to the original model file
                    model_file = cached_file(
                        pretrained_model_name_or_path,
                        "supplyswap_isnet.pth",
                        **kwargs
                    )
                    state_dict = torch.load(model_file, map_location="cpu")
        
        # Handle different state dict formats
        if isinstance(state_dict, dict):
            # Check if the state dict has the expected keys
            if any(key.startswith('isnet.') for key in state_dict.keys()):
                # State dict already has the correct prefix
                pass
            elif any(key.startswith('conv_in.') or key.startswith('stage') for key in state_dict.keys()):
                # State dict is from the original ISNet model, needs to be wrapped
                wrapped_state_dict = {}
                for key, value in state_dict.items():
                    wrapped_state_dict[f"isnet.{key}"] = value
                state_dict = wrapped_state_dict
            else:
                # Try to load directly
                pass
        
        # Load the weights into the ISNet model
        try:
            model.isnet.load_state_dict(state_dict)
        except Exception as e:
            print(f"Warning: Could not load state dict directly: {e}")
            print("Attempting to load with strict=False...")
            model.isnet.load_state_dict(state_dict, strict=False)
        
        model.eval()
        
        return model