Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import math | |
| class Projections(nn.Module): | |
| def __init__(self, n_heads, embed_dim): | |
| super(Projections, self).__init__() | |
| self.n_heads = n_heads | |
| self.embed_dim = embed_dim | |
| self.val_dim = embed_dim // n_heads | |
| self.W_key = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim)) | |
| self.W_val = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim)) | |
| self.W_output = nn.Parameter(torch.Tensor(embed_dim, embed_dim)) | |
| self.init_parameters() | |
| def init_parameters(self): | |
| for param in self.parameters(): | |
| stdv = 1. / math.sqrt(param.size(-1)) | |
| param.data.uniform_(-stdv, stdv) | |
| def forward(self, h): | |
| """ | |
| :param h: Tensor of shape (batch_size, graph_size, embed_dim) | |
| :return: dict with keys: K, V, V_output | |
| """ | |
| batch_size, graph_size, input_dim = h.size() | |
| hflat = h.contiguous().view(-1, input_dim) # (batch_size * graph_size, embed_dim) | |
| # Compute Keys and Values per head | |
| shp = (self.n_heads, batch_size, graph_size, self.val_dim) | |
| K = torch.matmul(hflat, self.W_key).view(shp) | |
| V = torch.matmul(hflat, self.W_val).view(shp) | |
| # Compute output projection: (batch_size, graph_size, embed_dim) | |
| V_output = torch.matmul(h, self.W_output.expand_as(self.W_output)) | |
| return { | |
| 'K': K, # (n_heads, batch_size, graph_size, val_dim) | |
| 'V': V, # (n_heads, batch_size, graph_size, val_dim) | |
| 'V_output': V_output # (batch_size, graph_size, embed_dim) | |
| } | |