This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ba80646639 [ONNX] Move relax related tests to the correct file (#17447)
ba80646639 is described below

commit ba80646639d863a07e360dc377d592d1469efb73
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Oct 7 21:38:44 2024 +0800

    [ONNX] Move relax related tests to the correct file (#17447)
    
    There are a few relax tests in `tests/python/frontend/onnx/test_forward.py`,
    which is used for relay frontend. This commit moves them to the correct 
file.
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 10 ++--
 tests/python/frontend/onnx/test_forward.py      | 62 -------------------------
 tests/python/relax/test_frontend_onnx.py        | 43 +++++++++++++++++
 3 files changed, 49 insertions(+), 66 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 5777f51fe2..36a7823f86 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -740,10 +740,12 @@ class Trilu(OnnxOpConverter):
         x = inputs[0]
         k = inputs[1] if len(inputs) > 1 else 0
 
-        if isinstance(k, relax.Var) and k.name_hint in params:
-            k = get_constant(k, params)
-        elif isinstance(k, relax.Constant):
-            k = int(k.data.numpy()[0])
+        if len(inputs) > 1:
+            k = get_constant(inputs[1], params)
+            if isinstance(k, relax.Constant):
+                k = int(k.data.numpy()[0])
+            else:
+                raise ValueError("Currently only support constant k for Trilu 
op.")
         else:
             k = 0
 
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index a5811d0dbd..a81352bb67 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -37,7 +37,6 @@ from tvm import relay
 from tvm.contrib import graph_executor, utils
 from tvm.relay.frontend.common import infer_type
 from tvm.relay.build_module import bind_params_by_name
-from tvm.relax.frontend.onnx import from_onnx
 from relay.utils.tag_span import _create_span, _set_span, 
_verify_structural_equal_with_span
 
 import onnx
@@ -5441,67 +5440,6 @@ def test_softplus(target, dev):
     verify_softplus(input_data)
 
 
-def test_load_cumsum():
-    """test_load_cumsum"""
-
-    def create_cumsum_model():
-        input_shape = [2, 3]
-
-        graph = helper.make_graph(
-            [
-                helper.make_node("CumSum", inputs=["X", "axis"], 
outputs=["Y"]),
-            ],
-            "cumsum_graph",
-            inputs=[
-                helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, 
input_shape),
-                helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, 
[1], "axis"),
-            ],
-            outputs=[helper.make_tensor_value_info("Y", 
onnx.TensorProto.DOUBLE, input_shape)],
-        )
-        return helper.make_model(graph)
-
-    from_onnx(create_cumsum_model())
-
-
-def test_load_trilu():
-    """test_load_trilu"""
-
-    def create_trilu_model():
-        input_shape = [2, 3, 3]
-
-        graph = helper.make_graph(
-            [
-                helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
-            ],
-            "trilu_graph",
-            inputs=[
-                helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, 
input_shape),
-                helper.make_tensor_value_info("k", onnx.TensorProto.INT32, 
[1], "k"),
-            ],
-            outputs=[helper.make_tensor_value_info("y", 
onnx.TensorProto.DOUBLE, input_shape)],
-        )
-        return helper.make_model(graph)
-
-    def create_trilu_model_const_k():
-        input_shape = [2, 3, 3]
-
-        graph = helper.make_graph(
-            [
-                make_constant_node("k", onnx.TensorProto.INT32, [1], [1]),
-                helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
-            ],
-            "trilu_graph",
-            inputs=[
-                helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, 
input_shape),
-            ],
-            outputs=[helper.make_tensor_value_info("y", 
onnx.TensorProto.DOUBLE, input_shape)],
-        )
-        return helper.make_model(graph)
-
-    from_onnx(create_trilu_model())
-    from_onnx(create_trilu_model_const_k())
-
-
 @tvm.testing.parametrize_targets
 def test_cumsum(target, dev):
     """test_cumsum"""
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 2837ad2185..f2bbd3f3f5 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -710,6 +710,28 @@ def test_trilu(upper: bool):
     verify_unary("Trilu", [3, 5, 5], attrs={"upper": upper})
 
 
[email protected]("k_value", [-1, 0, 1])
+def test_trilu_with_const_k(k_value: int):
+    """test_trilu_with_const_k"""
+
+    input_shape = [2, 3, 3]
+
+    graph = helper.make_graph(
+        [
+            make_constant_node("k", onnx.TensorProto.INT64, [1], [k_value]),
+            helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
+        ],
+        "trilu_graph",
+        inputs=[
+            helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, 
input_shape),
+        ],
+        outputs=[helper.make_tensor_value_info("y", onnx.TensorProto.DOUBLE, 
input_shape)],
+    )
+
+    model = helper.make_model(graph, producer_name="trilu_graph")
+    check_correctness(model)
+
+
 def test_selu():
     verify_unary("Selu", [3, 32, 32])
     verify_unary("Selu", [3, 32, 32], attrs={"alpha": 0.25, "gamma": 0.3})
@@ -859,6 +881,27 @@ def test_cumsum(reverse, exclusive):
     check_correctness(model)
 
 
+def test_cumsum1():
+    """test_cumsum1"""
+
+    input_shape = [2, 3]
+
+    graph = helper.make_graph(
+        [
+            helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]),
+        ],
+        "cumsum_graph",
+        inputs=[
+            helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, 
input_shape),
+            helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1], 
"axis"),
+        ],
+        outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, 
input_shape)],
+    )
+
+    model = helper.make_model(graph, producer_name="cumsum_graph")
+    check_correctness(model)
+
+
 @pytest.mark.parametrize("axis", [[0, 2], None])
 def test_squeeze(axis):
     if axis:

Reply via email to