This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 63b6b49f03 fix expand onnx conversion (#11278)
63b6b49f03 is described below

commit 63b6b49f030c93dd55dee8b3fd5760638b35782b
Author: Jiawei Liu <[email protected]>
AuthorDate: Wed May 11 15:38:21 2022 -0500

    fix expand onnx conversion (#11278)
---
 python/tvm/relay/frontend/onnx.py          |  2 +-
 tests/python/frontend/onnx/test_forward.py | 11 +++++++++--
 2 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index d27ff00a01..036b5a9146 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2590,7 +2590,7 @@ class Expand(OnnxOpConverter):
                     ],
                     axis=0,
                 )
-            elif new_dims > in_dims:
+            elif new_dims < in_dims:
                 shape = _op.concatenate(
                     [
                         _expr.const(
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index 904a33fae9..03f0cb3bad 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -402,8 +402,15 @@ def test_expand(target, dev):
     shape = (2, 1, 6)
     data = np.random.uniform(size=in_shape).astype(np.float32)
     ref_data = data * np.ones(shape, dtype=np.float32)
-    _test_expand("expand_with_dim_changed_test", data, shape, ref_data, 
"int32")
-    _test_expand("expand_with_dim_changed_test", data, shape, ref_data, 
"int64")
+    _test_expand("expand_larger_target_shape_test", data, shape, ref_data, 
"int32")
+    _test_expand("expand_larger_target_shape_test", data, shape, ref_data, 
"int64")
+
+    in_shape = (1, 1)
+    shape = (3,)
+    data = np.random.uniform(size=in_shape).astype(np.float32)
+    ref_data = data * np.ones(shape, dtype=np.float32)
+    _test_expand("expand_smaller_target_shape_test", data, shape, ref_data, 
"int32")
+    _test_expand("expand_smaller_target_shape_test", data, shape, ref_data, 
"int64")
 
 
 @tvm.testing.parametrize_targets

Reply via email to