This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7fd4704003 fix _convert_simple_rnn (#15723)
7fd4704003 is described below

commit 7fd4704003dec853abbfc15a47a0d07d941b7a8a
Author: Haoyang <[email protected]>
AuthorDate: Wed Sep 13 12:53:41 2023 +0800

    fix _convert_simple_rnn (#15723)
    
    * fix _convert_simple_rnn
    
    * fix _convert_simple_rnn
    
    * fix errors in the last pr
---
 python/tvm/relay/frontend/keras.py          | 25 ++++++++++++++-----------
 tests/python/frontend/keras/test_forward.py | 11 +++++++++++
 2 files changed, 25 insertions(+), 11 deletions(-)

diff --git a/python/tvm/relay/frontend/keras.py 
b/python/tvm/relay/frontend/keras.py
index 205b2be490..9e09cb400a 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -1052,23 +1052,26 @@ def _convert_simple_rnn(
         inexpr = [inexpr, prev_op]
     in_data = inexpr[0]
     prev_op = inexpr[1]
+    prev_op = _op.nn.batch_flatten(prev_op)
     weightList = keras_layer.get_weights()
     kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
     recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
-    if keras_layer.use_bias:
-        in_bias = etab.new_const(weightList[2])
     units = list(weightList[0].shape)[1]
     assert units > 0, "The value of units must be a positive integer"
-    in_data = _op.nn.batch_flatten(in_data)
-    ixh = _op.nn.dense(in_data, kernel_weight, units=units)
     if keras_layer.use_bias:
-        ixh = _op.nn.bias_add(ixh, bias=in_bias)
-    prev_op = _op.nn.batch_flatten(prev_op)
-    ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units)
-    output = ixh + ixh2
-    output = _convert_activation(output, keras_layer, etab, data_layout)
-    out_shape = tuple(dim if dim else 1 for dim in 
_as_list(keras_layer.output_shape)[0])
-    output = _op.reshape(output, newshape=out_shape)
+        in_bias = etab.new_const(weightList[2])
+    assert len(in_data.type_annotation.shape) == 3
+    timeDim = in_data.type_annotation.shape[1].value
+    in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1)
+    for i in range(len(in_data_split)):
+        in_data_split_i = _op.nn.batch_flatten(in_data_split[i])
+        ixh = _op.nn.dense(in_data_split_i, kernel_weight, units=units)
+        if keras_layer.use_bias:
+            ixh = _op.nn.bias_add(ixh, bias=in_bias)
+        ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units)
+        output = ixh + ixh2
+        output = _convert_activation(output, keras_layer, etab, data_layout)
+        prev_op = output
     return [output, output]
 
 
diff --git a/tests/python/frontend/keras/test_forward.py 
b/tests/python/frontend/keras/test_forward.py
index 80460f6063..9d33b15a91 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -825,6 +825,16 @@ class TestKeras:
         )
         verify_keras_frontend(dense_model, need_transpose=False)
 
+    def test_simplernn_with_infertype(self, keras_mod):
+        """This test case is from https://github.com/apache/tvm/issues/14868""";
+        input_shape = (2, 2, 2)
+        x = keras_mod.layers.Input(shape=input_shape[1:], dtype="float32")
+        layer = keras_mod.layers.SimpleRNN(units=4)
+        y = layer(x)
+        model = keras_mod.models.Model(x, y)
+        mod, _ = relay.frontend.from_keras(model, {model.input_names[0]: 
input_shape})
+        relay.transform.InferType()(mod)
+
 
 if __name__ == "__main__":
     for k in [keras, tf_keras]:
@@ -867,3 +877,4 @@ if __name__ == "__main__":
         sut.test_forward_repeat_vector(keras_mod=k)
         sut.test_forward_l2_normalize(keras_mod=k)
         sut.test_forward_time_distributed(keras_mod=k)
+        sut.test_simplernn_with_infertype(keras_mod=k)

Reply via email to