| from transformers import PreTrainedModel, PretrainedConfig | |
| import torch.nn as nn | |
| from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights | |
| class CheckboxConfig(PretrainedConfig): | |
| model_type = "checkbox-classifier" | |
| def __init__(self, num_labels=2, dropout_rate=0.3, **kwargs): | |
| super().__init__(num_labels=num_labels, **kwargs) | |
| self.dropout_rate = dropout_rate | |
| class CheckboxClassifier(PreTrainedModel): | |
| config_class = CheckboxConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.backbone = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1) | |
| num_features = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Sequential( | |
| nn.Dropout(config.dropout_rate), | |
| nn.Linear(num_features, 512), | |
| nn.SiLU(inplace=True), | |
| nn.BatchNorm1d(512), | |
| nn.Dropout(config.dropout_rate), | |
| nn.Linear(512, 256), | |
| nn.SiLU(inplace=True), | |
| nn.BatchNorm1d(256), | |
| nn.Dropout(config.dropout_rate/2), | |
| nn.Linear(256, config.num_labels) | |
| ) | |
| def forward(self, pixel_values): | |
| outputs = self.backbone(pixel_values) | |
| return {"logits": outputs} | |