-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlayers.py
203 lines (167 loc) · 8.51 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import tensorflow as tf
from preprocessing import create_glove_matrix
from preprocessing import read_data, create_char_dict
import numpy as np
class MaxOverTimePoolLayer(tf.keras.layers.Layer):
def __init__(self):
super(MaxOverTimePoolLayer, self).__init__()
# def build(self):
# self.kernel = self.add_weight("kernel",
# shape=[int(input_shape[-1]),
# self.num_outputs])
def call(self, input_tensor):
return tf.math.reduce_max(input_tensor, axis=1)
# return tf.math.reduce_max(input_tensor, axis=0)
class CharacterEmbedding(tf.keras.layers.Layer):
def __init__(self, vocab_size, conv_filters, filter_width=5, emb_dim=None):
super(CharacterEmbedding, self).__init__()
# num_channels = input_.get_shape()[-1]
self.vocab_size = vocab_size + 1
self.out_emb_dim = emb_dim if emb_dim is not None else (self.vocab_size - (self.vocab_size // 2)) # default 35
# self.out_emb_dim = emb_dim if emb_dim is not None else int(self.vocab_size * .3 // 10)
self.conv_filters = conv_filters
self.filter_width = filter_width
self.dropout_conv = tf.keras.layers.Dropout(.2)
self.emb = tf.keras.layers.Embedding(self.vocab_size, self.out_emb_dim)
self.conv = tf.keras.layers.Conv1D(self.conv_filters, self.filter_width, activation='relu',
padding='valid') # , input_shape=(None, char_emb_dim)),
self.max_pool = MaxOverTimePoolLayer()
def call(self, x, training=False):
x = self.emb(x)
x = self.dropout_conv(x, training=training)
x = self.conv(x)
x = self.max_pool(x)
return x
class HighwayLayer(tf.keras.layers.Layer):
def __init__(self):
super(HighwayLayer, self).__init__()
# self.units1, self.units2 = units_trans, units_gate
#
# self.trans = tf.keras.layers.Dense(self.units1, activation='relu')
# self.gate = tf.keras.layers.Dense(self.units2, activation='sigmoid')
def build(self, input_shape):
self.Wh = self.add_weight(shape=(input_shape[1], input_shape[1]), trainable=True)
self.Wt = self.add_weight(shape=(input_shape[1], input_shape[1]), trainable=True)
self.bh = self.add_weight(shape=(input_shape[1],), trainable=True)
self.bt = self.add_weight(shape=(input_shape[1],), trainable=True)
super(HighwayLayer, self).build(input_shape)
def call(self, x):
H = tf.nn.relu(tf.matmul(x, self.Wh) + self.bh, name='activation')
T = tf.sigmoid(tf.matmul(x, self.Wt) + self.bt, name='transform_gate')
# C = tf.sub(1.0, T, name="carry_gate")
y = tf.add(tf.multiply(H, T), tf.multiply(x, 1-T), name='y') # y = (H * T) + (x * C)
return y
class AttentionLayer(tf.keras.layers.Layer):
def __init__(self):
super(AttentionLayer, self).__init__()
pass
def build(self, input_shape):
# dim W = (T * J, 3*dim)
# print("input_shape (Attention Layer): ", input_shape)
self.Ws = self.add_weight(shape=(input_shape[-1]*3, 1), trainable=True)
# self.Ws = self.add_weight(shape=(H.shape[0] * U.shape[0], input_shape[1]*3)) # opzione2 TODO: U.shape, T.shape -> input_shape
super(AttentionLayer, self).build(input_shape)
def computeSimilarity(self, H, U):
"""
Create similarity matrix S between context (H) and query (U)
:param H: context matrix (T x 2d)
:param U: query matrix (J x 2d)
:return: S: similarity matrix (T x J)
"""
# print(" ---------> H shape: ", H.shape, "; U shape: ", U.shape, "; W shape: ", self.Ws.shape)
duplicateH = tf.keras.backend.repeat_elements(H, rep=U.shape[-2], axis=0)
# print("new H dim: ", duplicateH.shape)
duplicateU = tf.tile(U, [H.shape[-2], 1])
# print(" new U dim: ", duplicateU.shape)
C = tf.concat([duplicateH, duplicateU, tf.multiply(duplicateH, duplicateU)], axis=-1)
# print("C dim: ", C.shape)
# opzione giusta: stesso vettore di pesi Ws (funzione \alpha) che moltiplica ogni riga della matrice creata
S = tf.matmul(C, self.Ws)
# print("S shape: ", S.shape)
S = tf.reshape(S, [H.shape[-2], U.shape[-2]])
# print("S reshaped: ", S.shape)
return S
def computeContext2QueryAttention(self, S, U):
"""
Create C2Q attention matrix: which query words are most relevant to each context word.
:param S: similarity matrix (T x J)
:param U: query matrix (J x 2d)
:return: attended_query: C2Q matrix (Ũ) (T x 2d)
"""
C2Q = tf.nn.softmax(S, axis=-1) # attention weights on the query words
# print("C2Q shape (T x J): ", C2Q.shape)
attended_query = tf.matmul(C2Q, U)
# print("attended_query shape (T x d): ", attended_query.shape)
return attended_query
def computeQuery2ContextAttention(self, S, H):
"""
Create Q2C attention matrix: which context words have the closest similarity to one of the query words
and are therefore critical for answering the query.
:param S: similarity matrix (T x J)
:param H: context matrix (T x 2d)
:return: attended_context: Q2C matrix (H̃) (T x 2d)
"""
Q2C = tf.nn.softmax(tf.reduce_max(S, axis=-1))
# print("Q2C shape (T x 1): ", Q2C.shape)
Q2C = tf.expand_dims(Q2C, -1)
attended_context = tf.matmul(Q2C, H, transpose_a=True)
# print("attended_context shape (1 x d): ", AttendedContext.shape)
attended_context = tf.tile(attended_context, [H.shape[-2], 1])
# print("attended_context shape (T x d): ", attended_context.shape)
return attended_context
def merge(self, H, attended_query, attended_context):
"""
Combine the information obtained by the C2Q and Q2C attentions.
Each column vector of G can be considered as the query-aware representation of each context word.
:param H: context matrix (T x 2d)
:param attended_query: C2Q matrix (Ũ) (T x 2d)
:param attended_context: Q2C matrix (H̃) (T x 2d)
:return: G: matrix (T x 8d)
"""
if attended_context is not None:
G = tf.concat([H, attended_query, tf.multiply(H, attended_query), tf.multiply(H, attended_context)], axis=-1)
else:
G = tf.concat([H, attended_query, tf.multiply(H, attended_query)], axis=-1) # q2c ablation
# to be fixed FIXME
# print(" G shape (T X 8d) / (T x 6d) in case of q2c ablation : ", G.shape)
return G
def call(self, H, U, q2c_attention, c2q_attention):
# Similarity matrix (S) dimension: TxJ
S = self.computeSimilarity(H, U)
# C2Q attention
if c2q_attention:
C2Q = self.computeContext2QueryAttention(S, U)
else:
C2Q = " da completare" # to be fixed FIXME
# Q2C attention
if q2c_attention:
Q2C = self.computeQuery2ContextAttention(S, H)
else:
Q2C = None # to be fixed FIXME
# Merge C2Q (Ũ) and Q2C (H̃) to obtain G
G = self.merge(H, C2Q, Q2C)
return G
class OutputLayer(tf.keras.layers.Layer):
def __init__(self, lstm_units, dropout):
super(OutputLayer, self).__init__()
self.bi_lstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(lstm_units, return_sequences=True, dropout=dropout), merge_mode='concat')
self.dropout_out1 = tf.keras.layers.Dropout(.2)
self.dropout_out2 = tf.keras.layers.Dropout(.2)
def build(self, input_shape):
# M input_shape = (T X 2d) -> W dim: (10d, 1)
self.W_start = self.add_weight(shape=(input_shape[-1]*5, 1), trainable=True)
self.W_end = self.add_weight(shape=(input_shape[-1]*5, 1), trainable=True)
super(OutputLayer, self).build(input_shape)
def call(self, M, G, training=False):
# qui dropout
G = self.dropout_out1(G, training=training)
M = self.dropout_out2(M, training=training)
p_start = tf.matmul(tf.concat([G, M], axis=-1), self.W_start)
p_start = tf.nn.softmax(p_start, axis=-2)
M = self.bi_lstm(tf.expand_dims(M, 0), training=training)
M = tf.squeeze(M, [0]) # Removes dimensions of size 1 from the shape of a tensor. (in position 0)
# print(" new M shape ( Tx 2d): ", M.shape)
# qui dropout
p_end = tf.matmul(tf.concat([G, M], axis=-1), self.W_end)
p_end = tf.nn.softmax(p_end, axis=-2)
return p_start, p_end