This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7fd4704003 fix _convert_simple_rnn (#15723)
7fd4704003 is described below
commit 7fd4704003dec853abbfc15a47a0d07d941b7a8a
Author: Haoyang <[email protected]>
AuthorDate: Wed Sep 13 12:53:41 2023 +0800
fix _convert_simple_rnn (#15723)
* fix _convert_simple_rnn
* fix _convert_simple_rnn
* fix errors in the last pr
---
python/tvm/relay/frontend/keras.py | 25 ++++++++++++++-----------
tests/python/frontend/keras/test_forward.py | 11 +++++++++++
2 files changed, 25 insertions(+), 11 deletions(-)
diff --git a/python/tvm/relay/frontend/keras.py
b/python/tvm/relay/frontend/keras.py
index 205b2be490..9e09cb400a 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -1052,23 +1052,26 @@ def _convert_simple_rnn(
inexpr = [inexpr, prev_op]
in_data = inexpr[0]
prev_op = inexpr[1]
+ prev_op = _op.nn.batch_flatten(prev_op)
weightList = keras_layer.get_weights()
kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
- if keras_layer.use_bias:
- in_bias = etab.new_const(weightList[2])
units = list(weightList[0].shape)[1]
assert units > 0, "The value of units must be a positive integer"
- in_data = _op.nn.batch_flatten(in_data)
- ixh = _op.nn.dense(in_data, kernel_weight, units=units)
if keras_layer.use_bias:
- ixh = _op.nn.bias_add(ixh, bias=in_bias)
- prev_op = _op.nn.batch_flatten(prev_op)
- ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units)
- output = ixh + ixh2
- output = _convert_activation(output, keras_layer, etab, data_layout)
- out_shape = tuple(dim if dim else 1 for dim in
_as_list(keras_layer.output_shape)[0])
- output = _op.reshape(output, newshape=out_shape)
+ in_bias = etab.new_const(weightList[2])
+ assert len(in_data.type_annotation.shape) == 3
+ timeDim = in_data.type_annotation.shape[1].value
+ in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1)
+ for i in range(len(in_data_split)):
+ in_data_split_i = _op.nn.batch_flatten(in_data_split[i])
+ ixh = _op.nn.dense(in_data_split_i, kernel_weight, units=units)
+ if keras_layer.use_bias:
+ ixh = _op.nn.bias_add(ixh, bias=in_bias)
+ ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units)
+ output = ixh + ixh2
+ output = _convert_activation(output, keras_layer, etab, data_layout)
+ prev_op = output
return [output, output]
diff --git a/tests/python/frontend/keras/test_forward.py
b/tests/python/frontend/keras/test_forward.py
index 80460f6063..9d33b15a91 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -825,6 +825,16 @@ class TestKeras:
)
verify_keras_frontend(dense_model, need_transpose=False)
+ def test_simplernn_with_infertype(self, keras_mod):
+ """This test case is from https://github.com/apache/tvm/issues/14868"""
+ input_shape = (2, 2, 2)
+ x = keras_mod.layers.Input(shape=input_shape[1:], dtype="float32")
+ layer = keras_mod.layers.SimpleRNN(units=4)
+ y = layer(x)
+ model = keras_mod.models.Model(x, y)
+ mod, _ = relay.frontend.from_keras(model, {model.input_names[0]:
input_shape})
+ relay.transform.InferType()(mod)
+
if __name__ == "__main__":
for k in [keras, tf_keras]:
@@ -867,3 +877,4 @@ if __name__ == "__main__":
sut.test_forward_repeat_vector(keras_mod=k)
sut.test_forward_l2_normalize(keras_mod=k)
sut.test_forward_time_distributed(keras_mod=k)
+ sut.test_simplernn_with_infertype(keras_mod=k)