This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ba80646639 [ONNX] Move relax related tests to the correct file (#17447)
ba80646639 is described below
commit ba80646639d863a07e360dc377d592d1469efb73
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Oct 7 21:38:44 2024 +0800
[ONNX] Move relax related tests to the correct file (#17447)
There are a few relax tests in `tests/python/frontend/onnx/test_forward.py`,
which is used for relay frontend. This commit moves them to the correct
file.
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 10 ++--
tests/python/frontend/onnx/test_forward.py | 62 -------------------------
tests/python/relax/test_frontend_onnx.py | 43 +++++++++++++++++
3 files changed, 49 insertions(+), 66 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 5777f51fe2..36a7823f86 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -740,10 +740,12 @@ class Trilu(OnnxOpConverter):
x = inputs[0]
k = inputs[1] if len(inputs) > 1 else 0
- if isinstance(k, relax.Var) and k.name_hint in params:
- k = get_constant(k, params)
- elif isinstance(k, relax.Constant):
- k = int(k.data.numpy()[0])
+ if len(inputs) > 1:
+ k = get_constant(inputs[1], params)
+ if isinstance(k, relax.Constant):
+ k = int(k.data.numpy()[0])
+ else:
+ raise ValueError("Currently only support constant k for Trilu
op.")
else:
k = 0
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index a5811d0dbd..a81352bb67 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -37,7 +37,6 @@ from tvm import relay
from tvm.contrib import graph_executor, utils
from tvm.relay.frontend.common import infer_type
from tvm.relay.build_module import bind_params_by_name
-from tvm.relax.frontend.onnx import from_onnx
from relay.utils.tag_span import _create_span, _set_span,
_verify_structural_equal_with_span
import onnx
@@ -5441,67 +5440,6 @@ def test_softplus(target, dev):
verify_softplus(input_data)
-def test_load_cumsum():
- """test_load_cumsum"""
-
- def create_cumsum_model():
- input_shape = [2, 3]
-
- graph = helper.make_graph(
- [
- helper.make_node("CumSum", inputs=["X", "axis"],
outputs=["Y"]),
- ],
- "cumsum_graph",
- inputs=[
- helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE,
input_shape),
- helper.make_tensor_value_info("axis", onnx.TensorProto.INT32,
[1], "axis"),
- ],
- outputs=[helper.make_tensor_value_info("Y",
onnx.TensorProto.DOUBLE, input_shape)],
- )
- return helper.make_model(graph)
-
- from_onnx(create_cumsum_model())
-
-
-def test_load_trilu():
- """test_load_trilu"""
-
- def create_trilu_model():
- input_shape = [2, 3, 3]
-
- graph = helper.make_graph(
- [
- helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
- ],
- "trilu_graph",
- inputs=[
- helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE,
input_shape),
- helper.make_tensor_value_info("k", onnx.TensorProto.INT32,
[1], "k"),
- ],
- outputs=[helper.make_tensor_value_info("y",
onnx.TensorProto.DOUBLE, input_shape)],
- )
- return helper.make_model(graph)
-
- def create_trilu_model_const_k():
- input_shape = [2, 3, 3]
-
- graph = helper.make_graph(
- [
- make_constant_node("k", onnx.TensorProto.INT32, [1], [1]),
- helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
- ],
- "trilu_graph",
- inputs=[
- helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE,
input_shape),
- ],
- outputs=[helper.make_tensor_value_info("y",
onnx.TensorProto.DOUBLE, input_shape)],
- )
- return helper.make_model(graph)
-
- from_onnx(create_trilu_model())
- from_onnx(create_trilu_model_const_k())
-
-
@tvm.testing.parametrize_targets
def test_cumsum(target, dev):
"""test_cumsum"""
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 2837ad2185..f2bbd3f3f5 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -710,6 +710,28 @@ def test_trilu(upper: bool):
verify_unary("Trilu", [3, 5, 5], attrs={"upper": upper})
[email protected]("k_value", [-1, 0, 1])
+def test_trilu_with_const_k(k_value: int):
+ """test_trilu_with_const_k"""
+
+ input_shape = [2, 3, 3]
+
+ graph = helper.make_graph(
+ [
+ make_constant_node("k", onnx.TensorProto.INT64, [1], [k_value]),
+ helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
+ ],
+ "trilu_graph",
+ inputs=[
+ helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE,
input_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("y", onnx.TensorProto.DOUBLE,
input_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="trilu_graph")
+ check_correctness(model)
+
+
def test_selu():
verify_unary("Selu", [3, 32, 32])
verify_unary("Selu", [3, 32, 32], attrs={"alpha": 0.25, "gamma": 0.3})
@@ -859,6 +881,27 @@ def test_cumsum(reverse, exclusive):
check_correctness(model)
+def test_cumsum1():
+ """test_cumsum1"""
+
+ input_shape = [2, 3]
+
+ graph = helper.make_graph(
+ [
+ helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]),
+ ],
+ "cumsum_graph",
+ inputs=[
+ helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE,
input_shape),
+ helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1],
"axis"),
+ ],
+ outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE,
input_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="cumsum_graph")
+ check_correctness(model)
+
+
@pytest.mark.parametrize("axis", [[0, 2], None])
def test_squeeze(axis):
if axis: