NicolaLancellotti commented on a change in pull request #9442:
URL: https://github.com/apache/tvm/pull/9442#discussion_r743013392



##########
File path: python/tvm/relay/op/contrib/ethosu.py
##########
@@ -458,6 +458,348 @@ def qnn_avgpool2d_pattern() -> 
tvm.relay.dataflow_pattern.DFPattern:
     return pattern
 
 
+class BinaryElementwiseParams:
+    """
+    This class will parse a call to a ethosu.binary_elementwise composite 
function
+    and extract the parameter information.
+    """
+
+    def __init__(self, func_body: Call, operator_type: str, 
has_quantization_parameters: bool):
+        clip = None
+        if str(func_body.op) == "clip":
+            clip = func_body
+            binary_op = clip.args[0]
+        else:
+            binary_op = func_body
+
+        layout = "NHWC"
+
+        if has_quantization_parameters:
+            self.ifm = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm.value],
+                layout,
+                binary_op.args[BinaryElementwiseArgs.ifm_scale.value],
+                binary_op.args[BinaryElementwiseArgs.ifm_zero_point.value],
+            )
+            self.ifm2 = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm2.value],
+                layout,
+                binary_op.args[BinaryElementwiseArgs.ifm2_scale.value],
+                binary_op.args[BinaryElementwiseArgs.ifm2_zero_point.value],
+            )
+            self.ofm = TensorParams(
+                binary_op,
+                layout,
+                binary_op.args[BinaryElementwiseArgs.ofm_scale.value],
+                binary_op.args[BinaryElementwiseArgs.ofm_zero_point.value],
+            )
+        else:
+            self.ifm = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm.value],
+                layout,
+            )
+            self.ifm2 = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm2.value],
+                layout,
+            )
+            self.ofm = TensorParams(
+                binary_op,
+                layout,
+            )
+        self.activation = clip
+        self.operator_type = operator_type
+
+        def brodcastable(x, y):
+            for i in range(1, 4):
+                if x.shape[i] == y.shape[i] or y.shape[i] == 1:
+                    continue
+                return False
+            return True
+
+        if brodcastable(self.ifm, self.ifm2):
+            self.reversed_operands = False
+            self.valid_broadcast = True
+        elif brodcastable(self.ifm2, self.ifm):
+            self.reversed_operands = True
+            self.ifm, self.ifm2 = self.ifm2, self.ifm
+            self.valid_broadcast = True
+        else:
+            self.valid_broadcast = False
+
+    def is_valid(self):
+        """
+        This function checks whether BinaryElementwise has compatible 
attributes with the NPU
+        """
+        if np.dtype(self.ofm) == np.int32 and self.activation is not None:
+            return False
+        if len(self.ifm.shape) != 4 or len(self.ifm2.shape) != 4:
+            return False
+        if self.ifm.shape[0] != 1 or self.ifm2.shape[0] != 1:
+            return False
+        if not self.valid_broadcast:
+            return False
+        return True
+
+
+class AddParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Add composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.add"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "ADD", True)
+
+    def is_valid(self):
+        """
+        This function checks whether Add has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8, np.int32]
+        ):
+            return False
+        return True
+
+
+def qnn_add_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for qnn.add with optional fused RELU 
activation.
+    """
+    pattern = is_op("qnn.add")(
+        wildcard(),
+        wildcard(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+    )
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class SubParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Sub composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.sub"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "SUB", True)
+
+    def is_valid(self):
+        """
+        This function checks whether Sub has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8, np.int32]
+        ):
+            return False
+        return True
+
+
+def qnn_subtract_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for qnn.subtract with optional fused 
RELU activation.
+    """
+    pattern = is_op("qnn.subtract")(
+        wildcard(),
+        wildcard(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+    )
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class MulParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Mul composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.mul"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "MUL", True)
+
+    def is_valid(self):
+        """
+        This function checks whether Mul has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8, np.int32]
+        ):
+            return False
+        return True
+
+
+def qnn_mul_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for qnn.mul with optional fused RELU 
activation.
+    """
+    pattern = is_op("qnn.mul")(
+        wildcard(),
+        wildcard(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+    )
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class MinParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Min composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.min"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "MIN", False)
+
+    def is_valid(self):
+        """
+        This function checks whether Min has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if self.ifm.dtype != self.ifm2.dtype:
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8]
+        ):
+            return False
+        return True
+
+
+def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for minimum with optional fused RELU 
activation.
+    """
+    pattern = is_op("minimum")(wildcard(), wildcard())
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class MaxParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Max composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.max"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "MAX", False)
+
+    def is_valid(self):
+        """
+        This function checks whether Max has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if self.ifm.dtype != self.ifm2.dtype:
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8]
+        ):
+            return False
+        return True
+
+
+def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for maximum with optional fused RELU 
activation.
+    """
+    pattern = is_op("maximum")(wildcard(), wildcard())
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class ShrParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Shr composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.shr"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "SHR", False)
+
+    def is_valid(self):
+        """
+        This function checks whether Shr has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes([self.ifm, self.ifm2], 
supported_dtypes=[np.int32]):
+            return False
+        if not check_valid_dtypes([self.ofm], supported_dtypes=[np.uint8, 
np.int8, np.int32]):
+            return False
+        return True
+
+
+def shr_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for right_shift.
+    """
+    pattern = is_op("right_shift")(wildcard(), wildcard())

Review comment:
       Removed




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to