This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1e789f3f08 [Fix][Unity] Fix TVMError when loading ONNX model with 
CumSum operator (#15804)
1e789f3f08 is described below

commit 1e789f3f088103652c1796e984f71469c9b60bb6
Author: dmilosevic252 <[email protected]>
AuthorDate: Thu Oct 12 12:02:07 2023 +0200

    [Fix][Unity] Fix TVMError when loading ONNX model with CumSum operator 
(#15804)
    
    * [Unity] Fix TVMError when loading ONNX model with CumSum operator
    
    * Add regression test for loading ONNX model with CumSum operator
    
    * Fix formatting
    
    * Fix spacing errors
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py |  2 ++
 tests/python/frontend/onnx/test_forward.py      | 23 +++++++++++++++++++++++
 2 files changed, 25 insertions(+)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 5333812c05..f0d0c00333 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -729,6 +729,8 @@ class CumSum(OnnxOpConverter):
 
         if isinstance(axis, relax.Constant):
             axis = int(axis.data.numpy())
+        elif isinstance(axis, relax.Var):
+            axis = 0
         data = relax.op.cumsum(data, axis)
         if attr.get("reverse", 0) != 0:
             data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index 9c9362aaf1..b2132f3b81 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -37,6 +37,7 @@ from tvm import relay
 from tvm.contrib import graph_executor, utils
 from tvm.relay.frontend.common import infer_type
 from tvm.relay.build_module import bind_params_by_name
+from tvm.relax.frontend.onnx import from_onnx
 from relay.utils.tag_span import _create_span, _set_span, 
_verify_structural_equal_with_span
 
 import onnx
@@ -5386,6 +5387,28 @@ def test_softplus(target, dev):
     verify_softplus(input_data)
 
 
+def test_load_cumsum():
+    """test_load_cumsum"""
+
+    def create_cumsum_model():
+        input_shape = [2, 3]
+
+        graph = helper.make_graph(
+            [
+                helper.make_node("CumSum", inputs=["X", "axis"], 
outputs=["Y"]),
+            ],
+            "cumsum_graph",
+            inputs=[
+                helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, 
input_shape),
+                helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, 
[1], "axis"),
+            ],
+            outputs=[helper.make_tensor_value_info("Y", 
onnx.TensorProto.DOUBLE, input_shape)],
+        )
+        return helper.make_model(graph)
+
+    from_onnx(create_cumsum_model())
+
+
 @tvm.testing.parametrize_targets
 def test_cumsum(target, dev):
     """test_cumsum"""

Reply via email to