-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathmodels.py
365 lines (277 loc) · 14.2 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
from embeddings import PositionalEncoding
from utils.pad import pad_masking, subsequent_masking
import torch
from torch import nn
import numpy as np
from collections import defaultdict
PAD_TOKEN_ID = 0
def build_model(config, source_vocabulary_size, target_vocabulary_size):
if config['positional_encoding']:
source_embedding = PositionalEncoding(
num_embeddings=source_vocabulary_size,
embedding_dim=config['d_model'],
dim=config['d_model']) # why dim?
target_embedding = PositionalEncoding(
num_embeddings=target_vocabulary_size,
embedding_dim=config['d_model'],
dim=config['d_model']) # why dim?
else:
source_embedding = nn.Embedding(
num_embeddings=source_vocabulary_size,
embedding_dim=config['d_model'])
target_embedding = nn.Embedding(
num_embeddings=target_vocabulary_size,
embedding_dim=config['d_model'])
encoder = TransformerEncoder(
layers_count=config['layers_count'],
d_model=config['d_model'],
heads_count=config['heads_count'],
d_ff=config['d_ff'],
dropout_prob=config['dropout_prob'],
embedding=source_embedding)
decoder = TransformerDecoder(
layers_count=config['layers_count'],
d_model=config['d_model'],
heads_count=config['heads_count'],
d_ff=config['d_ff'],
dropout_prob=config['dropout_prob'],
embedding=target_embedding)
model = Transformer(encoder, decoder)
return model
class Transformer(nn.Module):
def __init__(self, encoder, decoder):
super(Transformer, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, sources, inputs):
# sources : (batch_size, sources_len)
# inputs : (batch_size, targets_len - 1)
batch_size, sources_len = sources.size()
batch_size, inputs_len = inputs.size()
sources_mask = pad_masking(sources, sources_len)
memory_mask = pad_masking(sources, inputs_len)
inputs_mask = subsequent_masking(inputs) | pad_masking(inputs, inputs_len)
memory = self.encoder(sources, sources_mask) # (batch_size, seq_len, d_model)
outputs, state = self.decoder(inputs, memory, memory_mask, inputs_mask) # (batch_size, seq_len, d_model)
return outputs
class TransformerEncoder(nn.Module):
def __init__(self, layers_count, d_model, heads_count, d_ff, dropout_prob, embedding):
super(TransformerEncoder, self).__init__()
self.d_model = d_model
self.embedding = embedding
self.encoder_layers = nn.ModuleList(
[TransformerEncoderLayer(d_model, heads_count, d_ff, dropout_prob) for _ in range(layers_count)]
)
def forward(self, sources, mask):
"""
args:
sources: embedded_sequence, (batch_size, seq_len, embed_size)
"""
sources = self.embedding(sources)
for encoder_layer in self.encoder_layers:
sources = encoder_layer(sources, mask)
return sources
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, heads_count, d_ff, dropout_prob):
super(TransformerEncoderLayer, self).__init__()
self.self_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob), d_model)
self.pointwise_feedforward_layer = Sublayer(PointwiseFeedForwardNetwork(d_ff, d_model, dropout_prob), d_model)
self.dropout = nn.Dropout(dropout_prob)
def forward(self, sources, sources_mask):
# x: (batch_size, seq_len, d_model)
sources = self.self_attention_layer(sources, sources, sources, sources_mask)
sources = self.dropout(sources)
sources = self.pointwise_feedforward_layer(sources)
return sources
class TransformerDecoder(nn.Module):
def __init__(self, layers_count, d_model, heads_count, d_ff, dropout_prob, embedding):
super(TransformerDecoder, self).__init__()
self.d_model = d_model
self.embedding = embedding
self.decoder_layers = nn.ModuleList(
[TransformerDecoderLayer(d_model, heads_count, d_ff, dropout_prob) for _ in range(layers_count)]
)
self.generator = nn.Linear(embedding.embedding_dim, embedding.num_embeddings)
self.generator.weight = self.embedding.weight
def forward(self, inputs, memory, memory_mask, inputs_mask=None, state=None):
# inputs: (batch_size, seq_len - 1, d_model)
# memory: (batch_size, seq_len, d_model)
inputs = self.embedding(inputs)
# if state is not None:
# inputs = torch.cat([state.previous_inputs, inputs], dim=1)
#
# state.previous_inputs = inputs
for layer_index, decoder_layer in enumerate(self.decoder_layers):
if state is None:
inputs = decoder_layer(inputs, memory, memory_mask, inputs_mask)
else: # Use cache
layer_cache = state.layer_caches[layer_index]
# print('inputs_mask', inputs_mask)
inputs = decoder_layer(inputs, memory, memory_mask, inputs_mask, layer_cache)
state.update_state(
layer_index=layer_index,
layer_mode='self-attention',
key_projected=decoder_layer.self_attention_layer.sublayer.key_projected,
value_projected=decoder_layer.self_attention_layer.sublayer.value_projected,
)
state.update_state(
layer_index=layer_index,
layer_mode='memory-attention',
key_projected=decoder_layer.memory_attention_layer.sublayer.key_projected,
value_projected=decoder_layer.memory_attention_layer.sublayer.value_projected,
)
generated = self.generator(inputs) # (batch_size, seq_len, vocab_size)
return generated, state
def init_decoder_state(self, **args):
return DecoderState()
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, heads_count, d_ff, dropout_prob):
super(TransformerDecoderLayer, self).__init__()
self.self_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob, mode='self-attention'), d_model)
self.memory_attention_layer = Sublayer(MultiHeadAttention(heads_count, d_model, dropout_prob, mode='memory-attention'), d_model)
self.pointwise_feedforward_layer = Sublayer(PointwiseFeedForwardNetwork(d_ff, d_model, dropout_prob), d_model)
def forward(self, inputs, memory, memory_mask, inputs_mask, layer_cache=None):
# print('self attention')
# print('inputs_mask', inputs_mask)
inputs = self.self_attention_layer(inputs, inputs, inputs, inputs_mask, layer_cache)
# print('memory attention')
inputs = self.memory_attention_layer(inputs, memory, memory, memory_mask, layer_cache)
inputs = self.pointwise_feedforward_layer(inputs)
return inputs
class Sublayer(nn.Module):
def __init__(self, sublayer, d_model):
super(Sublayer, self).__init__()
self.sublayer = sublayer
self.layer_normalization = LayerNormalization(d_model)
def forward(self, *args):
x = args[0]
x = self.sublayer(*args) + x
return self.layer_normalization(x)
class LayerNormalization(nn.Module):
def __init__(self, features_count, epsilon=1e-6):
super(LayerNormalization, self).__init__()
self.gain = nn.Parameter(torch.ones(features_count))
self.bias = nn.Parameter(torch.zeros(features_count))
self.epsilon = epsilon
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.gain * (x - mean) / (std + self.epsilon) + self.bias
class MultiHeadAttention(nn.Module):
def __init__(self, heads_count, d_model, dropout_prob, mode='self-attention'):
super(MultiHeadAttention, self).__init__()
assert d_model % heads_count == 0
assert mode in ('self-attention', 'memory-attention')
self.d_head = d_model // heads_count
self.heads_count = heads_count
self.mode = mode
self.query_projection = nn.Linear(d_model, heads_count * self.d_head)
self.key_projection = nn.Linear(d_model, heads_count * self.d_head)
self.value_projection = nn.Linear(d_model, heads_count * self.d_head)
self.final_projection = nn.Linear(d_model, heads_count * self.d_head)
self.dropout = nn.Dropout(dropout_prob)
self.softmax = nn.Softmax(dim=3)
self.attention = None
# For cache
self.key_projected = None
self.value_projected = None
def forward(self, query, key, value, mask=None, layer_cache=None):
"""
Args:
query: (batch_size, query_len, model_dim)
key: (batch_size, key_len, model_dim)
value: (batch_size, value_len, model_dim)
mask: (batch_size, query_len, key_len)
state: DecoderState
"""
# print('attention mask', mask)
batch_size, query_len, d_model = query.size()
d_head = d_model // self.heads_count
query_projected = self.query_projection(query)
# print('query_projected', query_projected.shape)
if layer_cache is None or layer_cache[self.mode] is None: # Don't use cache
key_projected = self.key_projection(key)
value_projected = self.value_projection(value)
else: # Use cache
if self.mode == 'self-attention':
key_projected = self.key_projection(key)
value_projected = self.value_projection(value)
key_projected = torch.cat([key_projected, layer_cache[self.mode]['key_projected']], dim=1)
value_projected = torch.cat([value_projected, layer_cache[self.mode]['value_projected']], dim=1)
elif self.mode == 'memory-attention':
key_projected = layer_cache[self.mode]['key_projected']
value_projected = layer_cache[self.mode]['value_projected']
# For cache
self.key_projected = key_projected
self.value_projected = value_projected
batch_size, key_len, d_model = key_projected.size()
batch_size, value_len, d_model = value_projected.size()
query_heads = query_projected.view(batch_size, query_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, query_len, d_head)
# print('query_heads', query_heads.shape)
# print(batch_size, key_len, self.heads_count, d_head)
# print(key_projected.shape)
key_heads = key_projected.view(batch_size, key_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, key_len, d_head)
value_heads = value_projected.view(batch_size, value_len, self.heads_count, d_head).transpose(1, 2) # (batch_size, heads_count, value_len, d_head)
attention_weights = self.scaled_dot_product(query_heads, key_heads) # (batch_size, heads_count, query_len, key_len)
if mask is not None:
# print('mode', self.mode)
# print('mask', mask.shape)
# print('attention_weights', attention_weights.shape)
mask_expanded = mask.unsqueeze(1).expand_as(attention_weights)
attention_weights = attention_weights.masked_fill(mask_expanded, -1e18)
self.attention = self.softmax(attention_weights) # Save attention to the object
# print('attention_weights', attention_weights.shape)
attention_dropped = self.dropout(self.attention)
context_heads = torch.matmul(attention_dropped, value_heads) # (batch_size, heads_count, query_len, d_head)
# print('context_heads', context_heads.shape)
context_sequence = context_heads.transpose(1, 2).contiguous() # (batch_size, query_len, heads_count, d_head)
context = context_sequence.view(batch_size, query_len, d_model) # (batch_size, query_len, d_model)
final_output = self.final_projection(context)
# print('final_output', final_output.shape)
return final_output
def scaled_dot_product(self, query_heads, key_heads):
"""
Args:
query_heads: (batch_size, heads_count, query_len, d_head)
key_heads: (batch_size, heads_count, key_len, d_head)
"""
key_heads_transposed = key_heads.transpose(2, 3)
dot_product = torch.matmul(query_heads, key_heads_transposed) # (batch_size, heads_count, query_len, key_len)
attention_weights = dot_product / np.sqrt(self.d_head)
return attention_weights
class PointwiseFeedForwardNetwork(nn.Module):
def __init__(self, d_ff, d_model, dropout_prob):
super(PointwiseFeedForwardNetwork, self).__init__()
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.Dropout(dropout_prob),
nn.ReLU(),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout_prob),
)
def forward(self, x):
"""
Args:
x: (batch_size, seq_len, d_model)
"""
return self.feed_forward(x)
class DecoderState:
def __init__(self):
self.previous_inputs = torch.tensor([])
self.layer_caches = defaultdict(lambda: {'self-attention': None, 'memory-attention': None})
def update_state(self, layer_index, layer_mode, key_projected, value_projected):
self.layer_caches[layer_index][layer_mode] = {
'key_projected': key_projected,
'value_projected': value_projected
}
# def repeat_beam_size_times(self, beam_size): # memory만 repeat하면 되는데 state에 memory는 넣지 않기로 했다.
# self.
# self.src = self.src.data.repeat(beam_size, 1)
def beam_update(self, positions):
for layer_index in self.layer_caches:
for mode in ('self-attention', 'memory-attention'):
if self.layer_caches[layer_index][mode] is not None:
for projection in self.layer_caches[layer_index][mode]:
cache = self.layer_caches[layer_index][mode][projection]
if cache is not None:
cache.data.copy_(cache.data.index_select(0, positions))