This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e2d6511161 [Bugfix][Frontend][Keras]Fix a corner case bug in softmax
converter of keras frontend (#15337)
e2d6511161 is described below
commit e2d65111616dfa95797c0dd7e082e4050b71701d
Author: Qingchao Shen <[email protected]>
AuthorDate: Tue Jul 18 13:02:34 2023 +0800
[Bugfix][Frontend][Keras]Fix a corner case bug in softmax converter of
keras frontend (#15337)
* Fix softmax converter about keras
* add new test cases to capture the bug
* Update keras.py
---
python/tvm/relay/frontend/keras.py | 6 ++++--
tests/python/frontend/keras/test_forward.py | 7 +++++++
2 files changed, 11 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/frontend/keras.py
b/python/tvm/relay/frontend/keras.py
index 1913d4a268..aba4160695 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -131,11 +131,13 @@ def _convert_advanced_activation(inexpr, keras_layer,
etab, data_layout, input_s
if act_type == "Softmax":
axis = keras_layer.axis
- dims = len(input_shape)
+ dims = len(input_shape) if input_shape else 0
if isinstance(axis, list):
raise tvm.error.OpAttributeUnImplemented(f"Softmax with axes
{axis} is not supported.")
if data_layout == "NCHW":
- if axis == -1:
+ if dims == 0:
+ axis = 0
+ elif axis == -1:
axis = 1
else:
axis = axis + 1 if axis < dims - 1 else 1
diff --git a/tests/python/frontend/keras/test_forward.py
b/tests/python/frontend/keras/test_forward.py
index 50a0e98505..53e2ca8dbe 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -229,6 +229,13 @@ class TestKeras:
keras_model = keras_mod.models.Model(data, x)
verify_keras_frontend(keras_model)
verify_keras_frontend(keras_model, need_transpose=False,
layout="NHWC")
+ # Test the input dimension = 1
+ data = keras_mod.layers.Input(shape=(11,))
+ act_func = keras_mod.layers.Softmax()
+ x = act_func(data)
+ keras_model = keras_mod.models.Model(data, x)
+ verify_keras_frontend(keras_model)
+ verify_keras_frontend(keras_model, need_transpose=False, layout="NHWC")
def test_forward_activations_except(self, keras_mod):
"""