This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 583f5ae37c [ONNX] Support Bitwise operations (#13888)
583f5ae37c is described below

commit 583f5ae37c3b34ac8e1ce36fccf22c1e7eb89ab1
Author: Valery Chernov <[email protected]>
AuthorDate: Fri Feb 3 14:30:55 2023 +0400

    [ONNX] Support Bitwise operations (#13888)
    
    * add base class for bitwise operations. BitwiseAnd, BitwiseNot, BitwiseOr 
and BitwiseXor were implemented
    
    * add test for BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor operations to 
ONNX front-end
    
    * add test of BitShift for ONNX front-end
    
    * fix dtype for test
    
    * skip test due to old version of ORT
    
    ---------
    
    Co-authored-by: Valery Chernov <[email protected]>
---
 python/tvm/relay/frontend/onnx.py          |  83 ++++++++++++++--
 tests/python/frontend/onnx/test_forward.py | 150 +++++++++++++++++++++++++++++
 2 files changed, 225 insertions(+), 8 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 8b4a0cc5e8..8de5e0e08b 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -5578,13 +5578,31 @@ class ConvInteger(OnnxOpConverter):
         )
 
 
-class BitShift(OnnxOpConverter):
-    """Operator converter for NonZero"""
+class BitwiseBase(OnnxOpConverter):
+    """Base class of operator converter for Bitwise operations"""
+
+    name = ""
+
+    @classmethod
+    def check_inputs(cls, inputs, num=2, use_int=True):
+        assert len(inputs) == num, "{} takes {} inputs, {} 
given".format(cls.name, num, len(inputs))
+
+        valid_types = ["uint8", "uint16", "uint32", "uint64"]
+        if use_int:
+            valid_types += ["int8", "int16", "int32", "int64"]
+        for i in range(num):
+            in_dtype = infer_type(inputs[i]).checked_type.dtype
+            assert in_dtype in valid_types, "Wrong dtype of the {}-th input: 
{}".format(i, in_dtype)
+
+
+class BitShift(BitwiseBase):
+    """Operator converter for BitShift"""
+
+    name = "BitShift"
 
     @classmethod
     def _impl_v11(cls, inputs, attr, params):
-        if len(inputs) != 2:
-            raise ValueError("Bitshift expects 2 inputs")
+        cls.check_inputs(inputs, use_int=False)
 
         direction = attr.get("direction", "LEFT").decode("ascii")
         if direction == "LEFT":
@@ -5596,6 +5614,54 @@ class BitShift(OnnxOpConverter):
         return out
 
 
+class BitwiseAnd(BitwiseBase):
+    """Operator converter for BitwiseAnd"""
+
+    name = "BitwiseAnd"
+
+    @classmethod
+    def _impl_v18(cls, inputs, attr, params):
+        cls.check_inputs(inputs)
+
+        return _op.bitwise_and(*inputs)
+
+
+class BitwiseNot(BitwiseBase):
+    """Operator converter for BitwiseNot"""
+
+    name = "BitwiseNot"
+
+    @classmethod
+    def _impl_v18(cls, inputs, attr, params):
+        cls.check_inputs(inputs, num=1)
+
+        return _op.bitwise_not(*inputs)
+
+
+class BitwiseOr(BitwiseBase):
+    """Operator converter for BitwiseOr"""
+
+    name = "BitwiseOr"
+
+    @classmethod
+    def _impl_v18(cls, inputs, attr, params):
+        cls.check_inputs(inputs)
+
+        return _op.bitwise_or(*inputs)
+
+
+class BitwiseXor(BitwiseBase):
+    """Operator converter for BitwiseXor"""
+
+    name = "BitwiseXor"
+
+    @classmethod
+    def _impl_v18(cls, inputs, attr, params):
+        cls.check_inputs(inputs)
+
+        return _op.bitwise_xor(*inputs)
+
+
 class Unique(OnnxOpConverter):
     """Operator converter for unique"""
 
@@ -6319,7 +6385,12 @@ def _get_convert_map(opset):
         "OptionalHasElement": OptionalHasElement.get_converter(opset),
         "OptionalGetElement": OptionalGetElement.get_converter(opset),
         "Affine": Affine.get_converter(opset),
+        # Bitwise operators
         "BitShift": BitShift.get_converter(opset),
+        "BitwiseAnd": BitwiseAnd.get_converter(opset),
+        "BitwiseNot": BitwiseNot.get_converter(opset),
+        "BitwiseOr": BitwiseOr.get_converter(opset),
+        "BitwiseXor": BitwiseXor.get_converter(opset),
         "ThresholdedRelu": ThresholdedRelu.get_converter(opset),
         "ScaledTanh": ScaledTanh.get_converter(opset),
         "ParametricSoftplus": ParametricSoftPlus.get_converter(opset),
@@ -6337,10 +6408,6 @@ def _get_convert_map(opset):
         "Upsample": Upsample.get_converter(opset),
         "SpatialBN": BatchNorm.get_converter(opset),
         # defs/generator
-        # 'RandomUniform'
-        # 'RandomNormal'
-        # 'RandomUniformLike'
-        # 'RandomNormalLike'
         # defs/logical
         # defs/math
         "Add": Add.get_converter(opset),
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index dd172d1dde..0a03284326 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -7506,6 +7506,156 @@ def test_convinteger(target, dev):
     )
 
 
[email protected]_targets
+def test_bitshift(target, dev):
+    """test_bitshift"""
+
+    def verify_bitshift(in_shape, shift_shape, high=1000000000, 
in_dtype="uint64"):
+        in_shape = list(in_shape)
+        shift_shape = list(shift_shape)
+
+        # Create an input for each tensor.
+        tensor_values = [
+            np.random.randint(high, size=in_shape).astype(in_dtype),
+            np.random.randint(16, size=shift_shape).astype(in_dtype),
+            np.random.randint(16, size=shift_shape).astype(in_dtype),
+        ]
+
+        bitshift_left_node = helper.make_node(
+            "BitShift",
+            inputs=["input", "shift_left"],
+            outputs=["shifted"],
+            direction="LEFT",
+        )
+
+        bitshift_right_node = helper.make_node(
+            "BitShift",
+            inputs=["shifted", "shift_right"],
+            outputs=["output"],
+            direction="RIGHT",
+        )
+
+        # Create input and output tensors.
+        proto_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
+        graph_inputs = [
+            helper.make_tensor_value_info("input", proto_type, in_shape),
+            helper.make_tensor_value_info("shift_left", proto_type, 
shift_shape),
+            helper.make_tensor_value_info("shift_right", proto_type, 
shift_shape),
+        ]
+
+        graph_outputs = [helper.make_tensor_value_info("output", proto_type, 
in_shape)]
+
+        graph_nodes = [bitshift_left_node, bitshift_right_node]
+
+        graph = helper.make_graph(
+            graph_nodes,
+            "BitShift_test",
+            inputs=graph_inputs,
+            outputs=graph_outputs,
+        )
+        model = helper.make_model(
+            graph,
+            producer_name="BitShift_test",
+        )
+
+        verify_with_ort_with_inputs(model, tensor_values, target=target, 
dev=dev)
+
+    shape = (100, 4, 2)
+    broadcast_shape = (100, 1, 1)
+    # Common bitwise test
+    verify_bitshift(shape, shape)
+    # Bitwise test with broadcasting
+    verify_bitshift(shape, broadcast_shape)
+
+
+# TODO(vvchernov): return test back than ONNX Runtime in CI will support 
domain version of 18
[email protected]("Currently ONNX Runtime in CI does not support domain 
version of 18")
[email protected]_targets
+def test_bitwise(target, dev):
+    """test_bitwise"""
+
+    def verify_bitwise_ops(A_shape, B_shape, C_shape, D_shape, high=128, 
in_dtype="int32"):
+        A_shape = list(A_shape)
+        B_shape = list(B_shape)
+        C_shape = list(C_shape)
+        D_shape = list(D_shape)
+
+        # Create an input for each tensor.
+        tensor_values = [
+            np.random.randint(high, size=A_shape).astype(in_dtype),
+            np.random.randint(high, size=B_shape).astype(in_dtype),
+            np.random.randint(high, size=C_shape).astype(in_dtype),
+            np.random.randint(high, size=D_shape).astype(in_dtype),
+        ]
+
+        or_node = helper.make_node(
+            "BitwiseOr",
+            inputs=["A", "B"],
+            outputs=["OR"],
+        )
+
+        and_node = helper.make_node(
+            "BitwiseAnd",
+            inputs=["OR", "C"],
+            outputs=["AND"],
+        )
+
+        xor_node = helper.make_node(
+            "BitwiseXor",
+            inputs=["AND", "D"],
+            outputs=["XOR"],
+        )
+
+        not_node = helper.make_node(
+            "BitwiseNot",
+            inputs=["XOR"],
+            outputs=["output"],
+        )
+
+        # Create input and output tensors.
+        proto_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
+        graph_inputs = [
+            helper.make_tensor_value_info("A", proto_type, A_shape),
+            helper.make_tensor_value_info("B", proto_type, B_shape),
+            helper.make_tensor_value_info("C", proto_type, C_shape),
+            helper.make_tensor_value_info("D", proto_type, D_shape),
+        ]
+
+        graph_outputs = [
+            helper.make_tensor_value_info("output", proto_type, A_shape),
+        ]
+
+        graph_nodes = [
+            or_node,
+            and_node,
+            xor_node,
+            not_node,
+        ]
+
+        graph = helper.make_graph(
+            graph_nodes,
+            "Bitwise_test",
+            inputs=graph_inputs,
+            outputs=graph_outputs,
+        )
+        model = helper.make_model(
+            graph,
+            producer_name="Bitwise_test",
+        )
+
+        verify_with_ort_with_inputs(model, tensor_values, target=target, 
dev=dev)
+
+    shape = (100, 4, 2)
+    broadcast_shape = (100, 1, 1)
+    dtypes = ["int8", "uint8", "int32", "uint32"]
+    high_vals = [128, 128, 2147483648, 2147483648]
+    for high, dtype in zip(high_vals, dtypes):
+        # Common bitwise test
+        verify_bitwise_ops(shape, shape, shape, shape, high, dtype)
+        # Bitwise test with broadcasting
+        verify_bitwise_ops(shape, broadcast_shape, broadcast_shape, 
broadcast_shape, high, dtype)
+
+
 @tvm.testing.parametrize_targets
 def test_scan(target, dev):
     """test_scan"""

Reply via email to