This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b8a1d63665 [Fix][Unity] Fix TVMError when using relax to load model 
with Trilu operator (#15924)
b8a1d63665 is described below

commit b8a1d63665625d0b904bdde200c8150ec9a19df7
Author: dmilosevic252 <[email protected]>
AuthorDate: Sun Oct 22 09:37:21 2023 +0200

    [Fix][Unity] Fix TVMError when using relax to load model with Trilu 
operator (#15924)
    
    * [Fix][Unity] Fix TVMError when using relax to load model with Trilu 
operator
    
    * Fix formatting
    
    * Add test with constant k
    
    * Fix formatting
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py |  7 +++++
 tests/python/frontend/onnx/test_forward.py      | 39 +++++++++++++++++++++++++
 2 files changed, 46 insertions(+)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index c4d38ae1bb..0b5aa4f7ec 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -645,6 +645,13 @@ class Trilu(OnnxOpConverter):
         x = inputs[0]
         k = inputs[1] if len(inputs) > 1 else 0
 
+        if isinstance(k, relax.Var) and k.name_hint in params:
+            k = get_constant(k, params)
+        elif isinstance(k, relax.Constant):
+            k = int(k.data.numpy()[0])
+        else:
+            k = 0
+
         if upper:
             return relax.op.triu(x, k)
         else:
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index b2132f3b81..51748462d0 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -5409,6 +5409,45 @@ def test_load_cumsum():
     from_onnx(create_cumsum_model())
 
 
+def test_load_trilu():
+    """test_load_trilu"""
+
+    def create_trilu_model():
+        input_shape = [2, 3, 3]
+
+        graph = helper.make_graph(
+            [
+                helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
+            ],
+            "trilu_graph",
+            inputs=[
+                helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, 
input_shape),
+                helper.make_tensor_value_info("k", onnx.TensorProto.INT32, 
[1], "k"),
+            ],
+            outputs=[helper.make_tensor_value_info("y", 
onnx.TensorProto.DOUBLE, input_shape)],
+        )
+        return helper.make_model(graph)
+
+    def create_trilu_model_const_k():
+        input_shape = [2, 3, 3]
+
+        graph = helper.make_graph(
+            [
+                make_constant_node("k", onnx.TensorProto.INT32, [1], [1]),
+                helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
+            ],
+            "trilu_graph",
+            inputs=[
+                helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, 
input_shape),
+            ],
+            outputs=[helper.make_tensor_value_info("y", 
onnx.TensorProto.DOUBLE, input_shape)],
+        )
+        return helper.make_model(graph)
+
+    from_onnx(create_trilu_model())
+    from_onnx(create_trilu_model_const_k())
+
+
 @tvm.testing.parametrize_targets
 def test_cumsum(target, dev):
     """test_cumsum"""

Reply via email to