This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new c1aab7d185 [Unity][Relax] Add bitwise and logical ops (AND, NOT, OR, 
XOR) (#15075)
c1aab7d185 is described below

commit c1aab7d185217a612795709033bc5b7b4dadc007
Author: Valery Chernov <[email protected]>
AuthorDate: Tue Jun 13 04:43:12 2023 +0400

    [Unity][Relax] Add bitwise and logical ops (AND, NOT, OR, XOR) (#15075)
    
    * add bitwise and logical binary ops (and, or, xor) to relax
    
    * add bitwise and logical NOT to unary relax ops
    
    * legalize bitwise and logical NOT
    
    * extand ir by new ops
    
    * add bitwise and logical ops to headers
    
    * add bitwise and logical not to native code
    
    * add test
    
    * fix lint
    
    ---------
    
    Co-authored-by: Valery Chernov <[email protected]>
---
 python/tvm/relax/op/binary.py                     | 102 ++++++++++++++++++++++
 python/tvm/relax/op/unary.py                      |  32 +++++++
 python/tvm/relax/transform/legalize_ops/binary.py |  10 +++
 python/tvm/relax/transform/legalize_ops/unary.py  |   2 +
 python/tvm/script/ir_builder/relax/ir.py          |  16 ++++
 src/relax/op/tensor/binary.cc                     |  12 +++
 src/relax/op/tensor/binary.h                      |  22 +++++
 src/relax/op/tensor/unary.cc                      |   2 +
 src/relax/op/tensor/unary.h                       |   6 ++
 tests/python/relax/test_op_binary.py              |  12 +++
 tests/python/relax/test_op_unary.py               |   6 ++
 11 files changed, 222 insertions(+)

diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index e3664fd8d8..982b3a24f2 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -284,3 +284,105 @@ def minimum(x1: Expr, x2: Expr) -> Expr:
         The computed result.
     """
     return _ffi_api.minimum(x1, x2)
+
+
+###################### Logical operators ######################
+
+
+def logical_and(x1: Expr, x2: Expr) -> Expr:
+    """Logical AND
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input tensor.
+    x2 : relax.Expr
+        The second input tensor.
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.logical_and(x1, x2)
+
+
+def logical_or(x1: Expr, x2: Expr) -> Expr:
+    """Logical OR
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input tensor.
+    x2 : relax.Expr
+        The second input tensor.
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.logical_or(x1, x2)
+
+
+def logical_xor(x1: Expr, x2: Expr) -> Expr:
+    """Logical XOR
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input tensor.
+    x2 : relax.Expr
+        The second input tensor.
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.logical_xor(x1, x2)
+
+
+###################### Bitwise operators ######################
+
+
+def bitwise_and(x1: Expr, x2: Expr) -> Expr:
+    """Bitwise AND
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input tensor.
+    x2 : relax.Expr
+        The second input tensor.
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.bitwise_and(x1, x2)
+
+
+def bitwise_or(x1: Expr, x2: Expr) -> Expr:
+    """Bitwise OR
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input tensor.
+    x2 : relax.Expr
+        The second input tensor.
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.bitwise_or(x1, x2)
+
+
+def bitwise_xor(x1: Expr, x2: Expr) -> Expr:
+    """Bitwise XOR
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input tensor.
+    x2 : relax.Expr
+        The second input tensor.
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.bitwise_xor(x1, x2)
diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py
index f885c103e7..78051452e2 100644
--- a/python/tvm/relax/op/unary.py
+++ b/python/tvm/relax/op/unary.py
@@ -159,6 +159,22 @@ def atanh(x: Expr) -> Expr:
     return _ffi_api.atanh(x)  # type: ignore
 
 
+def bitwise_not(x: Expr) -> Expr:
+    """Compute bitwise NOT of the input data.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.bitwise_not(x)  # type: ignore
+
+
 def ceil(x: Expr) -> Expr:
     """Take ceil of input data.
 
@@ -271,6 +287,22 @@ def log(x: Expr) -> Expr:
     return _ffi_api.log(x)  # type: ignore
 
 
+def logical_not(x: Expr) -> Expr:
+    """Compute logical NOT of the input data.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.logical_not(x)  # type: ignore
+
+
 def negative(x: Expr) -> Expr:
     """Compute element-wise negative of the input data.
 
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py 
b/python/tvm/relax/transform/legalize_ops/binary.py
index 897b676518..16d6c02696 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -57,3 +57,13 @@ register_legalize("relax.not_equal", _binary(topi.not_equal))
 
 register_legalize("relax.maximum", _binary(topi.maximum))
 register_legalize("relax.minimum", _binary(topi.minimum))
+
+# bitwise
+register_legalize("relax.bitwise_and", _binary(topi.bitwise_and))
+register_legalize("relax.bitwise_or", _binary(topi.bitwise_or))
+register_legalize("relax.bitwise_xor", _binary(topi.bitwise_xor))
+
+# logical
+register_legalize("relax.logical_and", _binary(topi.logical_and))
+register_legalize("relax.logical_or", _binary(topi.logical_or))
+register_legalize("relax.logical_xor", _binary(topi.logical_xor))
diff --git a/python/tvm/relax/transform/legalize_ops/unary.py 
b/python/tvm/relax/transform/legalize_ops/unary.py
index 104a874679..f948f18dd3 100644
--- a/python/tvm/relax/transform/legalize_ops/unary.py
+++ b/python/tvm/relax/transform/legalize_ops/unary.py
@@ -27,12 +27,14 @@ register_legalize("relax.asin", 
_call_topi_without_attr(topi.asin, "tir_asin"))
 register_legalize("relax.asinh", _call_topi_without_attr(topi.asinh, 
"tir_asinh"))
 register_legalize("relax.atan", _call_topi_without_attr(topi.atan, "tir_atan"))
 register_legalize("relax.atanh", _call_topi_without_attr(topi.atanh, 
"tir_atanh"))
+register_legalize("relax.bitwise_not", 
_call_topi_without_attr(topi.bitwise_not, "tir_bitwise_not"))
 register_legalize("relax.ceil", _call_topi_without_attr(topi.ceil, "tir_ceil"))
 register_legalize("relax.cos", _call_topi_without_attr(topi.cos, "tir_cos"))
 register_legalize("relax.cosh", _call_topi_without_attr(topi.cosh, "tir_cosh"))
 register_legalize("relax.exp", _call_topi_without_attr(topi.exp, "tir_exp"))
 register_legalize("relax.floor", _call_topi_without_attr(topi.floor, 
"tir_floor"))
 register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log"))
+register_legalize("relax.logical_not", 
_call_topi_without_attr(topi.logical_not, "tir_logical_not"))
 register_legalize("relax.negative", _call_topi_without_attr(topi.negative, 
"tir_negative"))
 register_legalize("relax.round", _call_topi_without_attr(topi.round, 
"tir_round"))
 register_legalize("relax.rsqrt", _call_topi_without_attr(topi.rsqrt, 
"tir_rsqrt"))
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 7a1ecca4d8..73606a8924 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -45,6 +45,10 @@ from tvm.relax.op import (
     argmin,
     assert_op,
     astype,
+    bitwise_and,
+    bitwise_not,
+    bitwise_or,
+    bitwise_xor,
     broadcast_to,
     builtin,
     call_builtin_with_ctx,
@@ -86,6 +90,10 @@ from tvm.relax.op import (
     less_equal,
     linear,
     log,
+    logical_and,
+    logical_not,
+    logical_or,
+    logical_xor,
     make_closure,
     matmul,
     max,
@@ -576,6 +584,10 @@ __all__ = [
     "argmin",
     "assert_op",
     "astype",
+    "bitwise_and",
+    "bitwise_not",
+    "bitwise_or",
+    "bitwise_xor",
     "broadcast_to",
     "builtin",
     "call_packed",
@@ -632,6 +644,10 @@ __all__ = [
     "less_equal",
     "linear",
     "log",
+    "logical_and",
+    "logical_not",
+    "logical_or",
+    "logical_xor",
     "make_closure",
     "matmul",
     "max",
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index d44b6da629..6483806182 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -132,5 +132,17 @@ RELAX_REGISTER_CMP_OP_AND_IMPL(not_equal);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(minimum);
 RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(maximum);
 
+/***************** Logical operators *****************/
+
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_and);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_or);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_xor);
+
+/***************** Bitwise operators *****************/
+
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_and);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_or);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_xor);
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index 06f3944d85..b28a6c3369 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -107,6 +107,28 @@ Expr minimum(Expr x1, Expr x2);
 /*! \brief Element-wise maximum */
 Expr maximum(Expr x1, Expr x2);
 
+/***************** Logical operators *****************/
+
+/*! \brief Broadcasted element-wise logical and */
+Expr logical_and(Expr x1, Expr x2);
+
+/*! \brief Broadcasted element-wise logical or */
+Expr logical_or(Expr x1, Expr x2);
+
+/*! \brief Broadcasted element-wise logical xor */
+Expr logical_xor(Expr x1, Expr x2);
+
+/***************** Bitwise operators *****************/
+
+/*! \brief Broadcasted element-wise bitwise and */
+Expr bitwise_and(Expr x1, Expr x2);
+
+/*! \brief Broadcasted element-wise bitwise or */
+Expr bitwise_or(Expr x1, Expr x2);
+
+/*! \brief Broadcasted element-wise bitwise xor */
+Expr bitwise_xor(Expr x1, Expr x2);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc
index 6713c4e31a..6eef44821d 100644
--- a/src/relax/op/tensor/unary.cc
+++ b/src/relax/op/tensor/unary.cc
@@ -43,12 +43,14 @@ RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(asin, 
/*require_float_dtype=*/true);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(asinh, /*require_float_dtype=*/true);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(atan, /*require_float_dtype=*/true);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(atanh, /*require_float_dtype=*/true);
+RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(bitwise_not, 
/*require_float_dtype=*/false);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(ceil, /*require_float_dtype=*/false);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(cos, /*require_float_dtype=*/true);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(cosh, /*require_float_dtype=*/true);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(exp, /*require_float_dtype=*/true);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(floor, /*require_float_dtype=*/false);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(log, /*require_float_dtype=*/true);
+RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(logical_not, 
/*require_float_dtype=*/false);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(negative, 
/*require_float_dtype=*/false);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(round, /*require_float_dtype=*/false);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(rsqrt, /*require_float_dtype=*/true);
diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h
index 60dd5148e2..5f92ed0b0c 100644
--- a/src/relax/op/tensor/unary.h
+++ b/src/relax/op/tensor/unary.h
@@ -76,6 +76,9 @@ Expr atan(Expr x);
 /*! \brief Compute element-wise arc tanh of the input data. */
 Expr atanh(Expr x);
 
+/*! \brief Compute element-wise bitwise not */
+Expr bitwise_not(Expr x);
+
 /*! \brief Take ceil of input data. */
 Expr ceil(Expr x);
 
@@ -94,6 +97,9 @@ Expr floor(Expr x);
 /*! \brief Compute element-wise natural logarithm of data. */
 Expr log(Expr x);
 
+/*! \brief Compute element-wise logical not */
+Expr logical_not(Expr x);
+
 /*! \brief Compute element-wise negative value of data. */
 Expr negative(Expr x);
 
diff --git a/tests/python/relax/test_op_binary.py 
b/tests/python/relax/test_op_binary.py
index 809fe7e98f..ce9e5d507e 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -41,6 +41,18 @@ def test_op_correctness():
     assert relax.op.less_equal(x, y).op == Op.get("relax.less_equal")
     assert relax.op.not_equal(x, y).op == Op.get("relax.not_equal")
 
+    x = relax.Var("x", R.Tensor((2, 3), "int32"))
+    y = relax.Var("y", R.Tensor((2, 3), "int32"))
+    assert relax.op.bitwise_and(x, y).op == Op.get("relax.bitwise_and")
+    assert relax.op.bitwise_or(x, y).op == Op.get("relax.bitwise_or")
+    assert relax.op.bitwise_xor(x, y).op == Op.get("relax.bitwise_xor")
+
+    x = relax.Var("x", R.Tensor((2, 3), "bool"))
+    y = relax.Var("y", R.Tensor((2, 3), "bool"))
+    assert relax.op.logical_and(x, y).op == Op.get("relax.logical_and")
+    assert relax.op.logical_or(x, y).op == Op.get("relax.logical_or")
+    assert relax.op.logical_xor(x, y).op == Op.get("relax.logical_xor")
+
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
     ret = bb.normalize(call)
diff --git a/tests/python/relax/test_op_unary.py 
b/tests/python/relax/test_op_unary.py
index c8751dbaca..9bfb8612ef 100644
--- a/tests/python/relax/test_op_unary.py
+++ b/tests/python/relax/test_op_unary.py
@@ -54,6 +54,12 @@ def test_op_correctness():
     assert relax.op.tanh(x).op == Op.get("relax.tanh")
     assert relax.op.clip(x, 0, 6).op == Op.get("relax.clip")
 
+    x = relax.Var("x", R.Tensor((2, 3), "int32"))
+    assert relax.op.bitwise_not(x).op == Op.get("relax.bitwise_not")
+
+    x = relax.Var("x", R.Tensor((2, 3), "bool"))
+    assert relax.op.logical_not(x).op == Op.get("relax.logical_not")
+
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
     ret = bb.normalize(call)

Reply via email to