This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new b8a1d63665 [Fix][Unity] Fix TVMError when using relax to load model
with Trilu operator (#15924)
b8a1d63665 is described below
commit b8a1d63665625d0b904bdde200c8150ec9a19df7
Author: dmilosevic252 <[email protected]>
AuthorDate: Sun Oct 22 09:37:21 2023 +0200
[Fix][Unity] Fix TVMError when using relax to load model with Trilu
operator (#15924)
* [Fix][Unity] Fix TVMError when using relax to load model with Trilu
operator
* Fix formatting
* Add test with constant k
* Fix formatting
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 7 +++++
tests/python/frontend/onnx/test_forward.py | 39 +++++++++++++++++++++++++
2 files changed, 46 insertions(+)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index c4d38ae1bb..0b5aa4f7ec 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -645,6 +645,13 @@ class Trilu(OnnxOpConverter):
x = inputs[0]
k = inputs[1] if len(inputs) > 1 else 0
+ if isinstance(k, relax.Var) and k.name_hint in params:
+ k = get_constant(k, params)
+ elif isinstance(k, relax.Constant):
+ k = int(k.data.numpy()[0])
+ else:
+ k = 0
+
if upper:
return relax.op.triu(x, k)
else:
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index b2132f3b81..51748462d0 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5409,6 +5409,45 @@ def test_load_cumsum():
from_onnx(create_cumsum_model())
+def test_load_trilu():
+ """test_load_trilu"""
+
+ def create_trilu_model():
+ input_shape = [2, 3, 3]
+
+ graph = helper.make_graph(
+ [
+ helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
+ ],
+ "trilu_graph",
+ inputs=[
+ helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE,
input_shape),
+ helper.make_tensor_value_info("k", onnx.TensorProto.INT32,
[1], "k"),
+ ],
+ outputs=[helper.make_tensor_value_info("y",
onnx.TensorProto.DOUBLE, input_shape)],
+ )
+ return helper.make_model(graph)
+
+ def create_trilu_model_const_k():
+ input_shape = [2, 3, 3]
+
+ graph = helper.make_graph(
+ [
+ make_constant_node("k", onnx.TensorProto.INT32, [1], [1]),
+ helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
+ ],
+ "trilu_graph",
+ inputs=[
+ helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE,
input_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("y",
onnx.TensorProto.DOUBLE, input_shape)],
+ )
+ return helper.make_model(graph)
+
+ from_onnx(create_trilu_model())
+ from_onnx(create_trilu_model_const_k())
+
+
@tvm.testing.parametrize_targets
def test_cumsum(target, dev):
"""test_cumsum"""