Skip to content

Commit 5fa0b17

Browse files
committed
Make embed_dim required and fix size mismatch bug when key_dim and val_dim are different
1 parent b8c9b8a commit 5fa0b17

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

nets/graph_encoder.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ def __init__(
1919
self,
2020
n_heads,
2121
input_dim,
22-
embed_dim=None,
22+
embed_dim,
2323
val_dim=None,
2424
key_dim=None
2525
):
2626
super(MultiHeadAttention, self).__init__()
2727

2828
if val_dim is None:
29-
assert embed_dim is not None, "Provide either embed_dim or val_dim"
3029
val_dim = embed_dim // n_heads
3130
if key_dim is None:
3231
key_dim = val_dim
@@ -43,8 +42,7 @@ def __init__(
4342
self.W_key = nn.Parameter(torch.Tensor(n_heads, input_dim, key_dim))
4443
self.W_val = nn.Parameter(torch.Tensor(n_heads, input_dim, val_dim))
4544

46-
if embed_dim is not None:
47-
self.W_out = nn.Parameter(torch.Tensor(n_heads, key_dim, embed_dim))
45+
self.W_out = nn.Parameter(torch.Tensor(n_heads, val_dim, embed_dim))
4846

4947
self.init_parameters()
5048

0 commit comments

Comments
 (0)