ekalda commented on a change in pull request #9442:
URL: https://github.com/apache/tvm/pull/9442#discussion_r742898510



##########
File path: python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py
##########
@@ -0,0 +1,169 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-argument
+"""Tensor Expressions for binary_elementwise"""
+import operator
+from tvm import te
+from .dma import dma_ofm_compute, dma_ifm_compute
+
+
+def binary_elementwise_compute(
+    ifm: te.Tensor,
+    ifm2: te.Tensor,
+    lut: te.Tensor,
+    operator_type: str,
+    ifm_scale: float,
+    ifm_zero_point: int,
+    ifm2_scale: float,
+    ifm2_zero_point: int,
+    ofm_scale: float,
+    ofm_zero_point: int,
+    ofm_channels: int,
+    reversed_operands: bool,
+    activation: str,
+    clip_min: int,
+    clip_max: int,
+    ifm_layout: str,
+    ifm2_layout: str,
+    ofm_layout: str,
+) -> te.Tensor:
+    """A compute operator representing the capabilities of binary_elementwise 
for the NPU.
+
+    Parameters
+    ----------
+    ifm : te.Tensor
+        The Input Feature Map tensor (IFM).
+    ifm2 : te.Tensor
+        The Input Feature Map tensor 1 (IFM2).
+    lut : te.Tensor
+        The look-up table values to use if activation = "LUT".
+    operator_type: str
+        The type of the binary elementwise operator.
+            "ADD"
+            "SUB"
+            "MUL"
+            "MIN"
+            "MAX"
+            "SHR"
+            "SHL"
+    ifm_scale : float
+        The quantization scale for the Input Feature Map tensor.
+    ifm_zero_point : int
+        The quantization zero point for the Input Feature Map tensor.
+    ifm2_scale : float
+        The quantization scale for the Input Feature Map tensor 1.

Review comment:
       Maybe call this "tensor 2" to comply with the parameter name and also 
with the documentation in C++ Relay operator implementation 

##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -558,5 +558,193 @@ def verify(ext_func):
     verify(mod["tvmgen_default_ethosu_main_0"])
 
 
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
[email protected](
+    "ifm_shape, ifm2_shape, reversed_operands",
+    [
+        ([1, 2, 3, 4], [1, 2, 3, 4], False),
+        ([1, 2, 3, 4], [1, 1, 3, 1], False),
+        ([1, 1, 3, 1], [1, 2, 3, 4], True),
+    ],
+)
[email protected]("activation_function", ["NONE", "RELU"])
+def test_tflite_binary_elemwise_legalize(
+    operator_type,
+    ifm_shape,
+    ifm2_shape,
+    reversed_operands,
+    activation_function,
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x, y):
+                if operator_type == "ADD":
+                    op = tf.math.add(x, y)
+                elif operator_type == "SUB":
+                    op = tf.math.subtract(x, y)
+                elif operator_type == "MUL":
+                    op = tf.math.multiply(x, y)
+                elif operator_type == "MIN":
+                    op = tf.math.minimum(x, y)
+                elif operator_type == "MAX":
+                    op = tf.math.maximum(x, y)
+                if activation_function == "RELU":
+                    op = tf.nn.relu(op)
+                return op
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32), 
tf.TensorSpec(ifm2_shape, dtype=tf.float32)
+        )
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                data2 = np.random.rand(*tuple(ifm2_shape)) * 2
+                yield [data.astype(np.float32), data2.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        out_shape = ifm2_shape if reversed_operands else ifm_shape
+        shapes = [ifm_shape, ifm2_shape]
+        ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1)
+        op = ext_func.body
+        assert list(op.args[0].checked_type.shape) == shapes[ifm_index]
+        assert list(op.args[1].checked_type.shape) == shapes[ifm2_index]
+        assert op.args[0].checked_type.dtype == dtype
+        assert list(op.checked_type.shape) == out_shape
+        assert op.checked_type.dtype == dtype
+        assert op.attrs.operator_type == operator_type
+        assert op.attrs.reversed_operands == reversed_operands
+        if activation_function == "RELU":
+            assert str(op.attrs.activation) == "CLIP"
+
+    if operator_type == "ADD":
+        rewriter = legalize.AddRewriter()
+        pattern_table = [
+            (
+                ethosu.AddParams.composite_name,
+                ethosu.qnn_add_pattern(),
+                lambda pat: ethosu.AddParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "SUB":
+        rewriter = legalize.SubRewriter()
+        pattern_table = [
+            (
+                ethosu.SubParams.composite_name,
+                ethosu.qnn_subtract_pattern(),
+                lambda pat: ethosu.SubParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "MUL":
+        rewriter = legalize.MulRewriter()
+        pattern_table = [
+            (
+                ethosu.MulParams.composite_name,
+                ethosu.qnn_mul_pattern(),
+                lambda pat: ethosu.MulParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "MIN":
+        rewriter = legalize.MinRewriter()
+        pattern_table = [
+            (
+                ethosu.MinParams.composite_name,
+                ethosu.minimum_pattern(),
+                lambda pat: ethosu.MinParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "MAX":
+        rewriter = legalize.MaxRewriter()
+        pattern_table = [
+            (
+                ethosu.MaxParams.composite_name,
+                ethosu.maximum_pattern(),
+                lambda pat: ethosu.MaxParams(pat).is_valid(),
+            ),
+        ]
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, _ = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"x": ifm_shape, "y": ifm2_shape},
+        dtype_dict={"x": dtype, "y": dtype},
+    )
+    mod = partition_ethosu_by_table(mod, pattern_table)
+
+    mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
+        rewriter, mod["tvmgen_default_ethosu_main_0"]
+    )
+    verify(mod["tvmgen_default_ethosu_main_0"])
+
+
[email protected](
+    "ifm_shape, ifm2_shape, reversed_operands",
+    [
+        ([1, 2, 3, 4], [1, 2, 3, 4], False),
+        ([1, 2, 3, 4], [1, 1, 3, 1], False),
+        ([1, 1, 3, 1], [1, 2, 3, 4], True),
+    ],
+)
+def test_ethosu_left_shift_binary_elemwise_legalize(ifm_shape, ifm2_shape, 
reversed_operands):
+    dtype = "int32"
+    operator_type = "SHL"
+
+    def create_graph():
+        input1 = relay.var("x1", shape=ifm_shape, dtype=dtype)
+        input2 = relay.var("x2", shape=ifm2_shape, dtype=dtype)
+        c1 = relay.left_shift(input1, input2)

Review comment:
       Since left shift legalization from Relay is tested, does it mean that 
the NPU left shift does the "actual" left shift, but right shift doesn't, so we 
shouldn't legalize the Relay right shift to Ethos-U right shift but should 
legalize the left shift? It seems a bit odd to me that one of the shifts has 
Relay legalization and the other doesn't, especially since I don't know what 
would be the use of Relay left shift for the NPU in the absence of 
corresponding TFLite operator. Any thoughts @manupa-arm @mbaret @lhutton1 

##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -343,5 +343,256 @@ def representative_dataset():
     infra.verify_source(compiled_models, accel_type)
 
 
[email protected]("accel_type", ACCEL_TYPES)
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
[email protected](
+    "ifm_shape, ifm2_shape",
+    [
+        ([1, 2, 3, 4], [1, 2, 3, 4]),
+        ([1, 2, 3, 4], [1, 1, 3, 1]),
+        ([1, 1, 3, 1], [1, 2, 3, 4]),
+    ],
+)
[email protected]("activation_function", ["NONE", "RELU"])
+def test_ethosu_binary_elementwise(
+    accel_type,
+    operator_type,
+    ifm_shape,
+    ifm2_shape,
+    activation_function,
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        tf.config.run_functions_eagerly(True)
+
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, lhs, rhs):
+                if operator_type == "ADD":
+                    op = tf.math.add(lhs, rhs)
+                elif operator_type == "SUB":
+                    op = tf.math.subtract(lhs, rhs)
+                elif operator_type == "MUL":
+                    op = tf.math.multiply(lhs, rhs)
+                elif operator_type == "MIN":
+                    op = tf.math.minimum(lhs, rhs)
+                elif operator_type == "MAX":
+                    op = tf.math.maximum(lhs, rhs)
+                if activation_function == "RELU":
+                    op = tf.nn.relu(op)
+                return op
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32), 
tf.TensorSpec(ifm2_shape, dtype=tf.float32)
+        )
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                data2 = np.random.rand(*tuple(ifm2_shape)) * 2
+                yield [data.astype(np.float32), data2.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape},
+        dtype_dict={"ifm": dtype, "ifm2": dtype},
+    )
+    mod = partition_for_ethosu(mod, params)
+
+    # Generate reference data
+    input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+    compiled_models = infra.build_source(
+        mod,
+        input_data,
+        output_data,
+        accel_type,
+        output_tolerance=1 if operator_type == "MAX" else 0,

Review comment:
       Why tolerance 1 for MAX?

##########
File path: python/tvm/relay/op/contrib/ethosu.py
##########
@@ -458,6 +458,348 @@ def qnn_avgpool2d_pattern() -> 
tvm.relay.dataflow_pattern.DFPattern:
     return pattern
 
 
+class BinaryElementwiseParams:
+    """
+    This class will parse a call to a ethosu.binary_elementwise composite 
function
+    and extract the parameter information.
+    """
+
+    def __init__(self, func_body: Call, operator_type: str, 
has_quantization_parameters: bool):
+        clip = None
+        if str(func_body.op) == "clip":
+            clip = func_body
+            binary_op = clip.args[0]
+        else:
+            binary_op = func_body
+
+        layout = "NHWC"
+
+        if has_quantization_parameters:
+            self.ifm = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm.value],
+                layout,
+                binary_op.args[BinaryElementwiseArgs.ifm_scale.value],
+                binary_op.args[BinaryElementwiseArgs.ifm_zero_point.value],
+            )
+            self.ifm2 = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm2.value],
+                layout,
+                binary_op.args[BinaryElementwiseArgs.ifm2_scale.value],
+                binary_op.args[BinaryElementwiseArgs.ifm2_zero_point.value],
+            )
+            self.ofm = TensorParams(
+                binary_op,
+                layout,
+                binary_op.args[BinaryElementwiseArgs.ofm_scale.value],
+                binary_op.args[BinaryElementwiseArgs.ofm_zero_point.value],
+            )
+        else:
+            self.ifm = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm.value],
+                layout,
+            )
+            self.ifm2 = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm2.value],
+                layout,
+            )
+            self.ofm = TensorParams(
+                binary_op,
+                layout,
+            )
+        self.activation = clip
+        self.operator_type = operator_type
+
+        def brodcastable(x, y):
+            for i in range(1, 4):
+                if x.shape[i] == y.shape[i] or y.shape[i] == 1:
+                    continue
+                return False
+            return True
+
+        if brodcastable(self.ifm, self.ifm2):
+            self.reversed_operands = False
+            self.valid_broadcast = True
+        elif brodcastable(self.ifm2, self.ifm):
+            self.reversed_operands = True
+            self.ifm, self.ifm2 = self.ifm2, self.ifm
+            self.valid_broadcast = True
+        else:
+            self.valid_broadcast = False
+
+    def is_valid(self):
+        """
+        This function checks whether BinaryElementwise has compatible 
attributes with the NPU
+        """
+        if np.dtype(self.ofm) == np.int32 and self.activation is not None:
+            return False
+        if len(self.ifm.shape) != 4 or len(self.ifm2.shape) != 4:
+            return False
+        if self.ifm.shape[0] != 1 or self.ifm2.shape[0] != 1:
+            return False
+        if not self.valid_broadcast:
+            return False
+        return True
+
+
+class AddParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Add composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.add"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "ADD", True)
+
+    def is_valid(self):
+        """
+        This function checks whether Add has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8, np.int32]
+        ):
+            return False
+        return True
+
+
+def qnn_add_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for qnn.add with optional fused RELU 
activation.
+    """
+    pattern = is_op("qnn.add")(
+        wildcard(),
+        wildcard(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+    )
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class SubParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Sub composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.sub"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "SUB", True)
+
+    def is_valid(self):
+        """
+        This function checks whether Sub has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8, np.int32]
+        ):
+            return False
+        return True
+
+
+def qnn_subtract_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for qnn.subtract with optional fused 
RELU activation.
+    """
+    pattern = is_op("qnn.subtract")(
+        wildcard(),
+        wildcard(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+    )
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class MulParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Mul composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.mul"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "MUL", True)
+
+    def is_valid(self):
+        """
+        This function checks whether Mul has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8, np.int32]
+        ):
+            return False
+        return True
+
+
+def qnn_mul_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for qnn.mul with optional fused RELU 
activation.
+    """
+    pattern = is_op("qnn.mul")(
+        wildcard(),
+        wildcard(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+    )
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class MinParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Min composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.min"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "MIN", False)
+
+    def is_valid(self):
+        """
+        This function checks whether Min has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if self.ifm.dtype != self.ifm2.dtype:
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8]
+        ):
+            return False
+        return True
+
+
+def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for minimum with optional fused RELU 
activation.
+    """
+    pattern = is_op("minimum")(wildcard(), wildcard())
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class MaxParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Max composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.max"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "MAX", False)
+
+    def is_valid(self):
+        """
+        This function checks whether Max has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if self.ifm.dtype != self.ifm2.dtype:
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8]
+        ):
+            return False
+        return True
+
+
+def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for maximum with optional fused RELU 
activation.
+    """
+    pattern = is_op("maximum")(wildcard(), wildcard())
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class ShrParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Shr composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.shr"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "SHR", False)
+
+    def is_valid(self):
+        """
+        This function checks whether Shr has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes([self.ifm, self.ifm2], 
supported_dtypes=[np.int32]):
+            return False
+        if not check_valid_dtypes([self.ofm], supported_dtypes=[np.uint8, 
np.int8, np.int32]):
+            return False
+        return True
+
+
+def shr_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for right_shift.
+    """
+    pattern = is_op("right_shift")(wildcard(), wildcard())

Review comment:
       I think we don't need a pattern and params class for right shift if we 
are not going to legalize it for the NPU

##########
File path: python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py
##########
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""Relay operators for binary elementwise operators for Arm(R) Ethos(TM)-U 
NPU"""
+import tvm
+from tvm.relay.op import _make
+from tvm.topi.generic import schedule_injective
+from tvm.relay.op.op import OpStrategy
+from tvm.relay.op import strategy as _strategy
+
+from ..te import binary_elementwise_compute
+
+
+def _extract_ethosu_binary_elementwise_params(attrs, args):
+    """Get the parameters necessary to construct a ethosu_binary_elementwise 
compute TE
+    from a ethosu_binary_elementwise Relay call."""
+    ifm = args[0]
+    ifm2 = args[1]
+    lut = args[2]
+    operator_type = attrs.operator_type
+    ifm_scale = attrs.ifm_scale
+    ifm_zero_point = attrs.ifm_zero_point
+    ifm2_scale = attrs.ifm2_scale
+    ifm2_zero_point = attrs.ifm2_zero_point
+    ofm_scale = attrs.ofm_scale
+    ofm_zero_point = attrs.ofm_zero_point
+    ofm_channels = attrs.ofm_channels
+    reversed_operands = attrs.reversed_operands
+    activation = attrs.activation
+    clip_min = attrs.clip_min
+    clip_max = attrs.clip_max
+    ifm_layout = attrs.ifm_layout
+    ifm2_layout = attrs.ifm2_layout
+    ofm_layout = attrs.ofm_layout
+
+    return (
+        ifm,
+        ifm2,
+        lut,
+        operator_type,
+        ifm_scale,
+        ifm_zero_point,
+        ifm2_scale,
+        ifm2_zero_point,
+        ofm_scale,
+        ofm_zero_point,
+        ofm_channels,
+        reversed_operands,
+        activation,
+        clip_min,
+        clip_max,
+        ifm_layout,
+        ifm2_layout,
+        ofm_layout,
+    )
+
+
[email protected]_op_attr("contrib.ethosu.binary_elementwise", "FTVMCompute")
+def create_ethosu_binary_elementwise_compute(attrs, args, out_type):
+    """Create an ethosu_binary_elementwise compute op."""
+    params = _extract_ethosu_binary_elementwise_params(attrs, args)
+    op = binary_elementwise_compute(*params)
+    return [op]
+
+
[email protected]_op_attr("contrib.ethosu.binary_elementwise", "FTVMStrategy")
+def binary_elementwise_strategy_ethosu(attrs, inputs, out_type, target):
+    strategy = OpStrategy()
+    strategy.add_implementation(
+        create_ethosu_binary_elementwise_compute,
+        _strategy.wrap_topi_schedule(schedule_injective),
+        name="ethosu_binary_elementwise",
+    )
+    return strategy
+
+
+def ethosu_binary_elementwise(
+    ifm: tvm.relay.Expr,
+    ifm2: tvm.relay.Expr,
+    lut: tvm.relay.Expr,
+    operator_type: str,
+    ifm_scale: float,
+    ifm_zero_point: int,
+    ifm2_scale: float,
+    ifm2_zero_point: int,
+    ofm_scale: float,
+    ofm_zero_point: int,
+    ofm_channels: int,
+    reversed_operands: bool,
+    ofm_dtype: str,
+    activation: str = "NONE",
+    clip_min: int = 0,
+    clip_max: int = 0,
+    ifm_layout: str = "NHWC",
+    ifm2_layout: str = "NHWC",
+    ofm_layout: str = "NHWC",
+) -> tvm.relay.Call:
+    """This is a quantized binary elementwise operation as supported by
+    the NPU. It accepts either NHWC or NHCWB16 format
+    for the input data.
+
+    Parameters
+    ----------
+    ifm : tvm.relay.Expr
+        The Input Feature Map tensor (IFM).
+    ifm2 : tvm.relay.Expr
+        The Input Feature Map tensor 2 (IFM2).
+    lut : tvm.relay.Expr
+        The look-up table of values to use if activation = "LUT".
+    operator_type: str
+        The type of the binary elementwise operator.
+            "ADD"
+            "SUB"
+            "MUL"
+            "MIN"
+            "MAX"
+            "SHR"
+            "SHL"
+    ifm_scale : float
+        The quantization scale for the Input Feature Map tensor.
+    ifm_zero_point : int
+        The quantization zero point for the Input Feature Map tensor.
+    ifm2_scale : float
+        The quantization scale for the Input Feature Map tensor 2.
+    ifm2_zero_point : int
+        The quantization zero point for the Input Feature Map tensor 2.
+    ofm_scale : float
+        The quantization scale for the Output Feature Map tensor.
+    ofm_zero_point : int
+       The quantization zero point for the Output Feature Map tensor.
+    ofm_channels : int
+        The number of the Output Feature Map channels.
+    reversed_operands : bool
+        Specific if IFM2 is the first operand and IFM is the second operand.

Review comment:
       Maybe "True" instead of "Specific"?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to