Spaces:
Runtime error
Runtime error
Update nets/projections.py
Browse files- nets/projections.py +12 -13
nets/projections.py
CHANGED
|
@@ -13,7 +13,7 @@ class Projections(nn.Module):
|
|
| 13 |
|
| 14 |
self.W_key = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
|
| 15 |
self.W_val = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
|
| 16 |
-
self.W_output = nn.Parameter(torch.Tensor(
|
| 17 |
|
| 18 |
self.init_parameters()
|
| 19 |
|
|
@@ -24,23 +24,22 @@ class Projections(nn.Module):
|
|
| 24 |
|
| 25 |
def forward(self, h):
|
| 26 |
"""
|
| 27 |
-
:param h: (batch_size, graph_size, embed_dim)
|
| 28 |
-
:return: dict with keys K, V, V_output
|
| 29 |
"""
|
| 30 |
batch_size, graph_size, input_dim = h.size()
|
| 31 |
-
hflat = h.view(-1, input_dim) # (batch_size * graph_size, embed_dim)
|
| 32 |
|
|
|
|
| 33 |
shp = (self.n_heads, batch_size, graph_size, self.val_dim)
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
#
|
| 36 |
-
|
| 37 |
-
V = torch.matmul(hflat, self.W_val).view(shp) # (n_heads, batch_size, graph_size, val_dim)
|
| 38 |
-
|
| 39 |
-
# Output projection
|
| 40 |
-
V_output = torch.bmm(h, self.W_output.repeat(batch_size, 1, 1)) # (batch_size, graph_size, embed_dim)
|
| 41 |
|
| 42 |
return {
|
| 43 |
-
'K': K,
|
| 44 |
-
'V': V,
|
| 45 |
-
'V_output': V_output
|
| 46 |
}
|
|
|
|
| 13 |
|
| 14 |
self.W_key = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
|
| 15 |
self.W_val = nn.Parameter(torch.Tensor(n_heads, embed_dim, self.val_dim))
|
| 16 |
+
self.W_output = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
|
| 17 |
|
| 18 |
self.init_parameters()
|
| 19 |
|
|
|
|
| 24 |
|
| 25 |
def forward(self, h):
|
| 26 |
"""
|
| 27 |
+
:param h: Tensor of shape (batch_size, graph_size, embed_dim)
|
| 28 |
+
:return: dict with keys: K, V, V_output
|
| 29 |
"""
|
| 30 |
batch_size, graph_size, input_dim = h.size()
|
| 31 |
+
hflat = h.contiguous().view(-1, input_dim) # (batch_size * graph_size, embed_dim)
|
| 32 |
|
| 33 |
+
# Compute Keys and Values per head
|
| 34 |
shp = (self.n_heads, batch_size, graph_size, self.val_dim)
|
| 35 |
+
K = torch.matmul(hflat, self.W_key).view(shp)
|
| 36 |
+
V = torch.matmul(hflat, self.W_val).view(shp)
|
| 37 |
|
| 38 |
+
# Compute output projection: (batch_size, graph_size, embed_dim)
|
| 39 |
+
V_output = torch.matmul(h, self.W_output.expand_as(self.W_output))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
return {
|
| 42 |
+
'K': K, # (n_heads, batch_size, graph_size, val_dim)
|
| 43 |
+
'V': V, # (n_heads, batch_size, graph_size, val_dim)
|
| 44 |
+
'V_output': V_output # (batch_size, graph_size, embed_dim)
|
| 45 |
}
|