forked from dabsdamoon/tacotron_tensorflow2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention.py
298 lines (212 loc) · 11.6 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
from functools import partial
import math
import numpy as np
import tensorflow as tf
##### Define class to get alignments and attention
class BahdanauAttention(tf.keras.layers.Layer):
"""
This is a class for obtaining attention value, heavily inspired by two sources:
https://www.tensorflow.org/tutorials/text/nmt_with_attention
https://github.com/tensorflow/addons/tree/v0.7.1/tensorflow_addons/seq2seq
Note that in nmt_with_attention, the return value of call function is directly attention (context_vector).
This is a bit different from attention class in tensorflow_addons, which first returns alignments and then computes the attention outside of the class.
Based on Kyubyong and tensorflow_addons code, it seems that current attention at timestep t (a_t)
is computed using the attention at timestep t-1 (a_t-1), so I decided to change the attention class in nmt_with_atteition.
"""
def __init__(self,
units,
normalize = False,
**kwargs):
"""
Args:
units: Number of units in dense layers in attention.
normalize: Whether to normalize the score. If true, additional weights are used to compute normalized scores.
**kwargs : Dictionary that contains other common argument for layer creation.
"""
super(BahdanauAttention, self).__init__(**kwargs)
self.units = units
self.normalize = normalize
self.W1 = tf.keras.layers.Dense(self.units, name = "W_value", use_bias = False)
self.W2 = tf.keras.layers.Dense(self.units, name = "W_query", use_bias = False)
def build(self, input_shape):
super(BahdanauAttention, self).build(input_shape)
self.attention_v = self.add_weight("attention_v",
[self.units],
dtype=self.dtype,
initializer= "glorot_uniform")
if self.normalize:
self.attention_g = self.add_weight("attention_g",
dtype=self.dtype,
initializer=tf.constant_initializer(
math.sqrt((1. / self.units))),
shape=())
self.attention_b = self.add_weight("attention_b",
[self.units],
dtype=self.dtype,
initializer=tf.zeros_initializer())
def call(self,
query,
values,
memory_mask = None):
"""
Args:
query: Input tensor to compute attention at current timestep. In this code, it's cell output values generated by final output tensor of previous time and previous attention tensor.
values: Memory value computed before. Often, it's the final full-sequence (return_sequences = True) output value.
memory_mask: Mask value for memory. Only applied if mask tensor is given.
Returns:
alignments: Probability value indicating in which memory timestep it's paying attention to.
next_state: Next state value. Same as alignments in this case.
"""
hidden_with_time_axis = tf.expand_dims(query, 1)
if self.normalize:
normed_v = self.attention_g * self.attention_v * tf.math.rsqrt(tf.reduce_sum(tf.square(self.attention_v)))
score = tf.reduce_sum(normed_v * tf.tanh(self.W1(values) + self.W2(hidden_with_time_axis) + self.attention_b),
[2])
else:
score = tf.reduce_sum(self.attention_v * tf.tanh(self.W1(values) + self.W2(hidden_with_time_axis)),
[2])
##### Masking process
if memory_mask is not None:
score = self._maybe_mask_score(score, memory_mask)
alignments = tf.nn.softmax(score, axis=1) # alignments
next_state = alignments # This is part of _calculate_attention function of attention mechanisms in tfa.seq2seq.attention_wrapper
return alignments, next_state
##### Function to mask score
def _maybe_mask_score(self, score, memory_mask):
"""
Args:
score: Score to be masked.
memory_mask: Mask tensor
Returns:
outputs: Output tensor with mask applied
"""
score_mask_value = score.dtype.min
score_mask_values = score_mask_value * tf.ones_like(score)
outputs = tf.where(memory_mask, score, score_mask_values)
return outputs
##### Function to compute attention
def _compute_attention(self, alignments, values):
"""
Args:
alignments: Alignment value computed at call() function.
values: Memory value computed before. Often, it's the final full-sequence (return_sequences = True) output value.
Returns:
attention: Attention value computed.
alignments: Alignments given
"""
expanded_alignments = tf.expand_dims(alignments, 1)
context_vector = tf.matmul(expanded_alignments, values)
context_vector = tf.squeeze(context_vector, [1])
# No spectific attention layer given
attention = context_vector
return attention, alignments
##### Define class to get Bahdanau Monotonic alignments and attention
class BahdanauMonotonicAttention(tf.keras.layers.Layer):
def __init__(self,
units,
sigmoid_noise = 0.0,
normalize = False,
**kwargs):
self.units = units
self.sigmoid_noise = sigmoid_noise
self.normalize = normalize
self.W1 = tf.keras.layers.Dense(self.units, name = "W_value", use_bias = False)
self.W2 = tf.keras.layers.Dense(self.units, name = "W_query", use_bias = False)
self.probability_fn = partial(self._monotonic_probability_fn,
sigmoid_noise = self.sigmoid_noise)
super(BahdanauMonotonicAttention, self).__init__(**kwargs)
def build(self, input_shape):
self.attention_score_bias = self.add_weight("attention_score_bias",
shape = (),
dtype = self.dtype,
initializer = tf.constant_initializer(0.0))
self.attention_v = self.add_weight("attention_v",
[self.units],
dtype=self.dtype,
initializer= "glorot_uniform")
if self.normalize:
self.attention_g = self.add_weight("attention_g",
dtype=self.dtype,
initializer=tf.constant_initializer(math.sqrt((1. / self.units))),
shape=())
self.attention_b = self.add_weight("attention_b",
[self.units],
dtype=self.dtype,
initializer=tf.zeros_initializer())
super(BahdanauMonotonicAttention, self).build(input_shape)
def call(self,
query,
values,
previous_alignments,
memory_mask = None):
hidden_with_time_axis = tf.expand_dims(query, 1) # Processed query
if self.normalize:
normed_v = self.attention_g * self.attention_v * tf.math.rsqrt(tf.reduce_sum(tf.square(self.attention_v)))
score = tf.reduce_sum(normed_v * tf.tanh(self.W1(values) + self.W2(hidden_with_time_axis) + self.attention_b),
[2])
else:
score = tf.reduce_sum(self.attention_v * tf.tanh(self.W1(values) + self.W2(hidden_with_time_axis)),
[2])
##### Masking process
if memory_mask is not None:
score = self._maybe_mask_score(score, memory_mask)
score += self.attention_score_bias
alignments = self.probability_fn(score, previous_alignments)
next_attention_state = alignments
return alignments, next_attention_state
##### Define function to mask score
def _maybe_mask_score(self, score, memory_mask):
score_mask_value = score.dtype.min
score_mask_values = score_mask_value * tf.ones_like(score)
return tf.where(memory_mask, score, score_mask_values)
##### Define function to get monotonic probability
def _monotonic_probability_fn(self,
score,
previous_alignments,
sigmoid_noise):
"""
Attention probability function for monotonic attention obteind from
(https://github.com/tensorflow/addons/blob/1af92905ed03f05fcf6f4918783c3d151a8b8350/tensorflow_addons/seq2seq/attention_wrapper.py#L810)
"""
# Optionally add pre-sigmoid noise to the scores
if sigmoid_noise > 0:
noise = tf.random.normal(tf.shape(score), dtype=score.dtype)
score += sigmoid_noise * noise
# Compute "choosing" probabilities from the attention scores
p_choose_i = tf.sigmoid(score) # probability of selecting i
return self._monotonic_attention(p_choose_i, previous_alignments)
##### Define function to get monotonic attention
def _monotonic_attention(self, p_choose_i, previous_alignments):
"""
Function for computing monotonic attention brought from tensorflow.addon github
(https://github.com/tensorflow/addons/blob/1af92905ed03f05fcf6f4918783c3d151a8b8350/tensorflow_addons/seq2seq/attention_wrapper.py#L927).
In this case, we're only dealing with the "parallel" mode.
"""
# Force things to be tensors
p_choose_i = tf.convert_to_tensor(p_choose_i, name="p_choose_i")
previous_alignments = tf.convert_to_tensor(
previous_alignments, name="previous_attention")
# safe_cumprod computes cumprod in logspace with numeric checks
cumprod_1mp_choose_i = self._safe_cumprod(
1 - p_choose_i, axis=1, exclusive=True)
# Compute recurrence relation solution
attention = p_choose_i * cumprod_1mp_choose_i * tf.cumsum(
previous_alignments /
# Clip cumprod_1mp to avoid divide-by-zero
tf.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.),
axis=1)
return attention
##### Define function to be used for monotonic attention
def _safe_cumprod(self, x, *args, **kwargs):
x = tf.convert_to_tensor(x, name="x")
tiny = np.finfo(x.dtype.as_numpy_dtype).tiny
return tf.exp(
tf.cumsum(
tf.math.log(tf.clip_by_value(x, tiny, 1)), *args, **kwargs))
def _compute_attention(self, alignments, values):
expanded_alignments = tf.expand_dims(alignments, 1)
context_vector = tf.matmul(expanded_alignments, values)
context_vector = tf.squeeze(context_vector, [1])
# No spectific attention layer given
attention = context_vector
return attention, alignments