This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d006ecac35 [Relax] [ONNX] Add support for Sign and Not (#17167)
d006ecac35 is described below
commit d006ecac35fd3100ee547d2d0356e21245a93ed0
Author: tsu-bin <[email protected]>
AuthorDate: Thu Jul 18 21:50:14 2024 +0800
[Relax] [ONNX] Add support for Sign and Not (#17167)
Co-authored-by: tsu-bin <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 18 ++++++++++++++++++
tests/python/relax/test_frontend_onnx.py | 8 ++++++++
2 files changed, 26 insertions(+)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 3a70cd090a..85d4402d66 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1948,6 +1948,22 @@ class HardSwish(OnnxOpConverter):
)
+class Sign(OnnxOpConverter):
+ """Converts an onnx Sign node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v9(cls, bb, inputs, attr, params):
+ return relax.op.sign(inputs[0])
+
+
+class Not(OnnxOpConverter):
+ """Converts an onnx Not node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ return relax.op.logical_not(inputs[0])
+
+
def _get_convert_map():
return {
"MatMul": MatMul,
@@ -2030,6 +2046,8 @@ def _get_convert_map():
"Elu": Elu,
"HardSigmoid": HardSigmoid,
"HardSwish": HardSwish,
+ "Sign": Sign,
+ "Not": Not,
}
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 0fc7ec0644..05316f2699 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -600,6 +600,14 @@ def test_hardswish():
verify_unary("HardSwish", [32, 32])
+def test_sign():
+ verify_unary("Sign", [32, 32])
+
+
+def test_not():
+ verify_unary("Not", [32, 32], dtype=TensorProto.BOOL)
+
+
def test_conv():
def _verify_conv(input_shape, weight_shape, output_shape):
bias_shape = [output_shape[1]]