-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembeddings.py
98 lines (80 loc) · 3.37 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import Dict, List, Tuple
import torch.nn as nn
class TimeDistributedEmbeddingBag(nn.EmbeddingBag):
def __init__(self, *args, batch_first: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.batch_first = batch_first
def forward(self, x):
if len(x.size()) <= 2:
return super().forward(x)
# Squash samples and timesteps into a single axis
x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
y = super().forward(x_reshape)
# We have to reshape Y
if self.batch_first:
y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
else:
y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
return y
class MultiEmbedding(nn.Module):
def __init__(
self,
embedding_sizes: Dict[str, Tuple[int, int]],
categorical_groups: Dict[str, List[str]],
embedding_paddings: List[str],
x_categoricals: List[str],
max_embedding_size: int = None,
):
super().__init__()
self.embedding_sizes = embedding_sizes
self.categorical_groups = categorical_groups
self.embedding_paddings = embedding_paddings
self.max_embedding_size = max_embedding_size
self.x_categoricals = x_categoricals
self.init_embeddings()
def init_embeddings(self):
self.embeddings = nn.ModuleDict()
for name in self.embedding_sizes.keys():
embedding_size = self.embedding_sizes[name][1]
if self.max_embedding_size is not None:
embedding_size = min(embedding_size, self.max_embedding_size)
# convert to list to become mutable
self.embedding_sizes[name] = list(self.embedding_sizes[name])
self.embedding_sizes[name][1] = embedding_size
if name in self.categorical_groups: # embedding bag if related embeddings
self.embeddings[name] = TimeDistributedEmbeddingBag(
self.embedding_sizes[name][0], embedding_size, mode="sum", batch_first=True
)
else:
if name in self.embedding_paddings:
padding_idx = 0
else:
padding_idx = None
self.embeddings[name] = nn.Embedding(
self.embedding_sizes[name][0],
embedding_size, # This one is hidden size 160 in original
padding_idx=padding_idx,
)
def names(self):
return list(self.keys())
def items(self):
return self.embeddings.items()
def keys(self):
return self.embeddings.keys()
def values(self):
return self.embeddings.values()
def __getitem__(self, name: str):
return self.embeddings[name]
def forward(self, x):
input_vectors = {}
for name, emb in self.embeddings.items():
if name in self.categorical_groups:
input_vectors[name] = emb(
x[
...,
[self.x_categoricals.index(cat_name) for cat_name in self.categorical_groups[name]],
]
)
else:
input_vectors[name] = emb(x[..., self.x_categoricals.index(name)])
return input_vectors