Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from nets.multi_headed_attention import MultiHeadAttention | |
| import math | |
| class SkipConnection(nn.Module): | |
| def __init__(self, module, use_mask=True): | |
| super(SkipConnection, self).__init__() | |
| self.use_mask = use_mask | |
| self.module = module | |
| def forward(self, input): | |
| if isinstance(input, tuple): | |
| if len(input) > 1: | |
| input, mask = input[0], input[1] | |
| else: | |
| input = input[0] | |
| mask = None | |
| else: | |
| mask = None | |
| if self.use_mask: | |
| return input + self.module(input, mask=mask), mask | |
| else: | |
| return input + self.module(input), mask | |
| class Normalization(nn.Module): | |
| def __init__(self, embed_dim, normalization='batch'): | |
| super(Normalization, self).__init__() | |
| normalizer_class = { | |
| 'batch': nn.BatchNorm1d, | |
| 'instance': nn.InstanceNorm1d | |
| }.get(normalization, None) | |
| self.normalizer = normalizer_class(embed_dim, affine=True) | |
| def forward(self, input): | |
| if isinstance(input, tuple): | |
| if len(input) > 1: | |
| input, mask = input[0], input[1] | |
| else: | |
| input = input[0] | |
| mask = None | |
| else: | |
| mask = None | |
| if isinstance(self.normalizer, nn.BatchNorm1d): | |
| return self.normalizer(input.view(-1, input.size(-1))).view(*input.size()), mask | |
| elif isinstance(self.normalizer, nn.InstanceNorm1d): | |
| return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1), mask | |
| else: | |
| return input, mask | |
| class MultiHeadAttentionLayer(nn.Module): | |
| def __init__(self, n_heads, embed_dim, feed_forward_hidden=512, normalization='batch'): | |
| super(MultiHeadAttentionLayer, self).__init__() | |
| self.attention = SkipConnection( | |
| MultiHeadAttention(n_heads, input_dim=embed_dim, embed_dim=embed_dim), | |
| use_mask=True | |
| ) | |
| self.norm1 = Normalization(embed_dim, normalization) | |
| self.ff = SkipConnection( | |
| nn.Sequential( | |
| nn.Linear(embed_dim, feed_forward_hidden), | |
| nn.ReLU(), | |
| nn.Linear(feed_forward_hidden, embed_dim) | |
| ) if feed_forward_hidden > 0 else nn.Linear(embed_dim, embed_dim), | |
| use_mask=False | |
| ) | |
| self.norm2 = Normalization(embed_dim, normalization) | |
| def forward(self, input): | |
| h, mask = self.attention(input) | |
| h, mask = self.norm1((h, mask)) | |
| h, mask = self.ff((h, mask)) | |
| h, mask = self.norm2((h, mask)) | |
| return h, mask | |
| class Encoder(nn.Module): | |
| def __init__(self, n_heads, embed_dim, n_layers, node_dim=None, | |
| normalization='batch', feed_forward_hidden=200): | |
| super(Encoder, self).__init__() | |
| self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None | |
| self.layers = nn.ModuleList([ | |
| MultiHeadAttentionLayer( | |
| n_heads, embed_dim, | |
| feed_forward_hidden=feed_forward_hidden, | |
| normalization=normalization | |
| ) for _ in range(n_layers) | |
| ]) | |
| def forward(self, input, mask=None): | |
| device = input.device | |
| batch_size = input.shape[0] | |
| num_nodes = input.shape[1] | |
| if mask is None: | |
| mask = torch.ones(batch_size, num_nodes, num_nodes).to(device).float() | |
| mask = (mask == 0) | |
| x = self.init_embed(input) if self.init_embed is not None else input | |
| h = x | |
| for layer in self.layers: | |
| h, mask = layer((h, mask)) | |
| return h | |