This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e178375e1b [Frontend][Relay][Keras] Fix concatenate convert function 
in axis parsing (#15175)
e178375e1b is described below

commit e178375e1bbaaa4255d59a23f66d0595fe445991
Author: Qingchao Shen <[email protected]>
AuthorDate: Sat Jul 1 21:16:03 2023 +0800

    [Frontend][Relay][Keras] Fix concatenate convert function in axis parsing 
(#15175)
---
 python/tvm/relay/frontend/keras.py          | 11 +++++++----
 tests/python/frontend/keras/test_forward.py | 19 +++++++++++++++++++
 2 files changed, 26 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relay/frontend/keras.py 
b/python/tvm/relay/frontend/keras.py
index 75ccec52f3..3dc0551084 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -955,10 +955,13 @@ def _convert_concat(
     if input_shape is None:
         input_shape = keras_layer.input_shape
 
-    if data_layout == "NHWC" or len(input_shape[0]) < 4:
-        axis = -1
-    else:
-        axis = 1
+    axis = keras_layer.axis
+    dims = len(input_shape[0])
+    if data_layout == "NCHW":  # need_transpose
+        if axis == -1:
+            axis = 1
+        else:
+            axis = axis + 1 if axis < dims else 1
     return _op.concatenate(_as_list(inexpr), axis=axis)
 
 
diff --git a/tests/python/frontend/keras/test_forward.py 
b/tests/python/frontend/keras/test_forward.py
index 842d803b17..2584a36e32 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -159,6 +159,24 @@ class TestKeras:
             keras_model = keras_mod.models.Model(data, out)
             verify_keras_frontend(keras_model)
 
+    def test_forward_concatenate(self, keras_mod):
+        """test_forward_concatenate"""
+        data1 = keras_mod.layers.Input(shape=(1, 2, 2))
+        data2 = keras_mod.layers.Input(shape=(1, 1, 2))
+        merge_func = keras_mod.layers.Concatenate(axis=2)
+        out = merge_func([data1, data2])
+        keras_model = keras_mod.models.Model([data1, data2], out)
+        verify_keras_frontend(keras_model, layout="NHWC")
+        verify_keras_frontend(keras_model, layout="NCHW")
+        # test default axis (e.g., -1)
+        data1 = keras_mod.layers.Input(shape=(1, 2, 2))
+        data2 = keras_mod.layers.Input(shape=(1, 2, 3))
+        merge_func = keras_mod.layers.Concatenate()
+        out = merge_func([data1, data2])
+        keras_model = keras_mod.models.Model([data1, data2], out)
+        verify_keras_frontend(keras_model, layout="NHWC")
+        verify_keras_frontend(keras_model, layout="NCHW")
+
     def test_forward_merge_dot(self, keras_mod):
         """test_forward_merge_dot"""
         data1 = keras_mod.layers.Input(shape=(2, 2))
@@ -793,6 +811,7 @@ class TestKeras:
 if __name__ == "__main__":
     for k in [keras, tf_keras]:
         sut = TestKeras()
+        sut.test_forward_concatenate(keras_mod=k)
         sut.test_forward_merge_dot(keras_mod=k)
         sut.test_forward_merge(keras_mod=k)
         sut.test_forward_activations(keras_mod=k)

Reply via email to