This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 31a4267  [ONNX] fix reduce crash on scalar inputs (#10780)
31a4267 is described below

commit 31a4267a1960754a91e2a5189a516598029b26b8
Author: Jiawei Liu <[email protected]>
AuthorDate: Fri Mar 25 14:51:11 2022 -0500

    [ONNX] fix reduce crash on scalar inputs (#10780)
    
    * fix reduce crash on scalar inputs
    
    * fix uncovered cases.
    
    * fix on different opset to pass ci
---
 python/tvm/relay/frontend/onnx.py          | 18 ++++++++++++++++++
 tests/python/frontend/onnx/test_forward.py |  2 ++
 2 files changed, 20 insertions(+)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index eea5008..04fb17a 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -1875,6 +1875,9 @@ class Reduce(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
+        if not infer_shape(inputs[0]):  # promote scalar to 1-D tensor
+            inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
         if "axes" in attr:
             axis = attr.get("axes", 0)
         else:
@@ -1885,6 +1888,9 @@ class Reduce(OnnxOpConverter):
 
     @classmethod
     def _impl_v12(cls, inputs, attr, params):
+        if not infer_shape(inputs[0]):  # promote scalar to 1-D tensor
+            inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
         if len(inputs) == 2:
             if isinstance(inputs[1], _expr.Constant):
                 # Get axis and unpack scalar
@@ -1937,6 +1943,9 @@ class ReduceSumSquare(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
+        if not infer_shape(inputs[0]):  # promote scalar to 1-D tensor
+            inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
         if "axes" in attr:
             axis = attr.get("axes", 0)
         else:
@@ -1953,6 +1962,9 @@ class ReduceL1(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
+        if not infer_shape(inputs[0]):  # promote scalar to 1-D tensor
+            inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
         if "axes" in attr:
             axis = attr.get("axes", 0)
         else:
@@ -1969,6 +1981,9 @@ class ReduceL2(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
+        if not infer_shape(inputs[0]):  # promote scalar to 1-D tensor
+            inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
         if "axes" in attr:
             axis = attr.get("axes", 0)
         else:
@@ -1986,6 +2001,9 @@ class ReduceLogSum(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
+        if not infer_shape(inputs[0]):  # promote scalar to 1-D tensor
+            inputs[0] = _op.expand_dims(inputs[0], axis=0)
+
         if "axes" in attr:
             axis = attr.get("axes", 0)
         else:
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index a526da5..91775d2 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -1934,6 +1934,8 @@ def test_all_reduce_funcs(target, dev):
     ]
 
     for func in funcs:
+        verify_reduce_func(func, np.array(1.0).astype(np.float32), axis=None, 
keepdims=False)
+
         for keepdims in [True, False]:
             verify_reduce_func(
                 func, np.random.randn(3, 2, 2).astype(np.float32), axis=None, 
keepdims=keepdims

Reply via email to