This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d006ecac35 [Relax] [ONNX] Add support for Sign and Not (#17167)
d006ecac35 is described below

commit d006ecac35fd3100ee547d2d0356e21245a93ed0
Author: tsu-bin <[email protected]>
AuthorDate: Thu Jul 18 21:50:14 2024 +0800

    [Relax] [ONNX] Add support for Sign and Not (#17167)
    
    Co-authored-by: tsu-bin <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 18 ++++++++++++++++++
 tests/python/relax/test_frontend_onnx.py        |  8 ++++++++
 2 files changed, 26 insertions(+)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 3a70cd090a..85d4402d66 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1948,6 +1948,22 @@ class HardSwish(OnnxOpConverter):
         )
 
 
+class Sign(OnnxOpConverter):
+    """Converts an onnx Sign node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v9(cls, bb, inputs, attr, params):
+        return relax.op.sign(inputs[0])
+
+
+class Not(OnnxOpConverter):
+    """Converts an onnx Not node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        return relax.op.logical_not(inputs[0])
+
+
 def _get_convert_map():
     return {
         "MatMul": MatMul,
@@ -2030,6 +2046,8 @@ def _get_convert_map():
         "Elu": Elu,
         "HardSigmoid": HardSigmoid,
         "HardSwish": HardSwish,
+        "Sign": Sign,
+        "Not": Not,
     }
 
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 0fc7ec0644..05316f2699 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -600,6 +600,14 @@ def test_hardswish():
     verify_unary("HardSwish", [32, 32])
 
 
+def test_sign():
+    verify_unary("Sign", [32, 32])
+
+
+def test_not():
+    verify_unary("Not", [32, 32], dtype=TensorProto.BOOL)
+
+
 def test_conv():
     def _verify_conv(input_shape, weight_shape, output_shape):
         bias_shape = [output_shape[1]]

Reply via email to