This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c87ebefcfd Concatenation corner case fix. (#11907)
c87ebefcfd is described below

commit c87ebefcfd756c33c08393b19a081a6fd4b657e7
Author: Sergey <[email protected]>
AuthorDate: Tue Jun 28 12:22:07 2022 +0300

    Concatenation corner case fix. (#11907)
    
    * Concatenation corner case fix.
    
    * lint fixes.
---
 python/tvm/topi/x86/concat.py        |  1 -
 tests/python/relay/test_op_level1.py | 14 ++++++++++++++
 2 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py
index 28f650bca9..435dd1636c 100644
--- a/python/tvm/topi/x86/concat.py
+++ b/python/tvm/topi/x86/concat.py
@@ -83,7 +83,6 @@ def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0):
 
     if (
         len(data[0].shape) == 1
-        or right_val == 1
         or (left_val == 1 and axis == len(data[0].shape) - 1)
         or (left_val == 1 and right_val == 1)
     ):
diff --git a/tests/python/relay/test_op_level1.py 
b/tests/python/relay/test_op_level1.py
index f4afc9e905..44df40d3b0 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -528,6 +528,20 @@ def test_concatenate3(target, dev):
         do_concat_test(shapes, t_shape, dtype, axis, dev, target)
 
 
[email protected]_targets("llvm")
+def test_concatenate4(target, dev):
+    np.random.seed(7)
+    x_shape = (2, 1)
+    x = relay.var("x", shape=x_shape, dtype="int64")
+    concat = relay.concatenate([x], axis=1)
+    f = relay.Function([x], concat)
+    x_val = np.array([[33], [13]], dtype="int64")
+    graph = relay.create_executor("graph", device=tvm.cpu(), target="llvm")
+    op_res = graph.evaluate(f)(x_val)
+    ref_res = np.concatenate([x_val], axis=1)
+    tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.000001)
+
+
 def test_batch_norm_fold_const():
     axis = 1
     dtype = "float32"

Reply via email to