This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 583f5ae37c [ONNX] Support Bitwise operations (#13888)
583f5ae37c is described below
commit 583f5ae37c3b34ac8e1ce36fccf22c1e7eb89ab1
Author: Valery Chernov <[email protected]>
AuthorDate: Fri Feb 3 14:30:55 2023 +0400
[ONNX] Support Bitwise operations (#13888)
* add base class for bitwise operations. BitwiseAnd, BitwiseNot, BitwiseOr
and BitwiseXor were implemented
* add test for BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor operations to
ONNX front-end
* add test of BitShift for ONNX front-end
* fix dtype for test
* skip test due to old version of ORT
---------
Co-authored-by: Valery Chernov <[email protected]>
---
python/tvm/relay/frontend/onnx.py | 83 ++++++++++++++--
tests/python/frontend/onnx/test_forward.py | 150 +++++++++++++++++++++++++++++
2 files changed, 225 insertions(+), 8 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index 8b4a0cc5e8..8de5e0e08b 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -5578,13 +5578,31 @@ class ConvInteger(OnnxOpConverter):
)
-class BitShift(OnnxOpConverter):
- """Operator converter for NonZero"""
+class BitwiseBase(OnnxOpConverter):
+ """Base class of operator converter for Bitwise operations"""
+
+ name = ""
+
+ @classmethod
+ def check_inputs(cls, inputs, num=2, use_int=True):
+ assert len(inputs) == num, "{} takes {} inputs, {}
given".format(cls.name, num, len(inputs))
+
+ valid_types = ["uint8", "uint16", "uint32", "uint64"]
+ if use_int:
+ valid_types += ["int8", "int16", "int32", "int64"]
+ for i in range(num):
+ in_dtype = infer_type(inputs[i]).checked_type.dtype
+ assert in_dtype in valid_types, "Wrong dtype of the {}-th input:
{}".format(i, in_dtype)
+
+
+class BitShift(BitwiseBase):
+ """Operator converter for BitShift"""
+
+ name = "BitShift"
@classmethod
def _impl_v11(cls, inputs, attr, params):
- if len(inputs) != 2:
- raise ValueError("Bitshift expects 2 inputs")
+ cls.check_inputs(inputs, use_int=False)
direction = attr.get("direction", "LEFT").decode("ascii")
if direction == "LEFT":
@@ -5596,6 +5614,54 @@ class BitShift(OnnxOpConverter):
return out
+class BitwiseAnd(BitwiseBase):
+ """Operator converter for BitwiseAnd"""
+
+ name = "BitwiseAnd"
+
+ @classmethod
+ def _impl_v18(cls, inputs, attr, params):
+ cls.check_inputs(inputs)
+
+ return _op.bitwise_and(*inputs)
+
+
+class BitwiseNot(BitwiseBase):
+ """Operator converter for BitwiseNot"""
+
+ name = "BitwiseNot"
+
+ @classmethod
+ def _impl_v18(cls, inputs, attr, params):
+ cls.check_inputs(inputs, num=1)
+
+ return _op.bitwise_not(*inputs)
+
+
+class BitwiseOr(BitwiseBase):
+ """Operator converter for BitwiseOr"""
+
+ name = "BitwiseOr"
+
+ @classmethod
+ def _impl_v18(cls, inputs, attr, params):
+ cls.check_inputs(inputs)
+
+ return _op.bitwise_or(*inputs)
+
+
+class BitwiseXor(BitwiseBase):
+ """Operator converter for BitwiseXor"""
+
+ name = "BitwiseXor"
+
+ @classmethod
+ def _impl_v18(cls, inputs, attr, params):
+ cls.check_inputs(inputs)
+
+ return _op.bitwise_xor(*inputs)
+
+
class Unique(OnnxOpConverter):
"""Operator converter for unique"""
@@ -6319,7 +6385,12 @@ def _get_convert_map(opset):
"OptionalHasElement": OptionalHasElement.get_converter(opset),
"OptionalGetElement": OptionalGetElement.get_converter(opset),
"Affine": Affine.get_converter(opset),
+ # Bitwise operators
"BitShift": BitShift.get_converter(opset),
+ "BitwiseAnd": BitwiseAnd.get_converter(opset),
+ "BitwiseNot": BitwiseNot.get_converter(opset),
+ "BitwiseOr": BitwiseOr.get_converter(opset),
+ "BitwiseXor": BitwiseXor.get_converter(opset),
"ThresholdedRelu": ThresholdedRelu.get_converter(opset),
"ScaledTanh": ScaledTanh.get_converter(opset),
"ParametricSoftplus": ParametricSoftPlus.get_converter(opset),
@@ -6337,10 +6408,6 @@ def _get_convert_map(opset):
"Upsample": Upsample.get_converter(opset),
"SpatialBN": BatchNorm.get_converter(opset),
# defs/generator
- # 'RandomUniform'
- # 'RandomNormal'
- # 'RandomUniformLike'
- # 'RandomNormalLike'
# defs/logical
# defs/math
"Add": Add.get_converter(opset),
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index dd172d1dde..0a03284326 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -7506,6 +7506,156 @@ def test_convinteger(target, dev):
)
[email protected]_targets
+def test_bitshift(target, dev):
+ """test_bitshift"""
+
+ def verify_bitshift(in_shape, shift_shape, high=1000000000,
in_dtype="uint64"):
+ in_shape = list(in_shape)
+ shift_shape = list(shift_shape)
+
+ # Create an input for each tensor.
+ tensor_values = [
+ np.random.randint(high, size=in_shape).astype(in_dtype),
+ np.random.randint(16, size=shift_shape).astype(in_dtype),
+ np.random.randint(16, size=shift_shape).astype(in_dtype),
+ ]
+
+ bitshift_left_node = helper.make_node(
+ "BitShift",
+ inputs=["input", "shift_left"],
+ outputs=["shifted"],
+ direction="LEFT",
+ )
+
+ bitshift_right_node = helper.make_node(
+ "BitShift",
+ inputs=["shifted", "shift_right"],
+ outputs=["output"],
+ direction="RIGHT",
+ )
+
+ # Create input and output tensors.
+ proto_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
+ graph_inputs = [
+ helper.make_tensor_value_info("input", proto_type, in_shape),
+ helper.make_tensor_value_info("shift_left", proto_type,
shift_shape),
+ helper.make_tensor_value_info("shift_right", proto_type,
shift_shape),
+ ]
+
+ graph_outputs = [helper.make_tensor_value_info("output", proto_type,
in_shape)]
+
+ graph_nodes = [bitshift_left_node, bitshift_right_node]
+
+ graph = helper.make_graph(
+ graph_nodes,
+ "BitShift_test",
+ inputs=graph_inputs,
+ outputs=graph_outputs,
+ )
+ model = helper.make_model(
+ graph,
+ producer_name="BitShift_test",
+ )
+
+ verify_with_ort_with_inputs(model, tensor_values, target=target,
dev=dev)
+
+ shape = (100, 4, 2)
+ broadcast_shape = (100, 1, 1)
+ # Common bitwise test
+ verify_bitshift(shape, shape)
+ # Bitwise test with broadcasting
+ verify_bitshift(shape, broadcast_shape)
+
+
+# TODO(vvchernov): return test back than ONNX Runtime in CI will support
domain version of 18
[email protected]("Currently ONNX Runtime in CI does not support domain
version of 18")
[email protected]_targets
+def test_bitwise(target, dev):
+ """test_bitwise"""
+
+ def verify_bitwise_ops(A_shape, B_shape, C_shape, D_shape, high=128,
in_dtype="int32"):
+ A_shape = list(A_shape)
+ B_shape = list(B_shape)
+ C_shape = list(C_shape)
+ D_shape = list(D_shape)
+
+ # Create an input for each tensor.
+ tensor_values = [
+ np.random.randint(high, size=A_shape).astype(in_dtype),
+ np.random.randint(high, size=B_shape).astype(in_dtype),
+ np.random.randint(high, size=C_shape).astype(in_dtype),
+ np.random.randint(high, size=D_shape).astype(in_dtype),
+ ]
+
+ or_node = helper.make_node(
+ "BitwiseOr",
+ inputs=["A", "B"],
+ outputs=["OR"],
+ )
+
+ and_node = helper.make_node(
+ "BitwiseAnd",
+ inputs=["OR", "C"],
+ outputs=["AND"],
+ )
+
+ xor_node = helper.make_node(
+ "BitwiseXor",
+ inputs=["AND", "D"],
+ outputs=["XOR"],
+ )
+
+ not_node = helper.make_node(
+ "BitwiseNot",
+ inputs=["XOR"],
+ outputs=["output"],
+ )
+
+ # Create input and output tensors.
+ proto_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
+ graph_inputs = [
+ helper.make_tensor_value_info("A", proto_type, A_shape),
+ helper.make_tensor_value_info("B", proto_type, B_shape),
+ helper.make_tensor_value_info("C", proto_type, C_shape),
+ helper.make_tensor_value_info("D", proto_type, D_shape),
+ ]
+
+ graph_outputs = [
+ helper.make_tensor_value_info("output", proto_type, A_shape),
+ ]
+
+ graph_nodes = [
+ or_node,
+ and_node,
+ xor_node,
+ not_node,
+ ]
+
+ graph = helper.make_graph(
+ graph_nodes,
+ "Bitwise_test",
+ inputs=graph_inputs,
+ outputs=graph_outputs,
+ )
+ model = helper.make_model(
+ graph,
+ producer_name="Bitwise_test",
+ )
+
+ verify_with_ort_with_inputs(model, tensor_values, target=target,
dev=dev)
+
+ shape = (100, 4, 2)
+ broadcast_shape = (100, 1, 1)
+ dtypes = ["int8", "uint8", "int32", "uint32"]
+ high_vals = [128, 128, 2147483648, 2147483648]
+ for high, dtype in zip(high_vals, dtypes):
+ # Common bitwise test
+ verify_bitwise_ops(shape, shape, shape, shape, high, dtype)
+ # Bitwise test with broadcasting
+ verify_bitwise_ops(shape, broadcast_shape, broadcast_shape,
broadcast_shape, high, dtype)
+
+
@tvm.testing.parametrize_targets
def test_scan(target, dev):
"""test_scan"""