NicolaLancellotti commented on a change in pull request #9442:
URL: https://github.com/apache/tvm/pull/9442#discussion_r743013392
##########
File path: python/tvm/relay/op/contrib/ethosu.py
##########
@@ -458,6 +458,348 @@ def qnn_avgpool2d_pattern() ->
tvm.relay.dataflow_pattern.DFPattern:
return pattern
+class BinaryElementwiseParams:
+ """
+ This class will parse a call to a ethosu.binary_elementwise composite
function
+ and extract the parameter information.
+ """
+
+ def __init__(self, func_body: Call, operator_type: str,
has_quantization_parameters: bool):
+ clip = None
+ if str(func_body.op) == "clip":
+ clip = func_body
+ binary_op = clip.args[0]
+ else:
+ binary_op = func_body
+
+ layout = "NHWC"
+
+ if has_quantization_parameters:
+ self.ifm = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm.value],
+ layout,
+ binary_op.args[BinaryElementwiseArgs.ifm_scale.value],
+ binary_op.args[BinaryElementwiseArgs.ifm_zero_point.value],
+ )
+ self.ifm2 = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm2.value],
+ layout,
+ binary_op.args[BinaryElementwiseArgs.ifm2_scale.value],
+ binary_op.args[BinaryElementwiseArgs.ifm2_zero_point.value],
+ )
+ self.ofm = TensorParams(
+ binary_op,
+ layout,
+ binary_op.args[BinaryElementwiseArgs.ofm_scale.value],
+ binary_op.args[BinaryElementwiseArgs.ofm_zero_point.value],
+ )
+ else:
+ self.ifm = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm.value],
+ layout,
+ )
+ self.ifm2 = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm2.value],
+ layout,
+ )
+ self.ofm = TensorParams(
+ binary_op,
+ layout,
+ )
+ self.activation = clip
+ self.operator_type = operator_type
+
+ def brodcastable(x, y):
+ for i in range(1, 4):
+ if x.shape[i] == y.shape[i] or y.shape[i] == 1:
+ continue
+ return False
+ return True
+
+ if brodcastable(self.ifm, self.ifm2):
+ self.reversed_operands = False
+ self.valid_broadcast = True
+ elif brodcastable(self.ifm2, self.ifm):
+ self.reversed_operands = True
+ self.ifm, self.ifm2 = self.ifm2, self.ifm
+ self.valid_broadcast = True
+ else:
+ self.valid_broadcast = False
+
+ def is_valid(self):
+ """
+ This function checks whether BinaryElementwise has compatible
attributes with the NPU
+ """
+ if np.dtype(self.ofm) == np.int32 and self.activation is not None:
+ return False
+ if len(self.ifm.shape) != 4 or len(self.ifm2.shape) != 4:
+ return False
+ if self.ifm.shape[0] != 1 or self.ifm2.shape[0] != 1:
+ return False
+ if not self.valid_broadcast:
+ return False
+ return True
+
+
+class AddParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Add composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.add"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "ADD", True)
+
+ def is_valid(self):
+ """
+ This function checks whether Add has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if not check_valid_dtypes(
+ [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8,
np.int8, np.int32]
+ ):
+ return False
+ return True
+
+
+def qnn_add_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for qnn.add with optional fused RELU
activation.
+ """
+ pattern = is_op("qnn.add")(
+ wildcard(),
+ wildcard(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ )
+ pattern = pattern.optional(is_op("clip"))
+ return pattern
+
+
+class SubParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Sub composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.sub"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "SUB", True)
+
+ def is_valid(self):
+ """
+ This function checks whether Sub has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if not check_valid_dtypes(
+ [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8,
np.int8, np.int32]
+ ):
+ return False
+ return True
+
+
+def qnn_subtract_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for qnn.subtract with optional fused
RELU activation.
+ """
+ pattern = is_op("qnn.subtract")(
+ wildcard(),
+ wildcard(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ )
+ pattern = pattern.optional(is_op("clip"))
+ return pattern
+
+
+class MulParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Mul composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.mul"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "MUL", True)
+
+ def is_valid(self):
+ """
+ This function checks whether Mul has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if not check_valid_dtypes(
+ [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8,
np.int8, np.int32]
+ ):
+ return False
+ return True
+
+
+def qnn_mul_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for qnn.mul with optional fused RELU
activation.
+ """
+ pattern = is_op("qnn.mul")(
+ wildcard(),
+ wildcard(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ )
+ pattern = pattern.optional(is_op("clip"))
+ return pattern
+
+
+class MinParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Min composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.min"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "MIN", False)
+
+ def is_valid(self):
+ """
+ This function checks whether Min has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if self.ifm.dtype != self.ifm2.dtype:
+ return False
+ if not check_valid_dtypes(
+ [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8,
np.int8]
+ ):
+ return False
+ return True
+
+
+def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for minimum with optional fused RELU
activation.
+ """
+ pattern = is_op("minimum")(wildcard(), wildcard())
+ pattern = pattern.optional(is_op("clip"))
+ return pattern
+
+
+class MaxParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Max composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.max"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "MAX", False)
+
+ def is_valid(self):
+ """
+ This function checks whether Max has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if self.ifm.dtype != self.ifm2.dtype:
+ return False
+ if not check_valid_dtypes(
+ [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8,
np.int8]
+ ):
+ return False
+ return True
+
+
+def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for maximum with optional fused RELU
activation.
+ """
+ pattern = is_op("maximum")(wildcard(), wildcard())
+ pattern = pattern.optional(is_op("clip"))
+ return pattern
+
+
+class ShrParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Shr composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.shr"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "SHR", False)
+
+ def is_valid(self):
+ """
+ This function checks whether Shr has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if not check_valid_dtypes([self.ifm, self.ifm2],
supported_dtypes=[np.int32]):
+ return False
+ if not check_valid_dtypes([self.ofm], supported_dtypes=[np.uint8,
np.int8, np.int32]):
+ return False
+ return True
+
+
+def shr_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for right_shift.
+ """
+ pattern = is_op("right_shift")(wildcard(), wildcard())
Review comment:
Removed
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]