This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c1aab7d185 [Unity][Relax] Add bitwise and logical ops (AND, NOT, OR,
XOR) (#15075)
c1aab7d185 is described below
commit c1aab7d185217a612795709033bc5b7b4dadc007
Author: Valery Chernov <[email protected]>
AuthorDate: Tue Jun 13 04:43:12 2023 +0400
[Unity][Relax] Add bitwise and logical ops (AND, NOT, OR, XOR) (#15075)
* add bitwise and logical binary ops (and, or, xor) to relax
* add bitwise and logical NOT to unary relax ops
* legalize bitwise and logical NOT
* extand ir by new ops
* add bitwise and logical ops to headers
* add bitwise and logical not to native code
* add test
* fix lint
---------
Co-authored-by: Valery Chernov <[email protected]>
---
python/tvm/relax/op/binary.py | 102 ++++++++++++++++++++++
python/tvm/relax/op/unary.py | 32 +++++++
python/tvm/relax/transform/legalize_ops/binary.py | 10 +++
python/tvm/relax/transform/legalize_ops/unary.py | 2 +
python/tvm/script/ir_builder/relax/ir.py | 16 ++++
src/relax/op/tensor/binary.cc | 12 +++
src/relax/op/tensor/binary.h | 22 +++++
src/relax/op/tensor/unary.cc | 2 +
src/relax/op/tensor/unary.h | 6 ++
tests/python/relax/test_op_binary.py | 12 +++
tests/python/relax/test_op_unary.py | 6 ++
11 files changed, 222 insertions(+)
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index e3664fd8d8..982b3a24f2 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -284,3 +284,105 @@ def minimum(x1: Expr, x2: Expr) -> Expr:
The computed result.
"""
return _ffi_api.minimum(x1, x2)
+
+
+###################### Logical operators ######################
+
+
+def logical_and(x1: Expr, x2: Expr) -> Expr:
+ """Logical AND
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The first input tensor.
+ x2 : relax.Expr
+ The second input tensor.
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.logical_and(x1, x2)
+
+
+def logical_or(x1: Expr, x2: Expr) -> Expr:
+ """Logical OR
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The first input tensor.
+ x2 : relax.Expr
+ The second input tensor.
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.logical_or(x1, x2)
+
+
+def logical_xor(x1: Expr, x2: Expr) -> Expr:
+ """Logical XOR
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The first input tensor.
+ x2 : relax.Expr
+ The second input tensor.
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.logical_xor(x1, x2)
+
+
+###################### Bitwise operators ######################
+
+
+def bitwise_and(x1: Expr, x2: Expr) -> Expr:
+ """Bitwise AND
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The first input tensor.
+ x2 : relax.Expr
+ The second input tensor.
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.bitwise_and(x1, x2)
+
+
+def bitwise_or(x1: Expr, x2: Expr) -> Expr:
+ """Bitwise OR
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The first input tensor.
+ x2 : relax.Expr
+ The second input tensor.
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.bitwise_or(x1, x2)
+
+
+def bitwise_xor(x1: Expr, x2: Expr) -> Expr:
+ """Bitwise XOR
+ Parameters
+ ----------
+ x1 : relax.Expr
+ The first input tensor.
+ x2 : relax.Expr
+ The second input tensor.
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.bitwise_xor(x1, x2)
diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py
index f885c103e7..78051452e2 100644
--- a/python/tvm/relax/op/unary.py
+++ b/python/tvm/relax/op/unary.py
@@ -159,6 +159,22 @@ def atanh(x: Expr) -> Expr:
return _ffi_api.atanh(x) # type: ignore
+def bitwise_not(x: Expr) -> Expr:
+ """Compute bitwise NOT of the input data.
+
+ Parameters
+ ----------
+ x : relax.Expr
+ The input data
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.bitwise_not(x) # type: ignore
+
+
def ceil(x: Expr) -> Expr:
"""Take ceil of input data.
@@ -271,6 +287,22 @@ def log(x: Expr) -> Expr:
return _ffi_api.log(x) # type: ignore
+def logical_not(x: Expr) -> Expr:
+ """Compute logical NOT of the input data.
+
+ Parameters
+ ----------
+ x : relax.Expr
+ The input data
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ return _ffi_api.logical_not(x) # type: ignore
+
+
def negative(x: Expr) -> Expr:
"""Compute element-wise negative of the input data.
diff --git a/python/tvm/relax/transform/legalize_ops/binary.py
b/python/tvm/relax/transform/legalize_ops/binary.py
index 897b676518..16d6c02696 100644
--- a/python/tvm/relax/transform/legalize_ops/binary.py
+++ b/python/tvm/relax/transform/legalize_ops/binary.py
@@ -57,3 +57,13 @@ register_legalize("relax.not_equal", _binary(topi.not_equal))
register_legalize("relax.maximum", _binary(topi.maximum))
register_legalize("relax.minimum", _binary(topi.minimum))
+
+# bitwise
+register_legalize("relax.bitwise_and", _binary(topi.bitwise_and))
+register_legalize("relax.bitwise_or", _binary(topi.bitwise_or))
+register_legalize("relax.bitwise_xor", _binary(topi.bitwise_xor))
+
+# logical
+register_legalize("relax.logical_and", _binary(topi.logical_and))
+register_legalize("relax.logical_or", _binary(topi.logical_or))
+register_legalize("relax.logical_xor", _binary(topi.logical_xor))
diff --git a/python/tvm/relax/transform/legalize_ops/unary.py
b/python/tvm/relax/transform/legalize_ops/unary.py
index 104a874679..f948f18dd3 100644
--- a/python/tvm/relax/transform/legalize_ops/unary.py
+++ b/python/tvm/relax/transform/legalize_ops/unary.py
@@ -27,12 +27,14 @@ register_legalize("relax.asin",
_call_topi_without_attr(topi.asin, "tir_asin"))
register_legalize("relax.asinh", _call_topi_without_attr(topi.asinh,
"tir_asinh"))
register_legalize("relax.atan", _call_topi_without_attr(topi.atan, "tir_atan"))
register_legalize("relax.atanh", _call_topi_without_attr(topi.atanh,
"tir_atanh"))
+register_legalize("relax.bitwise_not",
_call_topi_without_attr(topi.bitwise_not, "tir_bitwise_not"))
register_legalize("relax.ceil", _call_topi_without_attr(topi.ceil, "tir_ceil"))
register_legalize("relax.cos", _call_topi_without_attr(topi.cos, "tir_cos"))
register_legalize("relax.cosh", _call_topi_without_attr(topi.cosh, "tir_cosh"))
register_legalize("relax.exp", _call_topi_without_attr(topi.exp, "tir_exp"))
register_legalize("relax.floor", _call_topi_without_attr(topi.floor,
"tir_floor"))
register_legalize("relax.log", _call_topi_without_attr(topi.log, "tir_log"))
+register_legalize("relax.logical_not",
_call_topi_without_attr(topi.logical_not, "tir_logical_not"))
register_legalize("relax.negative", _call_topi_without_attr(topi.negative,
"tir_negative"))
register_legalize("relax.round", _call_topi_without_attr(topi.round,
"tir_round"))
register_legalize("relax.rsqrt", _call_topi_without_attr(topi.rsqrt,
"tir_rsqrt"))
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 7a1ecca4d8..73606a8924 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -45,6 +45,10 @@ from tvm.relax.op import (
argmin,
assert_op,
astype,
+ bitwise_and,
+ bitwise_not,
+ bitwise_or,
+ bitwise_xor,
broadcast_to,
builtin,
call_builtin_with_ctx,
@@ -86,6 +90,10 @@ from tvm.relax.op import (
less_equal,
linear,
log,
+ logical_and,
+ logical_not,
+ logical_or,
+ logical_xor,
make_closure,
matmul,
max,
@@ -576,6 +584,10 @@ __all__ = [
"argmin",
"assert_op",
"astype",
+ "bitwise_and",
+ "bitwise_not",
+ "bitwise_or",
+ "bitwise_xor",
"broadcast_to",
"builtin",
"call_packed",
@@ -632,6 +644,10 @@ __all__ = [
"less_equal",
"linear",
"log",
+ "logical_and",
+ "logical_not",
+ "logical_or",
+ "logical_xor",
"make_closure",
"matmul",
"max",
diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc
index d44b6da629..6483806182 100644
--- a/src/relax/op/tensor/binary.cc
+++ b/src/relax/op/tensor/binary.cc
@@ -132,5 +132,17 @@ RELAX_REGISTER_CMP_OP_AND_IMPL(not_equal);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(minimum);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(maximum);
+/***************** Logical operators *****************/
+
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_and);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_or);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(logical_xor);
+
+/***************** Bitwise operators *****************/
+
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_and);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_or);
+RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(bitwise_xor);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h
index 06f3944d85..b28a6c3369 100644
--- a/src/relax/op/tensor/binary.h
+++ b/src/relax/op/tensor/binary.h
@@ -107,6 +107,28 @@ Expr minimum(Expr x1, Expr x2);
/*! \brief Element-wise maximum */
Expr maximum(Expr x1, Expr x2);
+/***************** Logical operators *****************/
+
+/*! \brief Broadcasted element-wise logical and */
+Expr logical_and(Expr x1, Expr x2);
+
+/*! \brief Broadcasted element-wise logical or */
+Expr logical_or(Expr x1, Expr x2);
+
+/*! \brief Broadcasted element-wise logical xor */
+Expr logical_xor(Expr x1, Expr x2);
+
+/***************** Bitwise operators *****************/
+
+/*! \brief Broadcasted element-wise bitwise and */
+Expr bitwise_and(Expr x1, Expr x2);
+
+/*! \brief Broadcasted element-wise bitwise or */
+Expr bitwise_or(Expr x1, Expr x2);
+
+/*! \brief Broadcasted element-wise bitwise xor */
+Expr bitwise_xor(Expr x1, Expr x2);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc
index 6713c4e31a..6eef44821d 100644
--- a/src/relax/op/tensor/unary.cc
+++ b/src/relax/op/tensor/unary.cc
@@ -43,12 +43,14 @@ RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(asin,
/*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(asinh, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(atan, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(atanh, /*require_float_dtype=*/true);
+RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(bitwise_not,
/*require_float_dtype=*/false);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(ceil, /*require_float_dtype=*/false);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(cos, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(cosh, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(exp, /*require_float_dtype=*/true);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(floor, /*require_float_dtype=*/false);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(log, /*require_float_dtype=*/true);
+RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(logical_not,
/*require_float_dtype=*/false);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(negative,
/*require_float_dtype=*/false);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(round, /*require_float_dtype=*/false);
RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(rsqrt, /*require_float_dtype=*/true);
diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h
index 60dd5148e2..5f92ed0b0c 100644
--- a/src/relax/op/tensor/unary.h
+++ b/src/relax/op/tensor/unary.h
@@ -76,6 +76,9 @@ Expr atan(Expr x);
/*! \brief Compute element-wise arc tanh of the input data. */
Expr atanh(Expr x);
+/*! \brief Compute element-wise bitwise not */
+Expr bitwise_not(Expr x);
+
/*! \brief Take ceil of input data. */
Expr ceil(Expr x);
@@ -94,6 +97,9 @@ Expr floor(Expr x);
/*! \brief Compute element-wise natural logarithm of data. */
Expr log(Expr x);
+/*! \brief Compute element-wise logical not */
+Expr logical_not(Expr x);
+
/*! \brief Compute element-wise negative value of data. */
Expr negative(Expr x);
diff --git a/tests/python/relax/test_op_binary.py
b/tests/python/relax/test_op_binary.py
index 809fe7e98f..ce9e5d507e 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -41,6 +41,18 @@ def test_op_correctness():
assert relax.op.less_equal(x, y).op == Op.get("relax.less_equal")
assert relax.op.not_equal(x, y).op == Op.get("relax.not_equal")
+ x = relax.Var("x", R.Tensor((2, 3), "int32"))
+ y = relax.Var("y", R.Tensor((2, 3), "int32"))
+ assert relax.op.bitwise_and(x, y).op == Op.get("relax.bitwise_and")
+ assert relax.op.bitwise_or(x, y).op == Op.get("relax.bitwise_or")
+ assert relax.op.bitwise_xor(x, y).op == Op.get("relax.bitwise_xor")
+
+ x = relax.Var("x", R.Tensor((2, 3), "bool"))
+ y = relax.Var("y", R.Tensor((2, 3), "bool"))
+ assert relax.op.logical_and(x, y).op == Op.get("relax.logical_and")
+ assert relax.op.logical_or(x, y).op == Op.get("relax.logical_or")
+ assert relax.op.logical_xor(x, y).op == Op.get("relax.logical_xor")
+
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
ret = bb.normalize(call)
diff --git a/tests/python/relax/test_op_unary.py
b/tests/python/relax/test_op_unary.py
index c8751dbaca..9bfb8612ef 100644
--- a/tests/python/relax/test_op_unary.py
+++ b/tests/python/relax/test_op_unary.py
@@ -54,6 +54,12 @@ def test_op_correctness():
assert relax.op.tanh(x).op == Op.get("relax.tanh")
assert relax.op.clip(x, 0, 6).op == Op.get("relax.clip")
+ x = relax.Var("x", R.Tensor((2, 3), "int32"))
+ assert relax.op.bitwise_not(x).op == Op.get("relax.bitwise_not")
+
+ x = relax.Var("x", R.Tensor((2, 3), "bool"))
+ assert relax.op.logical_not(x).op == Op.get("relax.logical_not")
+
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
ret = bb.normalize(call)