Skip to content

Commit

Permalink
Add masked LSTM support
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Cong <congyc@amazon.com>
  • Loading branch information
Yu Cong committed Aug 27, 2022
1 parent bc677a1 commit d4e16b5
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 25 deletions.
159 changes: 159 additions & 0 deletions tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,5 +793,164 @@ def func(x):
return tf.identity(y[0], name="output")
self.run_test_case(func, {"input:0": x_val}, [], ["output:0"], rtol=1e-05, atol=1e-06)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_embedding_unidirectional(self):
for go_backwards in [True, False]:
timesteps = 4
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val = np.array([
[1, 2, 3, 4],
[5, 6, 0, 0],
[0, 0, 0, 0]
], dtype=np.int32)

model_in = tf.keras.layers.Input((timesteps,), dtype="int32")
x_embedding = tf.keras.layers.Embedding(
input_dim=10,
output_dim=5,
mask_zero=True,
embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41),
)(model_in)

# RNN layer inherits the mask propagated from above embedding layer
model_out = tf.keras.layers.LSTM(
units=5,
go_backwards=go_backwards,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)(x_embedding)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
return(tf.identity(y[1], name="output_yh"),
tf.identity(y[2], name="output_yc"))

output_list = ["output_yh:0", "output_yc:0"]
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_embedding_bidirectional(self):
timesteps = 4
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val = np.array([
[1, 2, 3, 4],
[5, 6, 0, 0],
[0, 0, 0, 0]
], dtype=np.int32)

model_in = tf.keras.layers.Input((timesteps,), dtype="int32")
x_embedding = tf.keras.layers.Embedding(
input_dim=10,
output_dim=5,
mask_zero=True,
embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41),
)(model_in)

# RNN layer inherits the mask propagated from above embedding layer
lstm_layer = tf.keras.layers.LSTM(
units=5,
go_backwards=False,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)
model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_embedding)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
return(tf.identity(y[1], name="output_yh_f"),
tf.identity(y[2], name="output_yc_f"),
tf.identity(y[3], name="output_yh_r"),
tf.identity(y[4], name="output_yc_r"))

output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"]
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06,
require_lstm_count=2)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_unidirectional(self):
for go_backwards in [True, False]:
batch_size, timesteps, feat = 3, 4, 5
in_shape = (timesteps, feat)
x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32)
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val[1, 2:, :] = 0.
x_val[2, :, :] = 0.

model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32")
x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in)

# RNN layer inherits the mask propagated from above mask layer
model_out = tf.keras.layers.LSTM(
units=5,
go_backwards=go_backwards,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)(x_masked)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
return(tf.identity(y[1], name="output_yh"),
tf.identity(y[2], name="output_yc"))

output_list = ["output_yh:0", "output_yc:0"]
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06)

@check_tf_min_version("2.0")
@skip_tf_versions("2.1", "Bug in TF 2.1")
def test_keras_masked_lstm_bidirectional(self):
batch_size, timesteps, feat = 3, 4, 5
in_shape = (timesteps, feat)
x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32)
# Note: masked LSTM only support post-padded input after conversion
# test case sequence_lens = [4, 2, 0]
x_val[1, 2:, :] = 0.
x_val[2, :, :] = 0.

model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32")
x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in)

# RNN layer inherits the mask propagated from above mask layer
lstm_layer = tf.keras.layers.LSTM(
units=5,
go_backwards=False,
return_state=True,
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
)
model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_masked)
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)

def func(x):
y = model(x)
# skiping output Y: https://github.com/microsoft/onnxruntime/issues/12492
return (tf.identity(y[1], name="output_yh_f"),
tf.identity(y[2], name="output_yc_f"),
tf.identity(y[3], name="output_yh_r"),
tf.identity(y[4], name="output_yc_r"))

output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"]
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06,
require_lstm_count=2)


if __name__ == '__main__':
unittest_main()
27 changes: 18 additions & 9 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2260,15 +2260,24 @@ def version_10(cls, ctx, node, **kwargs):
const_axis_name = utils.make_name(f'const_{axis}')
const_axis = ctx.make_const(name=const_axis_name, np_val=np.array([axis], dtype=np.int64))

# Add a Constant node (seq_len) for ReverseSequence.
# Index 1 for the shape should not return 0, since rank(input) >=2
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
op_name_scope=rv2_node_name)
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
op_name_scope=rv2_node_name)
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
inputs.append(seq_array.output[0])
# Add sequence_lens as ReverseSequence input
has_sequence_lens = node.get_attr_value("has_sequence_lens", False)
if not has_sequence_lens:
# Add a Constant node (seq_len) for ReverseSequence.
# Index 1 for the shape should not return 0, since rank(input) >=2
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
op_name_scope=rv2_node_name)
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
op_name_scope=rv2_node_name)
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
inputs.append(seq_array.output[0])
else:
# masked backward LSTM:
# sequence_lens is appended to ReverseV2's input by lstm_tf2_rewriter
# to keep tensor post-padded after reverse
seq_lens_casted = ctx.make_node("Cast", [node.input[-1]], attr={'to': TensorProto.INT64}).output[0]
inputs.append(seq_lens_casted)

# Add a ReverseSequence node.

Expand Down
99 changes: 83 additions & 16 deletions tf2onnx/rewriter/lstm_tf2_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
"""
tf2onnx.rewriter.lstm_tf2_rewriter - Rewrites LSTM pattern used by tf2.
"""

import logging
import numpy as np
from onnx import onnx_pb

from tf2onnx.graph_matcher import GraphMatcher
from tf2onnx.rewriter.rnn_utils import make_lstm_pattern
from tf2onnx.tf_loader import find_function
Expand Down Expand Up @@ -79,21 +81,35 @@ def rewriter_lstm_tf2(g, ops):
# extract output h_t
ht_mul = match_result.get_op("ht")
final_consumers = g.find_output_consumers(ht_mul.output[0])
select_ops = [n for n in final_consumers if n.type == "Select"]
select_ops = [n for n in final_consumers if n.type == "Select" or n.type == "SelectV2"]
def has_tensor_list_consumer(n):
return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0]))
select_ops = [n for n in select_ops if has_tensor_list_consumer(n)]

# extract sequence length
seq_len_idx, mask_idx = None, None
if len(select_ops) == 1:
greater_eq = select_ops[0].inputs[0]
if greater_eq.type != "GreaterEqual":
continue
seq_len = greater_eq.inputs[1]
if not seq_len.is_graph_input():
select_op_condition = select_ops[0].inputs[0]
while select_op_condition.type == "Identity":
select_op_condition = select_op_condition.inputs[0]

# skip timestpes based on speicific sequence length
if select_op_condition.type == "GreaterEqual":
seq_len = select_op_condition.inputs[1]
if not seq_len.is_graph_input():
continue
seq_len_idx = g.input_names.index(seq_len.output[0])

# masked LSTM: skip timesteps based on dynamically-computed boolean mask tensor
elif select_op_condition.type == "TensorListGetItem":
mask = select_op_condition.inputs[0]
if not mask.is_graph_input():
continue
mask_idx = g.input_names.index(mask.output[0])
else:
continue
seq_len_idx = g.input_names.index(seq_len.output[0])

final_consumers = g.find_output_consumers(select_ops[0].output[0])
else:
seq_len_idx = None

tensor_set_items = [n for n in final_consumers if n.type == "TensorListSetItem"]
if len(tensor_set_items) != 1:
Expand Down Expand Up @@ -209,6 +225,7 @@ def has_tensor_list_consumer(n):
# Keras
"w_idx": gk_idx,
"r_idx": hk_idx,
"mask_idx": mask_idx,
}

for op in ops:
Expand Down Expand Up @@ -276,15 +293,63 @@ def has_tensor_list_consumer(n):
tensor_array_inp = op.inputs[body_context["x_idx"]]
if not tensor_array_inp.type == "TensorListFromTensor":
continue
context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0]

final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]])
output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"]
# parse sequence length
seq_len_idx = body_context["seq_len_idx"]
mask_idx = body_context["mask_idx"]
if seq_len_idx:
context.onnx_input_ids[0]["sequence_lens"] = op.input[seq_len_idx]
elif mask_idx:
logging.warning(
"Found mask-enabled LSTM. Converted ONNX model will only support post-padded LSTM input. "
"If input is pre- or randomly-padded, masked timesteps will not be correctly skipped.")

# parse sequence length
tensor_array_mask = op.inputs[body_context["mask_idx"]]
if not tensor_array_mask.type == "TensorListFromTensor":
continue
mask_mat = tensor_array_mask.input[0]
mask_mat_node = g.get_node_by_output(mask_mat)
is_mask_reverse = mask_mat_node.type == "ReverseV2"
# no need to reverse the mask sequence
# the positions of skipped timesteps per batch is irrelevant assuming post-padded input
if is_mask_reverse:
mask_mat = mask_mat_node.input[0]

# reduce mask tensor to sequence_lens assuming post-padded input
# tranpose (1,0,2) -> boolean mask tensor (N, timesteps, 1)
# squeeze on dim(-1) -> boolean mask matrix (N, timesteps)
# reduceSum on dim(-1) -> sequence_lens (N)
mask_transpose_node = g.make_node(op_type="Transpose", inputs=[mask_mat], attr={"perm": [1, 0, 2]})
mask_squeeze = GraphBuilder(g).make_squeeze({"data": mask_transpose_node.output[0], "axes": [-1]})
mask_cast_node = g.make_node(op_type="Cast", inputs=[mask_squeeze],
attr={"to": onnx_pb.TensorProto.INT32})
sequence_lens = GraphBuilder(g).make_reduce_sum({"data": mask_cast_node.output[0],
"axes": [-1], "keepdims": 0})
context.onnx_input_ids[0]["sequence_lens"] = sequence_lens

# handle backward LSTM
tensor_array_inp_producer = tensor_array_inp.inputs[0]
is_input_reverse = tensor_array_inp_producer.type == "ReverseV2"
# backward LSTM is identified by the reverses of both input and mask tensors pre-LSTM
if is_mask_reverse != is_input_reverse:
continue
if is_input_reverse:
# TF uses simple "ReverseV2" to reverse input tensor with no assumption on padding position
# because reversed mask with shape (batch_size, timesteps) is explicit per-timestep.
# ONNX requires "ReverseSequence" to keep the reversed input tensor post-padded because mask
# is implied by sequence_lens. This requires passing sequence_lens to such "ReverseSequence" op.

# Note: tensor op conversions run after rewriters. Appending sequence_lens as a "ReverseV2" input
# signalizes alternative behavior in "ReverseV2" conversion in onnx_opset/tensor.py.
tensor_array_inp_producer.set_attr("has_sequence_lens", True)
inp_reverse_inputs = tensor_array_inp_producer.input
inp_reverse_inputs.append(sequence_lens)

context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0]
if body_context["seq_len_idx"] is None:
context.onnx_input_ids[0]["sequence_lens"] = ""
else:
context.onnx_input_ids[0]["sequence_lens"] = op.input[body_context["seq_len_idx"]]
context.onnx_input_ids[0]["sequence_lens"] = ""

context.onnx_input_ids[0]["initial_c"] = initial_c
context.onnx_input_ids[0]["initial_h"] = initial_h

Expand All @@ -295,6 +360,8 @@ def has_tensor_list_consumer(n):
lstm_node = lstm_rewriter.create_rnn_node(context)[0]

squeeze_output = GraphBuilder(g).make_squeeze({"data": lstm_node.output[0], "axes": [1]})
final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]])
output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"]
for output in output_ys:
g.replace_all_inputs(output, squeeze_output)

Expand Down

0 comments on commit d4e16b5

Please sign in to comment.