This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 28908998e0 [Relay][Keras][Bugfix] fix the converters of GRU and 
SimpleRNN about the go_backwards attribute (#15829)
28908998e0 is described below

commit 28908998e0c55025a89e8e2bd26a3fe3e6c84356
Author: Qingchao Shen <[email protected]>
AuthorDate: Fri Sep 29 15:54:23 2023 +0800

    [Relay][Keras][Bugfix] fix the converters of GRU and SimpleRNN about the 
go_backwards attribute (#15829)
    
    * fix bug in gru and simpleRNN about go_backwards
    
    * Update test_forward.py
    
    * Update keras.py
---
 python/tvm/relay/frontend/keras.py          |  4 ++++
 tests/python/frontend/keras/test_forward.py | 12 ++++++++++++
 2 files changed, 16 insertions(+)

diff --git a/python/tvm/relay/frontend/keras.py 
b/python/tvm/relay/frontend/keras.py
index 9e09cb400a..6c82ebb427 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -1062,6 +1062,8 @@ def _convert_simple_rnn(
         in_bias = etab.new_const(weightList[2])
     assert len(in_data.type_annotation.shape) == 3
     timeDim = in_data.type_annotation.shape[1].value
+    if keras_layer.go_backwards:
+        in_data = _op.reverse(in_data, axis=1)
     in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1)
     for i in range(len(in_data_split)):
         in_data_split_i = _op.nn.batch_flatten(in_data_split[i])
@@ -1090,6 +1092,8 @@ def _convert_gru(
     recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
     if keras_layer.use_bias:
         in_bias = etab.new_const(weightList[2])
+    if keras_layer.go_backwards:
+        in_data = _op.reverse(in_data, axis=1)
     units = list(weightList[0].shape)[1]
     assert units > 0, "The value of units must be a positive integer"
     in_data = _op.nn.batch_flatten(in_data)
diff --git a/tests/python/frontend/keras/test_forward.py 
b/tests/python/frontend/keras/test_forward.py
index ba3880e186..8c5b578060 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -568,12 +568,23 @@ class TestKeras:
             keras_mod.layers.SimpleRNN(
                 units=16, return_state=False, activation="tanh", use_bias=False
             ),
+            keras_mod.layers.SimpleRNN(
+                units=16, return_state=False, activation="tanh", 
go_backwards=True
+            ),
+            keras_mod.layers.GRU(
+                units=16,
+                return_state=False,
+                recurrent_activation="sigmoid",
+                activation="tanh",
+                reset_after=False,
+            ),
             keras_mod.layers.GRU(
                 units=16,
                 return_state=False,
                 recurrent_activation="sigmoid",
                 activation="tanh",
                 reset_after=False,
+                use_bias=False,
             ),
             keras_mod.layers.GRU(
                 units=16,
@@ -582,6 +593,7 @@ class TestKeras:
                 activation="tanh",
                 reset_after=False,
                 use_bias=False,
+                go_backwards=True,
             ),
         ]
         for rnn_func in rnn_funcs:

Reply via email to