This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e178375e1b [Frontend][Relay][Keras] Fix concatenate convert function
in axis parsing (#15175)
e178375e1b is described below
commit e178375e1bbaaa4255d59a23f66d0595fe445991
Author: Qingchao Shen <[email protected]>
AuthorDate: Sat Jul 1 21:16:03 2023 +0800
[Frontend][Relay][Keras] Fix concatenate convert function in axis parsing
(#15175)
---
python/tvm/relay/frontend/keras.py | 11 +++++++----
tests/python/frontend/keras/test_forward.py | 19 +++++++++++++++++++
2 files changed, 26 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relay/frontend/keras.py
b/python/tvm/relay/frontend/keras.py
index 75ccec52f3..3dc0551084 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -955,10 +955,13 @@ def _convert_concat(
if input_shape is None:
input_shape = keras_layer.input_shape
- if data_layout == "NHWC" or len(input_shape[0]) < 4:
- axis = -1
- else:
- axis = 1
+ axis = keras_layer.axis
+ dims = len(input_shape[0])
+ if data_layout == "NCHW": # need_transpose
+ if axis == -1:
+ axis = 1
+ else:
+ axis = axis + 1 if axis < dims else 1
return _op.concatenate(_as_list(inexpr), axis=axis)
diff --git a/tests/python/frontend/keras/test_forward.py
b/tests/python/frontend/keras/test_forward.py
index 842d803b17..2584a36e32 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -159,6 +159,24 @@ class TestKeras:
keras_model = keras_mod.models.Model(data, out)
verify_keras_frontend(keras_model)
+ def test_forward_concatenate(self, keras_mod):
+ """test_forward_concatenate"""
+ data1 = keras_mod.layers.Input(shape=(1, 2, 2))
+ data2 = keras_mod.layers.Input(shape=(1, 1, 2))
+ merge_func = keras_mod.layers.Concatenate(axis=2)
+ out = merge_func([data1, data2])
+ keras_model = keras_mod.models.Model([data1, data2], out)
+ verify_keras_frontend(keras_model, layout="NHWC")
+ verify_keras_frontend(keras_model, layout="NCHW")
+ # test default axis (e.g., -1)
+ data1 = keras_mod.layers.Input(shape=(1, 2, 2))
+ data2 = keras_mod.layers.Input(shape=(1, 2, 3))
+ merge_func = keras_mod.layers.Concatenate()
+ out = merge_func([data1, data2])
+ keras_model = keras_mod.models.Model([data1, data2], out)
+ verify_keras_frontend(keras_model, layout="NHWC")
+ verify_keras_frontend(keras_model, layout="NCHW")
+
def test_forward_merge_dot(self, keras_mod):
"""test_forward_merge_dot"""
data1 = keras_mod.layers.Input(shape=(2, 2))
@@ -793,6 +811,7 @@ class TestKeras:
if __name__ == "__main__":
for k in [keras, tf_keras]:
sut = TestKeras()
+ sut.test_forward_concatenate(keras_mod=k)
sut.test_forward_merge_dot(keras_mod=k)
sut.test_forward_merge(keras_mod=k)
sut.test_forward_activations(keras_mod=k)