This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1e789f3f08 [Fix][Unity] Fix TVMError when loading ONNX model with
CumSum operator (#15804)
1e789f3f08 is described below
commit 1e789f3f088103652c1796e984f71469c9b60bb6
Author: dmilosevic252 <[email protected]>
AuthorDate: Thu Oct 12 12:02:07 2023 +0200
[Fix][Unity] Fix TVMError when loading ONNX model with CumSum operator
(#15804)
* [Unity] Fix TVMError when loading ONNX model with CumSum operator
* Add regression test for loading ONNX model with CumSum operator
* Fix formatting
* Fix spacing errors
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 ++
tests/python/frontend/onnx/test_forward.py | 23 +++++++++++++++++++++++
2 files changed, 25 insertions(+)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 5333812c05..f0d0c00333 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -729,6 +729,8 @@ class CumSum(OnnxOpConverter):
if isinstance(axis, relax.Constant):
axis = int(axis.data.numpy())
+ elif isinstance(axis, relax.Var):
+ axis = 0
data = relax.op.cumsum(data, axis)
if attr.get("reverse", 0) != 0:
data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index 9c9362aaf1..b2132f3b81 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -37,6 +37,7 @@ from tvm import relay
from tvm.contrib import graph_executor, utils
from tvm.relay.frontend.common import infer_type
from tvm.relay.build_module import bind_params_by_name
+from tvm.relax.frontend.onnx import from_onnx
from relay.utils.tag_span import _create_span, _set_span,
_verify_structural_equal_with_span
import onnx
@@ -5386,6 +5387,28 @@ def test_softplus(target, dev):
verify_softplus(input_data)
+def test_load_cumsum():
+ """test_load_cumsum"""
+
+ def create_cumsum_model():
+ input_shape = [2, 3]
+
+ graph = helper.make_graph(
+ [
+ helper.make_node("CumSum", inputs=["X", "axis"],
outputs=["Y"]),
+ ],
+ "cumsum_graph",
+ inputs=[
+ helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE,
input_shape),
+ helper.make_tensor_value_info("axis", onnx.TensorProto.INT32,
[1], "axis"),
+ ],
+ outputs=[helper.make_tensor_value_info("Y",
onnx.TensorProto.DOUBLE, input_shape)],
+ )
+ return helper.make_model(graph)
+
+ from_onnx(create_cumsum_model())
+
+
@tvm.testing.parametrize_targets
def test_cumsum(target, dev):
"""test_cumsum"""