This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 31a4267 [ONNX] fix reduce crash on scalar inputs (#10780)
31a4267 is described below
commit 31a4267a1960754a91e2a5189a516598029b26b8
Author: Jiawei Liu <[email protected]>
AuthorDate: Fri Mar 25 14:51:11 2022 -0500
[ONNX] fix reduce crash on scalar inputs (#10780)
* fix reduce crash on scalar inputs
* fix uncovered cases.
* fix on different opset to pass ci
---
python/tvm/relay/frontend/onnx.py | 18 ++++++++++++++++++
tests/python/frontend/onnx/test_forward.py | 2 ++
2 files changed, 20 insertions(+)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index eea5008..04fb17a 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1875,6 +1875,9 @@ class Reduce(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
+ if not infer_shape(inputs[0]): # promote scalar to 1-D tensor
+ inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
if "axes" in attr:
axis = attr.get("axes", 0)
else:
@@ -1885,6 +1888,9 @@ class Reduce(OnnxOpConverter):
@classmethod
def _impl_v12(cls, inputs, attr, params):
+ if not infer_shape(inputs[0]): # promote scalar to 1-D tensor
+ inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
if len(inputs) == 2:
if isinstance(inputs[1], _expr.Constant):
# Get axis and unpack scalar
@@ -1937,6 +1943,9 @@ class ReduceSumSquare(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
+ if not infer_shape(inputs[0]): # promote scalar to 1-D tensor
+ inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
if "axes" in attr:
axis = attr.get("axes", 0)
else:
@@ -1953,6 +1962,9 @@ class ReduceL1(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
+ if not infer_shape(inputs[0]): # promote scalar to 1-D tensor
+ inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
if "axes" in attr:
axis = attr.get("axes", 0)
else:
@@ -1969,6 +1981,9 @@ class ReduceL2(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
+ if not infer_shape(inputs[0]): # promote scalar to 1-D tensor
+ inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
if "axes" in attr:
axis = attr.get("axes", 0)
else:
@@ -1986,6 +2001,9 @@ class ReduceLogSum(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
+ if not infer_shape(inputs[0]): # promote scalar to 1-D tensor
+ inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
if "axes" in attr:
axis = attr.get("axes", 0)
else:
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index a526da5..91775d2 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -1934,6 +1934,8 @@ def test_all_reduce_funcs(target, dev):
]
for func in funcs:
+ verify_reduce_func(func, np.array(1.0).astype(np.float32), axis=None,
keepdims=False)
+
for keepdims in [True, False]:
verify_reduce_func(
func, np.random.randn(3, 2, 2).astype(np.float32), axis=None,
keepdims=keepdims