This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 63b6b49f03 fix expand onnx conversion (#11278)
63b6b49f03 is described below
commit 63b6b49f030c93dd55dee8b3fd5760638b35782b
Author: Jiawei Liu <[email protected]>
AuthorDate: Wed May 11 15:38:21 2022 -0500
fix expand onnx conversion (#11278)
---
python/tvm/relay/frontend/onnx.py | 2 +-
tests/python/frontend/onnx/test_forward.py | 11 +++++++++--
2 files changed, 10 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index d27ff00a01..036b5a9146 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -2590,7 +2590,7 @@ class Expand(OnnxOpConverter):
],
axis=0,
)
- elif new_dims > in_dims:
+ elif new_dims < in_dims:
shape = _op.concatenate(
[
_expr.const(
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index 904a33fae9..03f0cb3bad 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -402,8 +402,15 @@ def test_expand(target, dev):
shape = (2, 1, 6)
data = np.random.uniform(size=in_shape).astype(np.float32)
ref_data = data * np.ones(shape, dtype=np.float32)
- _test_expand("expand_with_dim_changed_test", data, shape, ref_data,
"int32")
- _test_expand("expand_with_dim_changed_test", data, shape, ref_data,
"int64")
+ _test_expand("expand_larger_target_shape_test", data, shape, ref_data,
"int32")
+ _test_expand("expand_larger_target_shape_test", data, shape, ref_data,
"int64")
+
+ in_shape = (1, 1)
+ shape = (3,)
+ data = np.random.uniform(size=in_shape).astype(np.float32)
+ ref_data = data * np.ones(shape, dtype=np.float32)
+ _test_expand("expand_smaller_target_shape_test", data, shape, ref_data,
"int32")
+ _test_expand("expand_smaller_target_shape_test", data, shape, ref_data,
"int64")
@tvm.testing.parametrize_targets