-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrcnn.py
135 lines (96 loc) · 4.71 KB
/
rcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import tensorflow as tf
from collections import OrderedDict
from arenets.context.architectures.base.fc_single import FullyConnectedLayer
from arenets.context.configurations.rcnn import RCNNConfig
from arenets.sample import InputSample
from arenets.arekit.common.data_type import DataType
from arenets.tf_helpers import sequence
class RCNN(FullyConnectedLayer):
"""
Copyright (c) Joohong Lee
Title: Recurrent Convolutional Neural Networks for Text Classification
Paper: https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9745
Source: https://github.com/roomylee/rcnn-text-classification
"""
H_W_text = "W_text"
H_b_text = "b_text"
def __init__(self):
super(RCNN, self).__init__()
self.__hidden = OrderedDict()
self.__dropout_rnn_keep_prob = None
# region properties
@property
def ContextEmbeddingSize(self):
return self.Config.HiddenSize
# endregion
# region public 'set' methods
def set_input_rnn_keep_prob(self, value):
self.__dropout_rnn_keep_prob = value
# endregion
# region public 'init' methods
def init_input(self):
super(RCNN, self).init_input()
self.__dropout_rnn_keep_prob = tf.compat.v1.placeholder(dtype=tf.float32,
name="ctx_dropout_rnn_keep_prob")
def modify_rnn_outputs_optional(self, output_fw, output_bw):
# Nothing modifies
return output_fw, output_bw
def init_context_embedding(self, embedded_terms):
assert(isinstance(self.Config, RCNNConfig))
text_length = sequence.calculate_sequence_length(self.get_input_parameter(InputSample.I_X_INDS))
with tf.name_scope("bi-rnn"):
fw_cell = sequence.get_cell(hidden_size=self.Config.SurroundingOneSideContextEmbeddingSize,
cell_type=self.Config.CellType,
dropout_rnn_keep_prob=self.__dropout_rnn_keep_prob)
bw_cell = sequence.get_cell(hidden_size=self.Config.SurroundingOneSideContextEmbeddingSize,
cell_type=self.Config.CellType,
dropout_rnn_keep_prob=self.__dropout_rnn_keep_prob)
(output_fw, output_bw), states = sequence.bidirectional_rnn(
cell_fw=fw_cell,
cell_bw=bw_cell,
inputs=embedded_terms,
sequence_length=text_length,
dtype=tf.float32)
output_fw, output_bw = self.modify_rnn_outputs_optional(output_fw, output_bw)
with tf.name_scope("ctx"):
shape = [tf.shape(output_fw)[0], 1, tf.shape(output_fw)[2]]
c_left = tf.concat([tf.zeros(shape), output_fw[:, :-1]], axis=1, name="context_left")
c_right = tf.concat([output_bw[:, 1:], tf.zeros(shape)], axis=1, name="context_right")
with tf.name_scope("word-representation"):
merged = tf.concat([c_left, embedded_terms, c_right], axis=2, name="merged")
with tf.name_scope("text-representation"):
y2 = tf.tanh(tf.einsum('aij,jk->aik', merged, self.__hidden[self.H_W_text]) + self.__hidden[self.H_b_text])
with tf.name_scope("max-pooling"):
y3 = tf.reduce_max(y2, axis=1)
return y3
def init_body_dependent_hidden_states(self):
assert(isinstance(self.Config, RCNNConfig))
self.__hidden[self.H_W_text] = tf.compat.v1.get_variable(
name=self.H_W_text,
shape=[self.__text_embedding_size(), self.Config.HiddenSize],
regularizer=self.Config.LayerRegularizer,
initializer=self.Config.WeightInitializer)
self.__hidden[self.H_b_text] = tf.compat.v1.get_variable(
name=self.H_b_text,
shape=[self.Config.HiddenSize],
regularizer=self.Config.LayerRegularizer,
initializer=self.Config.BiasInitializer)
# endregion
# region public 'iter' methods
def iter_hidden_parameters(self):
for key, value in super(RCNN, self).iter_hidden_parameters():
yield key, value
for key, value in self.__hidden.items():
yield key, value
# endregion
# region public 'create' methods
def create_feed_dict(self, input, data_type):
feed_dict = super(RCNN, self).create_feed_dict(input=input, data_type=data_type)
feed_dict[self.__dropout_rnn_keep_prob] = self.Config.DropoutRNNKeepProb if data_type == DataType.Train else 1.0
return feed_dict
# endregion
# region private methods
def __text_embedding_size(self):
return self.TermEmbeddingSize + \
2 * self.Config.SurroundingOneSideContextEmbeddingSize
# endregion