This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 28908998e0 [Relay][Keras][Bugfix] fix the converters of GRU and
SimpleRNN about the go_backwards attribute (#15829)
28908998e0 is described below
commit 28908998e0c55025a89e8e2bd26a3fe3e6c84356
Author: Qingchao Shen <[email protected]>
AuthorDate: Fri Sep 29 15:54:23 2023 +0800
[Relay][Keras][Bugfix] fix the converters of GRU and SimpleRNN about the
go_backwards attribute (#15829)
* fix bug in gru and simpleRNN about go_backwards
* Update test_forward.py
* Update keras.py
---
python/tvm/relay/frontend/keras.py | 4 ++++
tests/python/frontend/keras/test_forward.py | 12 ++++++++++++
2 files changed, 16 insertions(+)
diff --git a/python/tvm/relay/frontend/keras.py
b/python/tvm/relay/frontend/keras.py
index 9e09cb400a..6c82ebb427 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -1062,6 +1062,8 @@ def _convert_simple_rnn(
in_bias = etab.new_const(weightList[2])
assert len(in_data.type_annotation.shape) == 3
timeDim = in_data.type_annotation.shape[1].value
+ if keras_layer.go_backwards:
+ in_data = _op.reverse(in_data, axis=1)
in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1)
for i in range(len(in_data_split)):
in_data_split_i = _op.nn.batch_flatten(in_data_split[i])
@@ -1090,6 +1092,8 @@ def _convert_gru(
recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
if keras_layer.use_bias:
in_bias = etab.new_const(weightList[2])
+ if keras_layer.go_backwards:
+ in_data = _op.reverse(in_data, axis=1)
units = list(weightList[0].shape)[1]
assert units > 0, "The value of units must be a positive integer"
in_data = _op.nn.batch_flatten(in_data)
diff --git a/tests/python/frontend/keras/test_forward.py
b/tests/python/frontend/keras/test_forward.py
index ba3880e186..8c5b578060 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -568,12 +568,23 @@ class TestKeras:
keras_mod.layers.SimpleRNN(
units=16, return_state=False, activation="tanh", use_bias=False
),
+ keras_mod.layers.SimpleRNN(
+ units=16, return_state=False, activation="tanh",
go_backwards=True
+ ),
+ keras_mod.layers.GRU(
+ units=16,
+ return_state=False,
+ recurrent_activation="sigmoid",
+ activation="tanh",
+ reset_after=False,
+ ),
keras_mod.layers.GRU(
units=16,
return_state=False,
recurrent_activation="sigmoid",
activation="tanh",
reset_after=False,
+ use_bias=False,
),
keras_mod.layers.GRU(
units=16,
@@ -582,6 +593,7 @@ class TestKeras:
activation="tanh",
reset_after=False,
use_bias=False,
+ go_backwards=True,
),
]
for rnn_func in rnn_funcs: