Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| from opendet.modeling.backbones import build_backbone | |
| from opendet.modeling.necks import build_neck | |
| from opendet.modeling.heads import build_head | |
| __all__ = ['BaseDetector'] | |
| class BaseDetector(nn.Module): | |
| def __init__(self, config): | |
| """the module for OCR. | |
| args: | |
| config (dict): the super parameters for module. | |
| """ | |
| super(BaseDetector, self).__init__() | |
| in_channels = config.get('in_channels', 3) | |
| self.use_wd = config.get('use_wd', True) | |
| # build backbone | |
| if 'Backbone' not in config or config['Backbone'] is None: | |
| self.use_backbone = False | |
| else: | |
| self.use_backbone = True | |
| config['Backbone']['in_channels'] = in_channels | |
| self.backbone = build_backbone(config['Backbone']) | |
| in_channels = self.backbone.out_channels | |
| # build neck | |
| if 'Neck' not in config or config['Neck'] is None: | |
| self.use_neck = False | |
| else: | |
| self.use_neck = True | |
| config['Neck']['in_channels'] = in_channels | |
| self.neck = build_neck(config['Neck']) | |
| in_channels = self.neck.out_channels | |
| # build head | |
| if 'Head' not in config or config['Head'] is None: | |
| self.use_head = False | |
| else: | |
| self.use_head = True | |
| config['Head']['in_channels'] = in_channels | |
| self.head = build_head(config['Head']) | |
| def no_weight_decay(self): | |
| if self.use_wd: | |
| if hasattr(self.backbone, 'no_weight_decay'): | |
| no_weight_decay = self.backbone.no_weight_decay() | |
| else: | |
| no_weight_decay = {} | |
| if hasattr(self.head, 'no_weight_decay'): | |
| no_weight_decay.update(self.head.no_weight_decay()) | |
| return no_weight_decay | |
| else: | |
| return {} | |
| def forward(self, x, data=None): | |
| if self.use_backbone: | |
| x = self.backbone(x) | |
| if self.use_neck: | |
| x = self.neck(x) | |
| if self.use_head: | |
| x = self.head(x, data=data) | |
| return x | |