This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c87ebefcfd Concatenation corner case fix. (#11907)
c87ebefcfd is described below
commit c87ebefcfd756c33c08393b19a081a6fd4b657e7
Author: Sergey <[email protected]>
AuthorDate: Tue Jun 28 12:22:07 2022 +0300
Concatenation corner case fix. (#11907)
* Concatenation corner case fix.
* lint fixes.
---
python/tvm/topi/x86/concat.py | 1 -
tests/python/relay/test_op_level1.py | 14 ++++++++++++++
2 files changed, 14 insertions(+), 1 deletion(-)
diff --git a/python/tvm/topi/x86/concat.py b/python/tvm/topi/x86/concat.py
index 28f650bca9..435dd1636c 100644
--- a/python/tvm/topi/x86/concat.py
+++ b/python/tvm/topi/x86/concat.py
@@ -83,7 +83,6 @@ def concatenate(data: tvm.te.Tensor, axis: Optional[int] = 0):
if (
len(data[0].shape) == 1
- or right_val == 1
or (left_val == 1 and axis == len(data[0].shape) - 1)
or (left_val == 1 and right_val == 1)
):
diff --git a/tests/python/relay/test_op_level1.py
b/tests/python/relay/test_op_level1.py
index f4afc9e905..44df40d3b0 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -528,6 +528,20 @@ def test_concatenate3(target, dev):
do_concat_test(shapes, t_shape, dtype, axis, dev, target)
[email protected]_targets("llvm")
+def test_concatenate4(target, dev):
+ np.random.seed(7)
+ x_shape = (2, 1)
+ x = relay.var("x", shape=x_shape, dtype="int64")
+ concat = relay.concatenate([x], axis=1)
+ f = relay.Function([x], concat)
+ x_val = np.array([[33], [13]], dtype="int64")
+ graph = relay.create_executor("graph", device=tvm.cpu(), target="llvm")
+ op_res = graph.evaluate(f)(x_val)
+ ref_res = np.concatenate([x_val], axis=1)
+ tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.000001)
+
+
def test_batch_norm_fold_const():
axis = 1
dtype = "float32"