This is an automated email from the ASF dual-hosted git repository.

mbaret pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e81d391  Arm(R) Ethos(TM)-U NPU BinaryElementwise operators support 
(#9442)
e81d391 is described below

commit e81d391640ed8c536ecc584ab3b60d756cce47f5
Author: Nicola Lancellotti <[email protected]>
AuthorDate: Thu Nov 11 17:27:40 2021 +0000

    Arm(R) Ethos(TM)-U NPU BinaryElementwise operators support (#9442)
    
    This commit adds support for the binary elementwise primitive operators for 
the Arm(R) Ethos(TM)-U NPU and includes a few minor rewording changes.
---
 .../tvm/relay/backend/contrib/ethosu/legalize.py   | 227 +++++++++++
 .../relay/backend/contrib/ethosu/op/__init__.py    |   1 +
 .../contrib/ethosu/op/binary_elementwise.py        | 215 ++++++++++
 .../relay/backend/contrib/ethosu/op/convolution.py |   4 +-
 .../relay/backend/contrib/ethosu/op/depthwise.py   |   7 +-
 .../tvm/relay/backend/contrib/ethosu/op/pooling.py |   6 +-
 .../relay/backend/contrib/ethosu/te/__init__.py    |   1 +
 .../contrib/ethosu/te/binary_elementwise.py        | 184 +++++++++
 .../contrib/ethosu/tir/binary_elementwise.py       | 102 +++++
 .../tvm/relay/backend/contrib/ethosu/tir/passes.py |   2 +
 .../tvm/relay/backend/contrib/ethosu/tir/spec.py   |  21 +
 .../backend/contrib/ethosu/tir_to_cs_translator.py |  76 +++-
 python/tvm/relay/backend/contrib/ethosu/util.py    |  15 +
 python/tvm/relay/op/contrib/ethosu.py              | 352 ++++++++++++++++-
 src/relay/op/contrib/ethosu/binary_elementwise.cc  | 301 ++++++++++++++
 src/relay/op/contrib/ethosu/common.cc              |  18 +
 src/relay/op/contrib/ethosu/common.h               |  11 +
 src/relay/op/contrib/ethosu/pooling.cc             |   2 +-
 tests/python/contrib/test_ethosu/infra.py          |  53 +++
 tests/python/contrib/test_ethosu/test_codegen.py   | 252 +++++++++++-
 tests/python/contrib/test_ethosu/test_legalize.py  | 188 +++++++++
 .../test_ethosu/test_replace_binary_elementwise.py | 335 ++++++++++++++++
 .../test_ethosu/test_tir_to_cs_translator.py       | 434 +++++++++++++++++++++
 .../contrib/test_ethosu/test_type_inference.py     | 116 ++++++
 24 files changed, 2902 insertions(+), 21 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py 
b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index c4b70c1..d0d04ce 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -413,6 +413,224 @@ class LegalizeAvgPooling:
         pass
 
 
+class BinaryElementwiseRewriter(DFPatternCallback):
+    """Convert ethosu binary elementwise composite functions to
+    ethosu_binary_elementwise operators"""
+
+    def __init__(
+        self,
+        params_class: Type,
+        pattern: CallPattern,
+    ):
+        super().__init__(require_type=True)
+        self.params_class = params_class
+        self.pattern = pattern
+
+    def callback(
+        self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: 
tvm.ir.container.Map
+    ) -> tvm.relay.Expr:
+        params = self.params_class(post.op.body)
+        params.ifm.tensor = post.args[1] if params.reversed_operands else 
post.args[0]
+        params.ifm2.tensor = post.args[0] if params.reversed_operands else 
post.args[1]
+        channels_map = {
+            "NHWC": 3,
+        }
+        if str(params.ofm.layout) not in channels_map.keys():
+            raise UnsupportedLayout(str(params.ofm.layout))
+
+        activation_map = {"clip": "CLIP"}
+        if params.activation:
+            activation = activation_map[params.activation.op.name]
+            clip_min = int(params.activation.attrs.a_min)
+            clip_max = int(params.activation.attrs.a_max)
+        else:
+            activation = "NONE"
+            clip_min = 0
+            clip_max = 0
+
+        # We don't yet support activation functions that need to get legalized 
to LUTs.
+        lut = relay.const([], dtype="int8")
+
+        return ethosu_ops.ethosu_binary_elementwise(
+            ifm=params.ifm.tensor,
+            ifm2=params.ifm2.tensor,
+            lut=lut,
+            operator_type=params.operator_type,
+            ifm_scale=float(params.ifm.q_params.scale_f32),
+            ifm_zero_point=int(params.ifm.q_params.zero_point),
+            ifm2_scale=float(params.ifm2.q_params.scale_f32),
+            ifm2_zero_point=int(params.ifm2.q_params.zero_point),
+            ofm_scale=float(params.ofm.q_params.scale_f32),
+            ofm_zero_point=int(params.ofm.q_params.zero_point),
+            ifm_channels=params.ifm.shape[3],
+            ifm2_channels=params.ifm2.shape[3],
+            reversed_operands=params.reversed_operands,
+            ofm_dtype=params.ofm.dtype,
+            activation=activation,
+            clip_min=clip_min,
+            clip_max=clip_max,
+            ifm_layout=str(params.ifm.layout),
+            ifm2_layout=str(params.ifm2.layout),
+            ofm_layout=str(params.ofm.layout),
+        )
+
+
+class AddRewriter(BinaryElementwiseRewriter):
+    def __init__(self):
+        super().__init__(
+            params_class=ethosu_patterns.AddParams,
+            pattern=(wildcard().has_attr({"Composite": 
ethosu_patterns.AddParams.composite_name}))(
+                wildcard(), wildcard()
+            ),
+        )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeAdd:
+    """This is the pass that wraps the AddRewriter"""
+
+    def transform_module(
+        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+    ) -> tvm.ir.IRModule:
+        for global_var, func in mod.functions.items():
+            func = rewrite(AddRewriter(), func)
+            mod.update_func(global_var, func)
+        return mod
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
+class SubRewriter(BinaryElementwiseRewriter):
+    def __init__(self):
+        super().__init__(
+            params_class=ethosu_patterns.SubParams,
+            pattern=(wildcard().has_attr({"Composite": 
ethosu_patterns.SubParams.composite_name}))(
+                wildcard(), wildcard()
+            ),
+        )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeSub:
+    """This is the pass that wraps the SubRewriter"""
+
+    def transform_module(
+        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+    ) -> tvm.ir.IRModule:
+        for global_var, func in mod.functions.items():
+            func = rewrite(SubRewriter(), func)
+            mod.update_func(global_var, func)
+        return mod
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
+class MulRewriter(BinaryElementwiseRewriter):
+    def __init__(self):
+        super().__init__(
+            params_class=ethosu_patterns.MulParams,
+            pattern=(wildcard().has_attr({"Composite": 
ethosu_patterns.MulParams.composite_name}))(
+                wildcard(), wildcard()
+            ),
+        )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeMul:
+    """This is the pass that wraps the MulRewriter"""
+
+    def transform_module(
+        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+    ) -> tvm.ir.IRModule:
+        for global_var, func in mod.functions.items():
+            func = rewrite(MulRewriter(), func)
+            mod.update_func(global_var, func)
+        return mod
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
+class MinRewriter(BinaryElementwiseRewriter):
+    def __init__(self):
+        super().__init__(
+            params_class=ethosu_patterns.MinParams,
+            pattern=(wildcard().has_attr({"Composite": 
ethosu_patterns.MinParams.composite_name}))(
+                wildcard(), wildcard()
+            ),
+        )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeMin:
+    """This is the pass that wraps the MinRewriter"""
+
+    def transform_module(
+        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+    ) -> tvm.ir.IRModule:
+        for global_var, func in mod.functions.items():
+            func = rewrite(MinRewriter(), func)
+            mod.update_func(global_var, func)
+        return mod
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
+class MaxRewriter(BinaryElementwiseRewriter):
+    def __init__(self):
+        super().__init__(
+            params_class=ethosu_patterns.MaxParams,
+            pattern=(wildcard().has_attr({"Composite": 
ethosu_patterns.MaxParams.composite_name}))(
+                wildcard(), wildcard()
+            ),
+        )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeMax:
+    """This is the pass that wraps the MaxRewriter"""
+
+    def transform_module(
+        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+    ) -> tvm.ir.IRModule:
+        for global_var, func in mod.functions.items():
+            func = rewrite(MaxRewriter(), func)
+            mod.update_func(global_var, func)
+        return mod
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
+class ShlRewriter(BinaryElementwiseRewriter):
+    def __init__(self):
+        super().__init__(
+            params_class=ethosu_patterns.ShlParams,
+            pattern=(wildcard().has_attr({"Composite": 
ethosu_patterns.ShlParams.composite_name}))(
+                wildcard(), wildcard()
+            ),
+        )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeShl:
+    """This is the pass that wraps the ShlRewriter"""
+
+    def transform_module(
+        self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+    ) -> tvm.ir.IRModule:
+        for global_var, func in mod.functions.items():
+            func = rewrite(ShlRewriter(), func)
+            mod.update_func(global_var, func)
+        return mod
+
+    def __call__(self, *args, **kwargs):
+        pass
+
+
 @ir.transform.module_pass(opt_level=1)
 class LegalizeEthosU:
     """This is the pass to call graph-rewrites to perform graph transformation
@@ -423,11 +641,20 @@ class LegalizeEthosU:
     def transform_module(
         self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
     ) -> tvm.ir.IRModule:
+        """This is the method that replaces the operations with 
hardware/codegen supported
+        operations.
+        """
         mod = LegalizeSplit()(mod)
         mod = LegalizeConv2D()(mod)
         mod = LegalizeDepthwiseConv2D()(mod)
         mod = LegalizeMaxPooling()(mod)
         mod = LegalizeAvgPooling()(mod)
+        mod = LegalizeAdd()(mod)
+        mod = LegalizeSub()(mod)
+        mod = LegalizeMul()(mod)
+        mod = LegalizeMin()(mod)
+        mod = LegalizeMax()(mod)
+        mod = LegalizeShl()(mod)
         return mod
 
     def __call__(self, *args, **kwargs):
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py 
b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py
index c9aa59b..05d4053 100644
--- a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py
+++ b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py
@@ -19,3 +19,4 @@
 from .convolution import ethosu_conv2d
 from .depthwise import ethosu_depthwise_conv2d
 from .pooling import ethosu_pooling
+from .binary_elementwise import ethosu_binary_elementwise
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py 
b/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py
new file mode 100644
index 0000000..d4ae18b
--- /dev/null
+++ b/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py
@@ -0,0 +1,215 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""Relay operators for binary elementwise operators for Arm(R) Ethos(TM)-U 
NPU"""
+from typing import Optional
+import tvm
+from tvm.relay.op import _make
+from tvm.topi.generic import schedule_injective
+from tvm.relay.op.op import OpStrategy
+from tvm.relay.op import strategy as _strategy
+
+from ..te import binary_elementwise_compute
+
+
+def _extract_ethosu_binary_elementwise_params(attrs, args):
+    """Get the parameters necessary to construct a ethosu_binary_elementwise 
compute TE
+    from a ethosu_binary_elementwise Relay call."""
+    ifm = args[0]
+    ifm2 = args[1]
+    lut = args[2]
+    operator_type = attrs.operator_type
+    ifm_scale = attrs.ifm_scale
+    ifm_zero_point = attrs.ifm_zero_point
+    ifm2_scale = attrs.ifm2_scale
+    ifm2_zero_point = attrs.ifm2_zero_point
+    ofm_scale = attrs.ofm_scale
+    ofm_zero_point = attrs.ofm_zero_point
+    ifm_channels = attrs.ifm_channels
+    ifm2_channels = attrs.ifm2_channels
+    reversed_operands = attrs.reversed_operands
+    activation = attrs.activation
+    clip_min = attrs.clip_min
+    clip_max = attrs.clip_max
+    ifm_layout = attrs.ifm_layout
+    ifm2_layout = attrs.ifm2_layout
+    ofm_layout = attrs.ofm_layout
+    ofm_dtype = attrs.ofm_dtype
+
+    return (
+        ifm,
+        ifm2,
+        lut,
+        operator_type,
+        ifm_scale,
+        ifm_zero_point,
+        ifm2_scale,
+        ifm2_zero_point,
+        ofm_scale,
+        ofm_zero_point,
+        ifm_channels,
+        ifm2_channels,
+        reversed_operands,
+        activation,
+        clip_min,
+        clip_max,
+        ifm_layout,
+        ifm2_layout,
+        ofm_layout,
+        ofm_dtype,
+    )
+
+
[email protected]_op_attr("contrib.ethosu.binary_elementwise", "FTVMCompute")
+def create_ethosu_binary_elementwise_compute(attrs, args, out_type):
+    """Create an ethosu_binary_elementwise compute op."""
+    params = _extract_ethosu_binary_elementwise_params(attrs, args)
+    op = binary_elementwise_compute(*params)
+    return [op]
+
+
[email protected]_op_attr("contrib.ethosu.binary_elementwise", "FTVMStrategy")
+def binary_elementwise_strategy_ethosu(attrs, inputs, out_type, target):
+    strategy = OpStrategy()
+    strategy.add_implementation(
+        create_ethosu_binary_elementwise_compute,
+        _strategy.wrap_topi_schedule(schedule_injective),
+        name="ethosu_binary_elementwise",
+    )
+    return strategy
+
+
+def ethosu_binary_elementwise(
+    ifm: tvm.relay.Expr,
+    ifm2: tvm.relay.Expr,
+    lut: tvm.relay.Expr,
+    operator_type: str,
+    ifm_scale: float,
+    ifm_zero_point: int,
+    ifm2_scale: float,
+    ifm2_zero_point: int,
+    ofm_scale: float,
+    ofm_zero_point: int,
+    ifm_channels: int,
+    ifm2_channels: int,
+    reversed_operands: bool,
+    ofm_dtype: str,
+    activation: Optional[str] = "NONE",
+    clip_min: Optional[int] = 0,
+    clip_max: Optional[int] = 0,
+    ifm_layout: Optional[str] = "NHWC",
+    ifm2_layout: Optional[str] = "NHWC",
+    ofm_layout: Optional[str] = "NHWC",
+) -> tvm.relay.Call:
+    """This is a quantized binary elementwise operation as supported by
+    the NPU. It accepts either NHWC or NHCWB16 format
+    for the input data.
+
+    Parameters
+    ----------
+    ifm : tvm.relay.Expr
+        The Input Feature Map tensor (IFM).
+    ifm2 : tvm.relay.Expr
+        The Input Feature Map tensor 2 (IFM2).
+    lut : tvm.relay.Expr
+        The look-up table of values to use if activation = "LUT".
+    operator_type: str
+        The type of the binary elementwise operator.
+            "ADD"
+            "SUB"
+            "MUL"
+            "MIN"
+            "MAX"
+            "SHR"
+            "SHL"
+    ifm_scale : float
+        The quantization scale for the Input Feature Map tensor.
+    ifm_zero_point : int
+        The quantization zero point for the Input Feature Map tensor.
+    ifm2_scale : float
+        The quantization scale for the Input Feature Map tensor 2.
+    ifm2_zero_point : int
+        The quantization zero point for the Input Feature Map tensor 2.
+    ofm_scale : float
+        The quantization scale for the Output Feature Map tensor.
+    ofm_zero_point : int
+       The quantization zero point for the Output Feature Map tensor.
+    ifm_channels : int
+        The number of the Input Feature Map channels.
+    ifm2_channels : int
+        The number of the Input Feature Map 2 channels.
+    reversed_operands : bool
+        True if IFM2 is the first operand and IFM is the second operand.
+    ofm_dtype: str
+        The Output Feature Map tensor type.
+        MUL, ADD, SUB {IFM}->{OFM}:
+          {uint8, int8 int32} -> {uint8, int8, int32}, any pairing
+        MAX, MIN:
+          IFM and OFM must be of the same type, one of:
+          {int8, uint8}
+        SHR {IFM}->{OFM}:
+          {int32}->{int8, uint8, int32}, any pairing"
+        SHL:
+          {int32}->{int32} only
+    activation : str, optional
+        The activation function to use.
+            "NONE" - no activation function.
+            "CLIP" - clip the output between clip_min and clip_max.
+            "TANH" - tanh activation function.
+            "SIGMOID" - sigmoid activation function.
+            "LUT" - use a look-up table to perform the activation function.
+        Available activations for activation type:
+            {int8, uint8}: "NONE", "CLIP", "TANH", "SIGMOID", "LUT"
+            {int32}: "NONE"
+    clip_min : int, optional
+        The minimum clipping value if activation = "CLIP".
+    clip_max : int, optional
+        The maximum clipping value if activation = "CLIP".
+    ifm_layout : str, optional
+        The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
+    ifm2_layout : str, optional
+        The layout of the Input Feature Map tensor 2. Can be "NHWC" or 
"NHCWB16".
+    ofm_layout : str, optional
+        The layout of the Output Feature Map tensor. Can be "NHWC" or 
"NHCWB16".
+
+    Returns
+    -------
+    out : tvm.relay.Call
+        A call to the ethosu_binary_elementwise op.
+    """
+    return _make.ethosu_binary_elementwise(
+        ifm,
+        ifm2,
+        lut,
+        operator_type,
+        ifm_scale,
+        ifm_zero_point,
+        ifm2_scale,
+        ifm2_zero_point,
+        ofm_scale,
+        ofm_zero_point,
+        ifm_channels,
+        ifm2_channels,
+        reversed_operands,
+        activation,
+        clip_min,
+        clip_max,
+        ifm_layout,
+        ifm2_layout,
+        ofm_layout,
+        ofm_dtype,
+    )
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py 
b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py
index 7fb054e..970e366 100644
--- a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py
+++ b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py
@@ -112,8 +112,8 @@ def ethosu_conv2d(
     ifm_layout: str = "NHWC",
     ofm_layout: str = "NHWC",
 ) -> tvm.relay.Call:
-    """This is a quantized 2D convolution operation as supported by the
-    Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format
+    """This is a quantized 2D convolution operation as supported by
+    the NPU. It accepts either NHWC or NHCWB16 format
     for the input data and OHWI format for the kernel weights.
 
     Reference: https://developer.arm.com/documentation/102420/0200/
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py 
b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py
index d1b49ef..d8f2e8b 100644
--- a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py
+++ b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py
@@ -15,7 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=unused-argument
-"""Relay operator for depthwise convolution"""
+"""Relay operator for depthwise convolution for Arm(R) Ethos(TM)-U NPU"""
+
 from typing import Tuple
 
 import tvm
@@ -112,8 +113,8 @@ def ethosu_depthwise_conv2d(
     ifm_layout: str = "NHWC",
     ofm_layout: str = "NHWC",
 ) -> tvm.relay.Call:
-    """This is a quantized 2D depthwise convolution operation as supported by 
the
-    Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format
+    """This is a quantized 2D depthwise convolution operation as supported by
+    the NPU. It accepts either NHWC or NHCWB16 format
     for the input data and OHWI format for the kernel weights.
 
     Reference: https://developer.arm.com/documentation/102420/0200/
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py 
b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py
index f344f61..cc36373 100644
--- a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py
+++ b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=unused-argument
-"""Relay operators for pooling"""
+"""Relay operators for pooling for Arm(R) Ethos(TM)-U NPU"""
 from typing import Tuple
 
 import tvm
@@ -107,8 +107,8 @@ def ethosu_pooling(
     ifm_layout: str = "NHWC",
     ofm_layout: str = "NHWC",
 ) -> tvm.relay.Call:
-    """This is a quantized 2D pooling operation as supported by the
-    Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format
+    """This is a quantized 2D pooling operation as supported by
+    the NPU. It accepts either NHWC or NHCWB16 format
     for the input data.
 
     Parameters
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py 
b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py
index e2eb28f..5c26236 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py
@@ -19,3 +19,4 @@
 from .convolution import *
 from .depthwise import *
 from .pooling import *
+from .binary_elementwise import *
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py 
b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py
new file mode 100644
index 0000000..84d4e1b
--- /dev/null
+++ b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py
@@ -0,0 +1,184 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-argument
+"""Tensor Expressions for binary_elementwise"""
+import operator
+from tvm import te
+from .dma import dma_ofm_compute, dma_ifm_compute
+
+
+def binary_elementwise_compute(
+    ifm: te.Tensor,
+    ifm2: te.Tensor,
+    lut: te.Tensor,
+    operator_type: str,
+    ifm_scale: float,
+    ifm_zero_point: int,
+    ifm2_scale: float,
+    ifm2_zero_point: int,
+    ofm_scale: float,
+    ofm_zero_point: int,
+    ifm_channels: int,
+    ifm2_channels: int,
+    reversed_operands: bool,
+    activation: str,
+    clip_min: int,
+    clip_max: int,
+    ifm_layout: str,
+    ifm2_layout: str,
+    ofm_layout: str,
+    ofm_dtype: str,
+) -> te.Tensor:
+    """A compute operator representing the capabilities of binary_elementwise 
for the NPU.
+
+    Parameters
+    ----------
+    ifm : te.Tensor
+        The Input Feature Map tensor (IFM).
+    ifm2 : te.Tensor
+        The Input Feature Map tensor 2 (IFM2).
+    lut : te.Tensor
+        The look-up table values to use if activation = "LUT".
+    operator_type: str
+        The type of the binary elementwise operator.
+            "ADD"
+            "SUB"
+            "MUL"
+            "MIN"
+            "MAX"
+            "SHR"
+            "SHL"
+    ifm_scale : float
+        The quantization scale for the Input Feature Map tensor.
+    ifm_zero_point : int
+        The quantization zero point for the Input Feature Map tensor.
+    ifm2_scale : float
+        The quantization scale for the Input Feature Map tensor 2.
+    ifm2_zero_point : int
+        The quantization zero point for the Input Feature Map tensor 1.
+    ofm_scale : float
+        The quantization scale for the Output Feature Map tensor.
+    ofm_zero_point : int
+        The quantization zero point for the Output Feature Map tensor.
+    ifm_channels : int
+        The number of the Input Feature Map channels.
+    ifm2_channels : int
+        The number of the Input Feature Map 2 channels.
+    reversed_operands : bool
+        True if IFM2 is the first operand and IFM is the second operand.
+    activation : str
+        The activation function to use.
+            "NONE" - no activation function.
+            "CLIP" - clip the output between clip_min and clip_max.
+            "TANH" - tanh activation function.
+            "SIGMOID" - sigmoid activation function.
+            "LUT" - use a look-up table to perform the activation function.
+        Available activations for activation type:
+            {int8, uint8}: "NONE", "CLIP", "TANH", "SIGMOID", "LUT"
+            {int32}: "NONE"
+    clip_min : int
+        The minimum clipping value if activation = "CLIP".
+    clip_max : int
+        The maximum clipping value if activation = "CLIP".
+    ifm_layout : str, optional
+        The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
+    ifm2_layout : str, optional
+        The layout of the Input Feature Map tensor 2. Can be "NHWC" or 
"NHCWB16".
+    ofm_layout : str, optional
+        The layout of the Output Feature Map tensor. Can be "NHWC" or 
"NHCWB16".
+    ofm_dtype: str
+        The Output Feature Map tensor type.
+        MUL, ADD, SUB {IFM}->{OFM}:
+          {uint8, int8 int32} -> {uint8, int8, int32}, any pairing
+        MAX, MIN:
+          IFM and OFM must be of the same type, one of:
+          {int8, uint8}
+        SHR {IFM}->{OFM}:
+          {int32}->{int8, uint8, int32}, any pairing"
+        SHL:
+          {int32}->{int32} only
+
+    Returns
+    -------
+    te.Tensor
+        The Output Feature Map tensor.
+    """
+    # Compute operation for the IFM DMA pipeline
+    dmaed_ifm = dma_ifm_compute(
+        ifm, ifm_layout, ifm_zero_point, ifm_scale, ifm_channels, (0, 0, 0, 0)
+    )
+    dmaed_ifm2 = dma_ifm_compute(
+        ifm2, ifm2_layout, ifm2_zero_point, ifm2_scale, ifm2_channels, (0, 0, 
0, 0)
+    )
+
+    # Binary elementwise compute operation
+    ofm_height = dmaed_ifm.shape[1]
+    ofm_width = dmaed_ifm.shape[2]
+
+    binary_elementwise_attrs = {
+        "op": "ethosu_binary_elementwise",
+        "operator_type": operator_type,
+        "reversed_operands": reversed_operands,
+        "activation": activation,
+        "clip_min": clip_min,
+        "clip_max": clip_max,
+    }
+
+    operators = {
+        "ADD": operator.add,
+        "SUB": operator.sub,
+        "MUL": operator.mul,
+        "MIN": te.min,
+        "MAX": te.max,
+        "SHR": operator.add,
+        "SHL": operator.add,
+    }
+    broadcast = [value == 1 for value in dmaed_ifm2.shape]
+
+    if reversed_operands:
+        binary_elementwise = te.compute(
+            (1, ofm_height, ofm_width, ifm_channels),
+            lambda nn, hh, ww, cc: operators[operator_type](
+                dmaed_ifm2(
+                    0 if broadcast[0] else nn,
+                    0 if broadcast[1] else hh,
+                    0 if broadcast[2] else ww,
+                    0 if broadcast[3] else cc,
+                ).astype(ifm.dtype),
+                dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype),
+            ).astype(ofm_dtype),
+            name="ethosu_binary_elementwise",
+            attrs=binary_elementwise_attrs,
+        )
+    else:
+        binary_elementwise = te.compute(
+            (1, ofm_height, ofm_width, ifm_channels),
+            lambda nn, hh, ww, cc: operators[operator_type](
+                dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype),
+                dmaed_ifm2(
+                    0 if broadcast[0] else nn,
+                    0 if broadcast[1] else hh,
+                    0 if broadcast[2] else ww,
+                    0 if broadcast[3] else cc,
+                ).astype(ifm.dtype),
+            ).astype(ofm_dtype),
+            name="ethosu_binary_elementwise",
+            attrs=binary_elementwise_attrs,
+        )
+
+    # Compute operation for the OFM DMA pipeline
+    return dma_ofm_compute(binary_elementwise, ofm_layout, ofm_zero_point, 
ofm_scale, ifm_channels)
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py 
b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
new file mode 100644
index 0000000..1ea24ed
--- /dev/null
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""Extract information from the binary_elementwise operators in TIR."""
+from typing import Dict, Tuple
+import tvm
+from .utils import get_outer_loops, get_op_attrs
+from .dma import get_ifm_params, get_ofm_params
+from .spec import SerialActivation, SerialBinaryElementwise
+
+
+def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var:
+    """When the datatype of the ifm, ifm2 and ofm do not match,
+    casts are inserted in TE to handle the difference in these types.
+    Since TIR is not directly run on the NPU we can simply ignore
+    these, and allow the NPU to handle the difference in datatypes
+    itself.
+
+    Parameters
+    ----------
+    tir_load : tvm.tir.expr.Load
+
+    Returns
+    -------
+    tvm.tir.Var
+    """
+    return tir_load.value if isinstance(tir_load, tvm.tir.Cast) else tir_load
+
+
+def get_binary_elementwise_params(
+    stmt: tvm.tir.AttrStmt,
+    producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
+    consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
+) -> Tuple[SerialBinaryElementwise, tvm.tir.Var, tvm.tir.Var]:
+    """Get the parameters necessary to construct a call_extern for a 
binary_elementwise.
+
+    Parameters
+    ----------
+    stmt : tvm.tir.AttrStmt
+        The outermost attribute statement of a binary elementwise loop nest.
+    producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
+        A dictionary to associate pointers with the loop nest
+        that produces their values.
+    consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
+        A dictionary to associate pointers with the loop nest
+        that consumes their values.
+
+    Returns
+    -------
+    SerialBinaryElementwise
+        The parameters needed to construct a binary elementwise operator.
+    output_pointer : tvm.tir.Var
+        The output pointer of the binary elementwise operation.
+    replace_pointer : tvm.tir.Var
+        The output pointer of the DMA write operation, which is to replace
+        the binary elementwise output pointer.
+    """
+    attrs, body = get_op_attrs(stmt)
+    reversed_operands = attrs["reversed_operands"]
+
+    _, _, _, _, _, inner = get_outer_loops(body, "NHWC")
+    op = ignore_cast(inner.value)
+    input_pointer = ignore_cast(op.a).buffer_var
+    input_pointer1 = ignore_cast(op.b).buffer_var
+
+    if reversed_operands:
+        input_pointer, input_pointer1 = input_pointer1, input_pointer
+    output_pointer = inner.buffer_var
+    # Get feature map info
+    serial_ifm, _ = get_ifm_params(input_pointer, producers)
+    serial_ifm2, _ = get_ifm_params(input_pointer1, producers)
+    serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers)
+    # Get activation info
+    serial_activation = SerialActivation(
+        op=attrs["activation"], clip_min=attrs["clip_min"], 
clip_max=attrs["clip_max"]
+    )
+    return (
+        SerialBinaryElementwise(
+            ifm=serial_ifm,
+            ifm2=serial_ifm2,
+            ofm=serial_ofm,
+            operator_type=attrs["operator_type"],
+            reversed_operands=reversed_operands,
+            activation=serial_activation,
+        ),
+        output_pointer,
+        replace_pointer,
+    )
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py 
b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
index 2f5d7ab..a5678d1 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
@@ -23,6 +23,7 @@ from tvm.relay.backend.contrib.ethosu import vela_api
 from .convolution import get_conv2d_params
 from .depthwise import get_depthwise_conv2d_params
 from .pooling import get_pooling_params
+from .binary_elementwise import get_binary_elementwise_params
 from .transform import get_copy_params
 from .utils import get_weights_pointer, get_scale_bias_pointer
 
@@ -56,6 +57,7 @@ def ReplaceOperators():
         "ethosu_copy": get_copy_params,
         "ethosu_depthwise_conv2d": get_depthwise_conv2d_params,
         "ethosu_pooling": get_pooling_params,
+        "ethosu_binary_elementwise": get_binary_elementwise_params,
     }
     pointer_to_producer = {}
     pointer_to_consumer = {}
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py 
b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
index ff019c7..269238a 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
@@ -261,3 +261,24 @@ class SerialPooling(SerializableFormat):
         self.padding = padding
         self.activation = activation
         self.upscale = upscale
+
+
+class SerialBinaryElementwise(SerializableFormat):
+    """Specialization class to retrieve arguments of
+    a ethosu.binary_elementwise tir extern call on a predefined ordering"""
+
+    def __init__(
+        self,
+        ifm: SerialFeatureMap,
+        ifm2: SerialFeatureMap,
+        ofm: SerialFeatureMap,
+        operator_type: str,
+        reversed_operands: bool,
+        activation: SerialActivation,
+    ):
+        self.ifm = ifm
+        self.ifm2 = ifm2
+        self.ofm = ofm
+        self.operator_type = operator_type
+        self.reversed_operands = reversed_operands
+        self.activation = activation
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py 
b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
index 8616695..f82d7bb 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
@@ -213,7 +213,10 @@ def assign_addresses(buffer_info, npu_ops):
         buffer = npu_fm.tiles.addresses[0].buffer_var
         assert buffer in buffer_addresses.keys()
         address, buffer_type = buffer_addresses[buffer]
-        npu_fm.tiles.addresses[0] = address
+        index = npu_fm.tiles.addresses[0].index * (
+            np.iinfo(np.dtype(npu_fm.tiles.addresses[0])).bits // 8
+        )
+        npu_fm.tiles.addresses[0] = address + int(index)
         npu_fm.region = _REGION_MAP[buffer_type]
         return npu_fm
 
@@ -304,6 +307,7 @@ def translate_ethosu_tir_call_extern(tir_call_extern):
         "ethosu_copy": translate_ethosu_copy,
         "ethosu_depthwise_conv2d": translate_ethosu_depthwise_conv2d,
         "ethosu_pooling": translate_ethosu_pooling,
+        "ethosu_binary_elementwise": translate_ethosu_binary_elementwise,
     }
     ext_call_type = tir_call_extern.args[0].value
     assert ext_call_type in supported_call_extern.keys(), f"{ext_call_type} is 
not yet supported"
@@ -482,6 +486,7 @@ def _create_npu_feature_map(serial_feature_map: 
spec.SerialFeatureMap) -> vapi.N
     }
     layout = str(serial_feature_map.layout.value)
     data_type = str(serial_feature_map.data_type.value)
+    date_type_bytes = np.iinfo(np.dtype(data_type)).bits // 8
     assert layout in layout_map.keys()
     assert data_type in datatype_map.keys()
     nfm = vapi.NpuFeatureMap()
@@ -507,9 +512,9 @@ def _create_npu_feature_map(serial_feature_map: 
spec.SerialFeatureMap) -> vapi.N
     )
     nfm.layout = layout_map[layout]
     nfm.strides = vapi.NpuShape3D(
-        int(serial_feature_map.stride_h),
-        int(serial_feature_map.stride_w),
-        int(serial_feature_map.stride_c),
+        int(serial_feature_map.stride_h.value) * date_type_bytes,
+        int(serial_feature_map.stride_w.value) * date_type_bytes,
+        int(serial_feature_map.stride_c.value) * date_type_bytes,
     )
     return nfm
 
@@ -677,3 +682,66 @@ def _create_npu_op_pooling(serial_pooling: 
spec.SerialPooling):
     npu_pooling_op.block_config = block_config
 
     return npu_pooling_op
+
+
+def translate_ethosu_binary_elementwise(
+    tir_call_extern: tvm.tir.Call,
+) -> vapi.NpuElementWiseOperation:
+    """This function will translate a TIR call_extern
+    as produced by NPU Relay to TIR compilation.
+
+    Parameters
+    ----------
+    tir_call_extern : tvm.tir.Call
+        This should be a TIR call_extern that has agreed upon ordering
+        for TIR Compiler. See SerialBinaryElementwise in
+        tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering.
+
+    Returns
+    -------
+    ethosu.vela.api.NpuElementWiseOperation
+        The vela object containing the params of ethosu_binary_elementwise
+    """
+    serial_object = spec.create_serial_object(
+        spec.SerialBinaryElementwise, tir_call_extern.args[1:]
+    )
+    return _create_npu_op_binary_elementwise(serial_object)
+
+
+def _create_npu_op_binary_elementwise(serial_binary_elementwise: 
spec.SerialBinaryElementwise):
+    operator_type = serial_binary_elementwise.operator_type
+    if operator_type == "ADD":
+        op = vapi.NpuElementWiseOp.ADD
+    elif operator_type == "SUB":
+        op = vapi.NpuElementWiseOp.SUB
+    elif operator_type == "MUL":
+        op = vapi.NpuElementWiseOp.MUL
+    elif operator_type == "MIN":
+        op = vapi.NpuElementWiseOp.MIN
+    elif operator_type == "MAX":
+        op = vapi.NpuElementWiseOp.MAX
+    elif operator_type == "SHR":
+        op = vapi.NpuElementWiseOp.SHR
+    elif operator_type == "SHL":
+        op = vapi.NpuElementWiseOp.SHL
+
+    npu_binary_elementwise_op = vapi.NpuElementWiseOperation(op)
+    npu_binary_elementwise_op.ifm = 
_create_npu_feature_map(serial_binary_elementwise.ifm)
+    npu_binary_elementwise_op.ifm2 = 
_create_npu_feature_map(serial_binary_elementwise.ifm2)
+    npu_binary_elementwise_op.ofm = 
_create_npu_feature_map(serial_binary_elementwise.ofm)
+    npu_binary_elementwise_op.reversed_operands = 
serial_binary_elementwise.reversed_operands
+
+    npu_binary_elementwise_op.activation = _create_npu_activation(
+        serial_binary_elementwise.activation
+    )
+    if (
+        npu_binary_elementwise_op.activation
+        and npu_binary_elementwise_op.activation.op_type == 
vapi.NpuActivationOp.NONE_OR_RELU
+    ):
+        _convert_clip_bounds(npu_binary_elementwise_op)
+
+    target_accel_config = vela_api.get_accelerator_config()
+    block_config = 
vela_api.get_optimal_block_config(npu_binary_elementwise_op, 
target_accel_config)
+    npu_binary_elementwise_op.block_config = block_config
+
+    return npu_binary_elementwise_op
diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py 
b/python/tvm/relay/backend/contrib/ethosu/util.py
index ee47e4a..8afb6eb 100644
--- a/python/tvm/relay/backend/contrib/ethosu/util.py
+++ b/python/tvm/relay/backend/contrib/ethosu/util.py
@@ -75,6 +75,21 @@ class ClipArgs(Enum):
     A_MAX = 2
 
 
+class BinaryElementwiseArgs(Enum):
+    """This is a helper enums to access the correct index
+    of binary elementwise arguments
+    """
+
+    ifm = 0
+    ifm2 = 1
+    ifm_scale = 2
+    ifm_zero_point = 3
+    ifm2_scale = 4
+    ifm2_zero_point = 5
+    ofm_scale = 6
+    ofm_zero_point = 7
+
+
 def is_composite_func(func: relay.Function, name: str) -> bool:
     """
     This method checks whether the call is to
diff --git a/python/tvm/relay/op/contrib/ethosu.py 
b/python/tvm/relay/op/contrib/ethosu.py
index a152235..25538ca 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -40,6 +40,7 @@ try:
     from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs  # type: 
ignore
     from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
     from tvm.relay.backend.contrib.ethosu.util import RequantArgs
+    from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs
     from tvm.relay.backend.contrib.ethosu.util import get_dim_value
 except ImportError:
     vapi = None
@@ -99,9 +100,8 @@ def check_strides(strides: List[int]) -> bool:
     return True
 
 
-def check_valid_dtypes(tensor_params: List[TensorParams]) -> bool:
+def check_valid_dtypes(tensor_params: List[TensorParams], supported_dtypes: 
List[type]) -> bool:
     """This function checks whether dtypes are supported by the NPU"""
-    supported_dtypes = (np.uint8, np.int8)
     for tep in tensor_params:
         # Check for dtypes
         if np.dtype(tep.dtype) not in supported_dtypes:
@@ -248,7 +248,7 @@ class QnnConv2DParams:
         This function checks whether QnnConv2D has compatible attributes with 
the NPU
         """
         tensor_params = [self.weights, self.ifm, self.ofm]
-        if not check_valid_dtypes(tensor_params):
+        if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, 
np.int8]):
             return False
         if not check_weights(self.weights, self.dilation):
             return False
@@ -287,7 +287,7 @@ class QnnDepthwiseConv2DParams(QnnConv2DParams):
         Checks whether QnnDepthwiseConv2D + activation function has compatible 
attributes with HW
         """
         tensor_params = [self.weights, self.ifm, self.ofm]
-        if not check_valid_dtypes(tensor_params):
+        if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, 
np.int8]):
             return False
         if not check_weights(self.weights, self.dilation):
             return False
@@ -373,7 +373,7 @@ class MaxPool2DParams:
         This function checks whether MaxPool2D has compatible attributes with 
the NPU
         """
         tensor_params = [self.ifm, self.ofm]
-        if not check_valid_dtypes(tensor_params):
+        if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, 
np.int8]):
             return False
         if self.ifm.dtype != self.ofm.dtype:
             return False
@@ -432,7 +432,7 @@ class AvgPool2DParams:
         This function checks whether AvgPool2D has compatible attributes with 
the NPU
         """
         tensor_params = [self.ifm, self.ofm]
-        if not check_valid_dtypes(tensor_params):
+        if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, 
np.int8]):
             return False
         if self.ifm.dtype != self.ofm.dtype:
             return False
@@ -458,6 +458,316 @@ def qnn_avgpool2d_pattern() -> 
tvm.relay.dataflow_pattern.DFPattern:
     return pattern
 
 
+class BinaryElementwiseParams:
+    """
+    This class will parse a call to a ethosu.binary_elementwise composite 
function
+    and extract the parameter information.
+    """
+
+    def __init__(self, func_body: Call, operator_type: str, 
has_quantization_parameters: bool):
+        clip = None
+        if str(func_body.op) == "clip":
+            clip = func_body
+            binary_op = clip.args[0]
+        else:
+            binary_op = func_body
+
+        layout = "NHWC"
+
+        if has_quantization_parameters:
+            self.ifm = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm.value],
+                layout,
+                binary_op.args[BinaryElementwiseArgs.ifm_scale.value],
+                binary_op.args[BinaryElementwiseArgs.ifm_zero_point.value],
+            )
+            self.ifm2 = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm2.value],
+                layout,
+                binary_op.args[BinaryElementwiseArgs.ifm2_scale.value],
+                binary_op.args[BinaryElementwiseArgs.ifm2_zero_point.value],
+            )
+            self.ofm = TensorParams(
+                binary_op,
+                layout,
+                binary_op.args[BinaryElementwiseArgs.ofm_scale.value],
+                binary_op.args[BinaryElementwiseArgs.ofm_zero_point.value],
+            )
+        else:
+            self.ifm = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm.value],
+                layout,
+            )
+            self.ifm2 = TensorParams(
+                binary_op.args[BinaryElementwiseArgs.ifm2.value],
+                layout,
+            )
+            self.ofm = TensorParams(
+                binary_op,
+                layout,
+            )
+        self.activation = clip
+        self.operator_type = operator_type
+
+        def can_broadcast(x, y):
+            for i in range(1, 4):
+                if x.shape[i] == y.shape[i] or y.shape[i] == 1:
+                    continue
+                return False
+            return True
+
+        if can_broadcast(self.ifm, self.ifm2):
+            self.reversed_operands = False
+            self.valid_broadcast = True
+        elif can_broadcast(self.ifm2, self.ifm):
+            self.reversed_operands = True
+            self.ifm, self.ifm2 = self.ifm2, self.ifm
+            self.valid_broadcast = True
+        else:
+            self.valid_broadcast = False
+
+    def is_valid(self):
+        """
+        This function checks whether BinaryElementwise has compatible 
attributes with the NPU
+        """
+        if np.dtype(self.ofm) == np.int32 and self.activation is not None:
+            return False
+        if len(self.ifm.shape) != 4 or len(self.ifm2.shape) != 4:
+            return False
+        if self.ifm.shape[0] != 1 or self.ifm2.shape[0] != 1:
+            return False
+        if not self.valid_broadcast:
+            return False
+        return True
+
+
+class AddParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Add composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.add"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "ADD", True)
+
+    def is_valid(self):
+        """
+        This function checks whether Add has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8, np.int32]
+        ):
+            return False
+        return True
+
+
+def qnn_add_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for qnn.add with optional fused RELU 
activation.
+    """
+    pattern = is_op("qnn.add")(
+        wildcard(),
+        wildcard(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+    )
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class SubParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Sub composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.sub"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "SUB", True)
+
+    def is_valid(self):
+        """
+        This function checks whether Sub has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8, np.int32]
+        ):
+            return False
+        return True
+
+
+def qnn_subtract_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for qnn.subtract with optional fused 
RELU activation.
+    """
+    pattern = is_op("qnn.subtract")(
+        wildcard(),
+        wildcard(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+    )
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class MulParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Mul composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.mul"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "MUL", True)
+
+    def is_valid(self):
+        """
+        This function checks whether Mul has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8, np.int32]
+        ):
+            return False
+        return True
+
+
+def qnn_mul_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for qnn.mul with optional fused RELU 
activation.
+    """
+    pattern = is_op("qnn.mul")(
+        wildcard(),
+        wildcard(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+        is_constant(),
+    )
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class MinParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Min composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.min"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "MIN", False)
+
+    def is_valid(self):
+        """
+        This function checks whether Min has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if self.ifm.dtype != self.ifm2.dtype:
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8]
+        ):
+            return False
+        return True
+
+
+def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for minimum with optional fused RELU 
activation.
+    """
+    pattern = is_op("minimum")(wildcard(), wildcard())
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class MaxParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Max composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.max"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "MAX", False)
+
+    def is_valid(self):
+        """
+        This function checks whether Max has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if self.ifm.dtype != self.ifm2.dtype:
+            return False
+        if not check_valid_dtypes(
+            [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, 
np.int8]
+        ):
+            return False
+        return True
+
+
+def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for maximum with optional fused RELU 
activation.
+    """
+    pattern = is_op("maximum")(wildcard(), wildcard())
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
+class ShlParams(BinaryElementwiseParams):
+    """
+    This class will parse a call to a ethosu.binary_elementwise Shl composite 
function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethosu.shl"
+
+    def __init__(self, func_body: Call):
+        BinaryElementwiseParams.__init__(self, func_body, "SHL", False)
+
+    def is_valid(self):
+        """
+        This function checks whether Shl has compatible attributes with the NPU
+        """
+        if not super().is_valid():
+            return False
+        if not check_valid_dtypes([self.ifm, self.ifm2, self.ofm], 
supported_dtypes=[np.int32]):
+            return False
+        return True
+
+
+def shl_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+    """
+    This function creates the pattern for left_shift with optional fused RELU 
activation.
+    """
+    pattern = is_op("left_shift")(wildcard(), wildcard())
+    pattern = pattern.optional(is_op("clip"))
+    return pattern
+
+
 @register_pattern_table("ethosu")
 def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, 
Callable]]:
     return [
@@ -481,6 +791,36 @@ def pattern_table() -> List[Tuple[str, 
tvm.relay.dataflow_pattern.DFPattern, Cal
             qnn_avgpool2d_pattern(),
             lambda pat: AvgPool2DParams(pat).is_valid(),
         ),
+        (
+            AddParams.composite_name,
+            qnn_add_pattern(),
+            lambda pat: AddParams(pat).is_valid(),
+        ),
+        (
+            SubParams.composite_name,
+            qnn_subtract_pattern(),
+            lambda pat: SubParams(pat).is_valid(),
+        ),
+        (
+            MulParams.composite_name,
+            qnn_mul_pattern(),
+            lambda pat: MulParams(pat).is_valid(),
+        ),
+        (
+            MinParams.composite_name,
+            minimum_pattern(),
+            lambda pat: MinParams(pat).is_valid(),
+        ),
+        (
+            MaxParams.composite_name,
+            maximum_pattern(),
+            lambda pat: MaxParams(pat).is_valid(),
+        ),
+        (
+            ShlParams.composite_name,
+            shl_pattern(),
+            lambda pat: ShlParams(pat).is_valid(),
+        ),
     ]
 
 
diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc 
b/src/relay/op/contrib/ethosu/binary_elementwise.cc
new file mode 100644
index 0000000..5b4900e
--- /dev/null
+++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/contrib/ethosu/binary_elementwise.cc
+ * \brief Binary elementwise operators definitions for the Arm(R) Ethos(TM)-U 
NPU.
+ */
+#include <tvm/relay/op.h>
+
+#include "common.h"
+
+namespace tvm {
+namespace relay {
+namespace op {
+namespace contrib {
+namespace ethosu {
+
+/*! \brief Attributes used by the Ethos(TM)-U NPU binary elementwise operators 
*/
+struct EthosuBinaryElementwiseAttrs : public 
tvm::AttrsNode<EthosuBinaryElementwiseAttrs> {
+  String operator_type;
+  double ifm_scale;
+  int ifm_zero_point;
+  double ifm2_scale;
+  int ifm2_zero_point;
+  double ofm_scale;
+  int ofm_zero_point;
+  IndexExpr ifm_channels;
+  IndexExpr ifm2_channels;
+  bool reversed_operands;
+  String activation;
+  int clip_min;
+  int clip_max;
+  String ifm_layout;
+  String ifm2_layout;
+  String ofm_layout;
+  String ofm_dtype;
+
+  TVM_DECLARE_ATTRS(EthosuBinaryElementwiseAttrs, 
"relay.attrs.EthosuBinaryElementwiseAttrs") {
+    TVM_ATTR_FIELD(operator_type)
+        .describe(
+            "The type of the binary elementwise operator."
+            "'ADD'"
+            "'SUB'"
+            "'MUL'"
+            "'MIN'"
+            "'MAX'"
+            "'SHR'"
+            "'SHL'");
+    TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input 
Feature Map tensor.");
+    TVM_ATTR_FIELD(ifm_zero_point)
+        .describe("The quantization zero point for the Input Feature Map 
tensor.");
+    TVM_ATTR_FIELD(ifm2_scale)
+        .describe("The quantization scale for the Input Feature Map tensor 
2.");
+    TVM_ATTR_FIELD(ifm2_zero_point)
+        .describe("The quantization zero point for the Input Feature Map 
tensor 2.");
+    TVM_ATTR_FIELD(ofm_scale).describe("The quantization scale for the Output 
Feature Map tensor.");
+    TVM_ATTR_FIELD(ofm_zero_point)
+        .describe("The quantization zero point for the Output Feature Map 
tensor.");
+    TVM_ATTR_FIELD(ifm_channels).describe("The number of the Input Feature Map 
channels.");
+    TVM_ATTR_FIELD(ifm2_channels).describe("The number of the Input Feature 
Map 2 channels.");
+    TVM_ATTR_FIELD(reversed_operands)
+        .describe("True if IFM2 is the first operand and IFM is the second 
operand.")
+        .set_default(false);
+    TVM_ATTR_FIELD(activation)
+        .describe(
+            "The activation function to use. "
+            "'NONE' - no activation function. "
+            "'CLIP' - clip the output between clip_min and clip_max. "
+            "'TANH' - tanh activation function. "
+            "'SIGMOID' - sigmoid activation function. "
+            "'LUT' - use a look-up table to perform the activation function."
+            "Available activations for activation type:"
+            "{int8, uint8}: 'NONE', 'CLIP', 'TANH', 'SIGMOID', 'LUT'"
+            "{int32}: 'NONE'")
+        .set_default("NONE");
+    TVM_ATTR_FIELD(clip_min)
+        .describe("The minimum clipping value if activation = 'CLIP'.")
+        .set_default(0);
+    TVM_ATTR_FIELD(clip_max)
+        .describe("The maximum clipping value if activation = 'CLIP'.")
+        .set_default(0);
+    TVM_ATTR_FIELD(ifm_layout)
+        .describe("The layout of the Input Feature Map tensor. Can be 'NHWC' 
or 'NHCWB16'.")
+        .set_default("NHWC");
+    TVM_ATTR_FIELD(ifm2_layout)
+        .describe("The layout of the Input Feature Map tensor 2. Can be 'NHWC' 
or 'NHCWB16'.")
+        .set_default("NHWC");
+    TVM_ATTR_FIELD(ofm_layout)
+        .describe("The layout of the Output Feature Map tensor. Can be 'NHWC' 
or 'NHCWB16'.")
+        .set_default("NHWC");
+    TVM_ATTR_FIELD(ofm_dtype).describe(
+        "The Output Feature Map tensor type."
+        "MUL, ADD, SUB {IFM}->{OFM}:"
+        "  {uint8, int8 int32} -> {uint8, int8, int32}, any pairing"
+        "MAX, MIN:"
+        "  IFM and OFM must be of the same type, one of:"
+        "  {int8, uint8}"
+        "SHR {IFM}->{OFM}:"
+        "  {int32}->{int8, uint8, int32}, any pairing"
+        "SHL:"
+        "  {int32}->{int32} only");
+  }
+};
+
+TVM_REGISTER_NODE_TYPE(EthosuBinaryElementwiseAttrs);
+
+bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
+                                const TypeReporter& reporter) {
+  const int ifm_index = 0;
+  const int ifm2_index = 1;
+  const int result_index = 3;
+  ICHECK_EQ(types.size(), result_index + 1);
+
+  const auto* ifm = types[ifm_index].as<TensorTypeNode>();
+  const auto* ifm2 = types[ifm2_index].as<TensorTypeNode>();
+  if (ifm == nullptr) return false;
+  if (ifm2 == nullptr) return false;
+
+  const auto* param = attrs.as<EthosuBinaryElementwiseAttrs>();
+  ICHECK(param != nullptr) << "EthosuBinaryElementwiseAttrs cannot be 
nullptr.";
+
+  String operator_type = param->operator_type;
+  auto ifm_dtype = ifm->dtype;
+  auto ifm2_dtype = ifm2->dtype;
+  DataType ofm_dtype;
+
+  if (param->ofm_dtype == "int8") {
+    ofm_dtype = DataType::Int(8);
+  } else if (param->ofm_dtype == "uint8") {
+    ofm_dtype = DataType::UInt(8);
+  } else if (param->ofm_dtype == "int32") {
+    ofm_dtype = DataType::Int(32);
+  }
+
+  if (ifm_dtype != ifm2_dtype) {
+    reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+                                     << "Invalid operator: expected 
ethosu_binary_elementwise "
+                                     << "type for ifm2 be the same of ifm but 
was " << ifm2_dtype
+                                     << " instead of " << ifm_dtype);
+    return false;
+  }
+
+  if (operator_type == "ADD" || operator_type == "SUB" || operator_type == 
"MUL") {
+    if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) &&
+        ifm_dtype != DataType::Int(32)) {
+      reporter->GetDiagCtx().EmitFatal(
+          Diagnostic::Error(reporter->GetSpan())
+          << "Invalid operator: expected ethosu_binary_elementwise " << 
operator_type
+          << " type(uint8) or type(int8) or type(int32) for ifm but was " << 
ifm_dtype);
+      return false;
+    }
+    if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
+        ofm_dtype != DataType::Int(32)) {
+      reporter->GetDiagCtx().EmitFatal(
+          Diagnostic::Error(reporter->GetSpan())
+          << "Invalid operator: expected ethosu_binary_elementwise " << 
operator_type
+          << " type(uint8) or type(int8) or type(int32) for ofm but was " << 
ofm_dtype);
+      return false;
+    }
+  } else if (operator_type == "MIN" || operator_type == "MAX") {
+    if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) {
+      reporter->GetDiagCtx().EmitFatal(
+          Diagnostic::Error(reporter->GetSpan())
+          << "Invalid operator: expected ethosu_binary_elementwise " << 
operator_type
+          << " type(uint8) or type(int8) for ifm but was " << ifm_dtype);
+      return false;
+    }
+    if (ifm_dtype != ofm_dtype) {
+      reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+                                       << "Invalid operator: expected 
ethosu_binary_elementwise "
+                                       << operator_type
+                                       << " type for ofm be the same of ifm 
but was " << ofm_dtype
+                                       << " instead of " << ifm_dtype);
+      return false;
+    }
+  } else if (operator_type == "SHR") {
+    if (ifm_dtype != DataType::Int(32)) {
+      reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+                                       << "Invalid operator: expected 
ethosu_binary_elementwise "
+                                       << operator_type << " type(int32) for 
ifm but was "
+                                       << ifm_dtype);
+      return false;
+    }
+    if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
+        ofm_dtype != DataType::Int(32)) {
+      reporter->GetDiagCtx().EmitFatal(
+          Diagnostic::Error(reporter->GetSpan())
+          << "Invalid operator: expected ethosu_binary_elementwise " << 
operator_type
+          << " type(uint8) or type(int8) or type(int32) for ofm but was " << 
ofm_dtype);
+      return false;
+    }
+  } else if (operator_type == "SHL") {
+    if (ifm_dtype != DataType::Int(32)) {
+      reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+                                       << "Invalid operator: expected 
ethosu_binary_elementwise "
+                                       << operator_type << " type(int32) for 
ifm but was "
+                                       << ifm_dtype);
+
+      return false;
+    }
+    if (ofm_dtype != DataType::Int(32)) {
+      reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+                                       << "Invalid operator: expected 
ethosu_binary_elementwise "
+                                       << operator_type << " type(int32) for 
ofm but was "
+                                       << ofm_dtype);
+      return false;
+    }
+  } else {
+    reporter->GetDiagCtx().EmitFatal(
+        Diagnostic::Error(reporter->GetSpan())
+        << "Invalid operator: expected ethosu_binary_elementwise 'ADD' or 
'SUB' or 'MUL' or "
+        << "'MIN' or 'MAX' or 'SHR' or 'SHL' for operator_type but was " << 
param->operator_type);
+    return false;
+  }
+
+  // Assign ofm type
+  auto ofm_shape = EthosuInferBinaryElementwiseOutputShape(ifm->shape, 
param->ifm_layout,
+                                                           param->ofm_layout, 
param->ifm_channels);
+  reporter->Assign(types[result_index], TensorType(ofm_shape, ofm_dtype));
+  return true;
+}
+
+Expr MakeEthosuBinaryElementwise(Expr ifm, Expr ifm2, Expr lut, String 
operator_type,
+                                 double ifm_scale, int ifm_zero_point, double 
ifm2_scale,
+                                 int ifm2_zero_point, double ofm_scale, int 
ofm_zero_point,
+                                 IndexExpr ifm_channels, IndexExpr 
ifm2_channels,
+                                 bool reversed_operands, String activation, 
int clip_min,
+                                 int clip_max, String ifm_layout, String 
ifm2_layout,
+                                 String ofm_layout, String ofm_dtype) {
+  auto attrs = make_object<EthosuBinaryElementwiseAttrs>();
+
+  attrs->operator_type = std::move(operator_type);
+  attrs->ifm_scale = ifm_scale;
+  attrs->ifm_zero_point = ifm_zero_point;
+  attrs->ifm2_scale = ifm2_scale;
+  attrs->ifm2_zero_point = ifm2_zero_point;
+  attrs->ofm_scale = ofm_scale;
+  attrs->ofm_zero_point = ofm_zero_point;
+  attrs->ifm_channels = std::move(ifm_channels);
+  attrs->ifm2_channels = std::move(ifm2_channels);
+  attrs->reversed_operands = reversed_operands;
+  attrs->activation = std::move(activation);
+  attrs->clip_min = clip_min;
+  attrs->clip_max = clip_max;
+  attrs->ifm_layout = std::move(ifm_layout);
+  attrs->ifm2_layout = std::move(ifm2_layout);
+  attrs->ofm_layout = std::move(ofm_layout);
+  attrs->ofm_dtype = std::move(ofm_dtype);
+
+  static const Op& op = Op::Get("contrib.ethosu.binary_elementwise");
+  return Call(op, {ifm, ifm2, lut}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.ethosu_binary_elementwise")
+    .set_body_typed(MakeEthosuBinaryElementwise);
+
+RELAY_REGISTER_OP("contrib.ethosu.binary_elementwise")
+    .describe(R"code(Arm(R) Ethos(TM)-U NPU quantized binary elementwise 
operator.
+
+This Relay operator corresponds to the hardware-implemented quantized
+binary elementwise operation found on Ethos(TM)-U NPU. It accepts either NHWC
+or NHCWB16 format for the inputs data (input feature maps, or IFMs).
+
+Reference: https://developer.arm.com/documentation/102420/0200/
+
+- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels)
+           NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16)
+- **ifm2**: NHWC - (1, ifm_height, ifm_width, ifm_channels)
+           NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16)
+- **ofm**: (1, ofm_height, ofm_width, ifm_channels)
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<EthosuBinaryElementwiseAttrs>()
+    .set_num_inputs(3)
+    .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).")
+    .add_argument("ifm2", "Tensor", "The Input Feature Map tensor 2 (IFM2).")
+    .add_argument("lut", "Tensor", "The look-up table of values to use if 
activation = 'LUT'")
+    .set_support_level(11)
+    .add_type_rel("EthosuBinaryElementwise", EthosuBinaryElementwiseRel);
+
+}  // namespace ethosu
+}  // namespace contrib
+}  // namespace op
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/op/contrib/ethosu/common.cc 
b/src/relay/op/contrib/ethosu/common.cc
index bdda81b..bdaa9da 100644
--- a/src/relay/op/contrib/ethosu/common.cc
+++ b/src/relay/op/contrib/ethosu/common.cc
@@ -32,6 +32,24 @@ namespace op {
 namespace contrib {
 namespace ethosu {
 
+Array<IndexExpr> EthosuInferBinaryElementwiseOutputShape(Array<IndexExpr> 
ifm_shape,
+                                                         String ifm_layout, 
String ofm_layout,
+                                                         IndexExpr 
ofm_channels) {
+  // In the case of NHCWB16, convert the ifm shape to NHW (C not required for 
this function)
+  if (ifm_layout == "NHCWB16") {
+    ifm_shape = {ifm_shape[0], ifm_shape[1], ifm_shape[3]};
+  }
+  Array<IndexExpr> oshape({ifm_shape[0], ifm_shape[1], ifm_shape[2], 
ofm_channels});
+
+  // If the ofm is NHCWB16, convert the layout
+  if (ofm_layout == "NHCWB16") {
+    int channel_bricks = 1 + (oshape[3].as<IntImmNode>()->value - 1) / 16;
+    oshape = {oshape[0], oshape[1], channel_bricks, oshape[2], 16};
+  }
+
+  return oshape;
+}
+
 Array<IndexExpr> EthosuInferKernelOutput(Array<IndexExpr> ifm_shape, String 
ifm_layout,
                                          String ofm_layout, Array<IndexExpr> 
kernel_shape,
                                          IndexExpr ofm_channels, 
Array<IndexExpr> dilation,
diff --git a/src/relay/op/contrib/ethosu/common.h 
b/src/relay/op/contrib/ethosu/common.h
index b5377e6..574fb91 100644
--- a/src/relay/op/contrib/ethosu/common.h
+++ b/src/relay/op/contrib/ethosu/common.h
@@ -33,6 +33,17 @@ namespace op {
 namespace contrib {
 namespace ethosu {
 
+/*! \brief Infer the output tensor shape for binary elementwise operators.
+ * \param ifm_shape The shape of Input Feature Map.
+ * \param ifm_layout The layout of the IFM (NHWC or NHCWB16).
+ * \param ofm_layout The layout of the OFM (NHWC or NHCWB16).
+ * \param ofm_channels The number of Output Feature Map channels.
+ * \return The shape of the output tensor.
+ */
+Array<IndexExpr> EthosuInferBinaryElementwiseOutputShape(Array<IndexExpr> 
ifm_shape,
+                                                         String ifm_layout, 
String ofm_layout,
+                                                         IndexExpr 
ofm_channels);
+
 /*! \brief Infer the output tensor shape for convolution and pooling operators.
  * \param ifm_shape The shape of Input Feature Map.
  * \param ifm_layout The layout of the IFM (NHWC or NHCWB16).
diff --git a/src/relay/op/contrib/ethosu/pooling.cc 
b/src/relay/op/contrib/ethosu/pooling.cc
index 86f14f3..bcf54fb 100644
--- a/src/relay/op/contrib/ethosu/pooling.cc
+++ b/src/relay/op/contrib/ethosu/pooling.cc
@@ -19,7 +19,7 @@
 
 /*!
  * \file src/relay/op/contrib/ethosu/pooling.cc
- * \brief Pooling operators definitions for the Arm(R) Ethos(TM)-U NPU 
convolution ops.
+ * \brief Pooling operators definitions for the Arm(R) Ethos(TM)-U NPU.
  */
 #include <tvm/relay/op.h>
 
diff --git a/tests/python/contrib/test_ethosu/infra.py 
b/tests/python/contrib/test_ethosu/infra.py
index 58862c5..17d3fad 100644
--- a/tests/python/contrib/test_ethosu/infra.py
+++ b/tests/python/contrib/test_ethosu/infra.py
@@ -509,3 +509,56 @@ def make_ethosu_pooling(
         ofm_layout=ofm_layout,
     )
     return pooling
+
+
+def get_binary_elementwise_args(call, include_buffers=False):
+    args = call.args
+    binary_elementwise_args = []
+
+    for i, arg in enumerate(args):
+        if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, 
tvm.tir.expr.FloatImm):
+            binary_elementwise_args.append(arg.value)
+        elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers:
+            binary_elementwise_args.append(arg.index)
+        else:
+            binary_elementwise_args.append(arg)
+
+    return binary_elementwise_args
+
+
+def make_ethosu_binary_elementwise(
+    ifm,
+    ifm2,
+    ifm_channels,
+    ifm2_channels,
+    operator_type,
+    ofm_dtype,
+    reversed_operands=False,
+    activation="NONE",
+    ifm_layout="NHWC",
+    ifm2_layout="NHWC",
+    ofm_layout="NHWC",
+):
+    ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
+        ifm=ifm,
+        ifm2=ifm2,
+        lut=relay.const([], dtype="int8"),
+        operator_type=operator_type,
+        ifm_scale=1,
+        ifm_zero_point=0,
+        ifm2_scale=1,
+        ifm2_zero_point=0,
+        ofm_scale=1,
+        ofm_zero_point=0,
+        ifm_channels=ifm_channels,
+        ifm2_channels=ifm2_channels,
+        reversed_operands=reversed_operands,
+        activation=activation,
+        ofm_dtype=ofm_dtype,
+        clip_min=10 if activation == "CLIP" else 0,
+        clip_max=100 if activation == "CLIP" else 0,
+        ifm_layout=ifm_layout,
+        ifm2_layout=ifm2_layout,
+        ofm_layout=ofm_layout,
+    )
+    return ethosu_binary_elementwise
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py 
b/tests/python/contrib/test_ethosu/test_codegen.py
index 478a3c2..a5686c8 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -276,8 +276,6 @@ def test_ethosu_pooling(
     dtype = "int8"
 
     def create_tflite_graph():
-        tf.config.run_functions_eagerly(True)
-
         class Model(tf.Module):
             @tf.function
             def tf_function(self, x):
@@ -343,5 +341,255 @@ def test_ethosu_pooling(
     infra.verify_source(compiled_models, accel_type)
 
 
[email protected]("accel_type", ACCEL_TYPES)
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
[email protected](
+    "ifm_shape, ifm2_shape",
+    [
+        ([1, 2, 3, 4], [1, 2, 3, 4]),
+        ([1, 2, 3, 4], [1, 1, 1, 1]),
+        ([1, 1, 1, 1], [1, 2, 3, 4]),
+    ],
+)
[email protected]("activation_function", ["NONE", "RELU"])
+def test_ethosu_binary_elementwise(
+    accel_type,
+    operator_type,
+    ifm_shape,
+    ifm2_shape,
+    activation_function,
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, lhs, rhs):
+                if operator_type == "ADD":
+                    op = tf.math.add(lhs, rhs)
+                elif operator_type == "SUB":
+                    op = tf.math.subtract(lhs, rhs)
+                elif operator_type == "MUL":
+                    op = tf.math.multiply(lhs, rhs)
+                elif operator_type == "MIN":
+                    op = tf.math.minimum(lhs, rhs)
+                elif operator_type == "MAX":
+                    op = tf.math.maximum(lhs, rhs)
+                if activation_function == "RELU":
+                    op = tf.nn.relu(op)
+                return op
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32), 
tf.TensorSpec(ifm2_shape, dtype=tf.float32)
+        )
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                data2 = np.random.rand(*tuple(ifm2_shape)) * 2
+                yield [data.astype(np.float32), data2.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape},
+        dtype_dict={"ifm": dtype, "ifm2": dtype},
+    )
+    mod = partition_for_ethosu(mod, params)
+
+    # Generate reference data
+    input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+    compiled_models = infra.build_source(
+        mod,
+        input_data,
+        output_data,
+        accel_type,
+        output_tolerance=1 if operator_type == "MAX" else 0,
+    )
+
+    # Assumes only two runtime.Modules are created -- i.e. single offload 
module
+    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
+    assert len(imported_modules) == 2
+    ethosu_module = imported_modules[0]
+
+    # Verify generated C source
+    get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+    cmms = get_cs(ethosu_module)
+    cmms = bytes.fromhex(cmms)
+
+    infra.print_payload(cmms)
+    infra.verify_source(compiled_models, accel_type)
+
+
[email protected]("accel_type", ACCEL_TYPES)
[email protected](
+    "ifm_shape, ifm2_shape",
+    [
+        ([1, 2, 3, 4], [1, 2, 3, 4]),
+        ([1, 2, 3, 4], [1, 1, 3, 1]),
+        ([1, 1, 3, 1], [1, 2, 3, 4]),
+    ],
+)
+def test_ethosu_left_shift_binary_elemwise(
+    accel_type,
+    ifm_shape,
+    ifm2_shape,
+):
+    dtype = "int32"
+
+    def create_model():
+        ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+        ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
+        c1 = relay.left_shift(ifm, ifm2)
+        f = relay.Function([ifm, ifm2], c1)
+        mod = tvm.IRModule()
+        mod["main"] = f
+        return mod
+
+    relay_mod = create_model()
+    mod = partition_for_ethosu(relay_mod)
+
+    # Generate reference data
+    in_min, in_max = util.get_range_for_dtype_str(dtype)
+    input_data = {
+        "ifm": np.random.randint(in_min, high=in_max, size=ifm_shape, 
dtype=dtype),
+        "ifm2": np.random.randint(0, high=32, size=ifm2_shape, dtype=dtype),
+    }
+    output_data = generate_ref_data(relay_mod, input_data)
+
+    compiled_models = infra.build_source(
+        mod,
+        input_data,
+        output_data,
+        accel_type,
+    )
+
+    # Assumes only two runtime.Modules are created -- i.e. single offload 
module
+    imported_modules = compiled_models[0].executor_factory.lib.imported_modules
+    assert len(imported_modules) == 2
+    ethosu_module = imported_modules[0]
+
+    # Verify generated C source
+    get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+    cmms = get_cs(ethosu_module)
+    cmms = bytes.fromhex(cmms)
+
+    infra.print_payload(cmms)
+    infra.verify_source(compiled_models, accel_type)
+
+
[email protected]("accel_type", ACCEL_TYPES)
[email protected](
+    "ifm_shape, ifm2_shape, reversed_operands, ofm_dtype",
+    [
+        ([1, 2, 3, 4], [1, 2, 3, 4], False, "int8"),
+        ([1, 2, 3, 1], [1, 1, 3, 1], False, "int32"),
+        ([1, 1, 3, 1], [1, 2, 3, 1], True, "int32"),
+    ],
+)
+def test_ethosu_right_shift_binary_elemwise(
+    ifm_shape, ifm2_shape, reversed_operands, accel_type, ofm_dtype
+):
+    dtype = "int32"
+
+    def create_model():
+        ifm_count = int(np.prod(ifm_shape))
+        ifm2_count = int(np.prod(ifm2_shape))
+
+        # Create a "partitioned" Relay function
+        ifms = relay.var("ifms", shape=[ifm_count + ifm2_count], dtype=dtype)
+        split = relay.split(ifms, [ifm_count])
+        ifm = relay.reshape(split[0], newshape=ifm_shape)
+        ifm2 = relay.reshape(split[1], newshape=ifm2_shape)
+        shr_op = infra.make_ethosu_binary_elementwise(
+            ifm, ifm2, ifm_shape[3], ifm2_shape[3], "SHR", ofm_dtype, 
reversed_operands
+        )
+
+        glb_ethosu = relay.GlobalVar("tvmgen_default_ethosu_main_0")
+        func = (
+            relay.Function([ifms], shr_op)
+            .with_attr("Inline", 1)
+            .with_attr("Compiler", "ethosu")
+            .with_attr("global_symbol", "tvmgen_default_ethosu_main_0")
+            .with_attr("Primitive", 1)
+        )
+        mod = tvm.IRModule()
+        mod[glb_ethosu] = func
+        mod = relay.transform.InferType()(mod)
+
+        # Main
+        ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+        ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
+        call = relay.Call(
+            glb_ethosu,
+            [
+                relay.concatenate(
+                    data=(
+                        relay.reshape(ifm, newshape=ifm_count),
+                        relay.reshape(ifm2, newshape=ifm2_count),
+                    ),
+                    axis=0,
+                )
+            ],
+        )
+        mod["main"] = relay.Function([ifm, ifm2], call)
+        mod = relay.transform.InferType()(mod)
+        return mod
+
+    mod = create_model()
+
+    # Generate reference data
+    in_min, in_max = util.get_range_for_dtype_str(dtype)
+    in_min, in_max = 18, 19
+    lhs = np.random.randint(in_min, high=in_max, size=ifm_shape, dtype=dtype)
+    rhs = np.random.randint(1, high=2, size=ifm2_shape, dtype=dtype)
+    input_data = {
+        "ifm": lhs,
+        "ifm2": rhs,
+    }
+
+    if reversed_operands:
+        lhs = np.broadcast_to(lhs, ifm2_shape)
+        lhs, rhs = rhs, lhs
+    else:
+        rhs = np.broadcast_to(rhs, ifm_shape)
+
+    def rounding_right_shift(lhs, rhs):
+        r = 1 << (rhs - 1)
+        return (lhs + r) >> rhs
+
+    output_data = np.array(
+        [rounding_right_shift(x[0], x[1]) for x in zip(lhs.flat, rhs.flat)]
+    ).astype(ofm_dtype)
+
+    compiled_model = infra.build_source(mod, input_data, [output_data], 
accel_type)
+
+    imported_modules = compiled_model[0].executor_factory.lib.imported_modules
+    assert len(imported_modules) == 2
+    ethosu_module = imported_modules[0]
+
+    # Verify generated C source
+    get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+    cmms = get_cs(ethosu_module)
+    cmms = bytes.fromhex(cmms)
+
+    infra.print_payload(cmms)
+    infra.verify_source(compiled_model, accel_type)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py 
b/tests/python/contrib/test_ethosu/test_legalize.py
index fc03a98..2a84a23 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -558,5 +558,193 @@ def test_tflite_pool2d_legalize(
     verify(mod["tvmgen_default_ethosu_main_0"])
 
 
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
[email protected](
+    "ifm_shape, ifm2_shape, reversed_operands",
+    [
+        ([1, 2, 3, 4], [1, 2, 3, 4], False),
+        ([1, 2, 3, 4], [1, 1, 3, 1], False),
+        ([1, 1, 3, 1], [1, 2, 3, 4], True),
+    ],
+)
[email protected]("activation_function", ["NONE", "RELU"])
+def test_tflite_binary_elemwise_legalize(
+    operator_type,
+    ifm_shape,
+    ifm2_shape,
+    reversed_operands,
+    activation_function,
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x, y):
+                if operator_type == "ADD":
+                    op = tf.math.add(x, y)
+                elif operator_type == "SUB":
+                    op = tf.math.subtract(x, y)
+                elif operator_type == "MUL":
+                    op = tf.math.multiply(x, y)
+                elif operator_type == "MIN":
+                    op = tf.math.minimum(x, y)
+                elif operator_type == "MAX":
+                    op = tf.math.maximum(x, y)
+                if activation_function == "RELU":
+                    op = tf.nn.relu(op)
+                return op
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32), 
tf.TensorSpec(ifm2_shape, dtype=tf.float32)
+        )
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                data2 = np.random.rand(*tuple(ifm2_shape)) * 2
+                yield [data.astype(np.float32), data2.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        out_shape = ifm2_shape if reversed_operands else ifm_shape
+        shapes = [ifm_shape, ifm2_shape]
+        ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1)
+        op = ext_func.body
+        assert list(op.args[0].checked_type.shape) == shapes[ifm_index]
+        assert list(op.args[1].checked_type.shape) == shapes[ifm2_index]
+        assert op.args[0].checked_type.dtype == dtype
+        assert list(op.checked_type.shape) == out_shape
+        assert op.checked_type.dtype == dtype
+        assert op.attrs.operator_type == operator_type
+        assert op.attrs.reversed_operands == reversed_operands
+        if activation_function == "RELU":
+            assert str(op.attrs.activation) == "CLIP"
+
+    if operator_type == "ADD":
+        rewriter = legalize.AddRewriter()
+        pattern_table = [
+            (
+                ethosu.AddParams.composite_name,
+                ethosu.qnn_add_pattern(),
+                lambda pat: ethosu.AddParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "SUB":
+        rewriter = legalize.SubRewriter()
+        pattern_table = [
+            (
+                ethosu.SubParams.composite_name,
+                ethosu.qnn_subtract_pattern(),
+                lambda pat: ethosu.SubParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "MUL":
+        rewriter = legalize.MulRewriter()
+        pattern_table = [
+            (
+                ethosu.MulParams.composite_name,
+                ethosu.qnn_mul_pattern(),
+                lambda pat: ethosu.MulParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "MIN":
+        rewriter = legalize.MinRewriter()
+        pattern_table = [
+            (
+                ethosu.MinParams.composite_name,
+                ethosu.minimum_pattern(),
+                lambda pat: ethosu.MinParams(pat).is_valid(),
+            ),
+        ]
+    elif operator_type == "MAX":
+        rewriter = legalize.MaxRewriter()
+        pattern_table = [
+            (
+                ethosu.MaxParams.composite_name,
+                ethosu.maximum_pattern(),
+                lambda pat: ethosu.MaxParams(pat).is_valid(),
+            ),
+        ]
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, _ = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"x": ifm_shape, "y": ifm2_shape},
+        dtype_dict={"x": dtype, "y": dtype},
+    )
+    mod = partition_ethosu_by_table(mod, pattern_table)
+
+    mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
+        rewriter, mod["tvmgen_default_ethosu_main_0"]
+    )
+    verify(mod["tvmgen_default_ethosu_main_0"])
+
+
[email protected](
+    "ifm_shape, ifm2_shape, reversed_operands",
+    [
+        ([1, 2, 3, 4], [1, 2, 3, 4], False),
+        ([1, 2, 3, 4], [1, 1, 3, 1], False),
+        ([1, 1, 3, 1], [1, 2, 3, 4], True),
+    ],
+)
+def test_ethosu_left_shift_binary_elemwise_legalize(ifm_shape, ifm2_shape, 
reversed_operands):
+    dtype = "int32"
+    operator_type = "SHL"
+
+    def create_graph():
+        input1 = relay.var("x1", shape=ifm_shape, dtype=dtype)
+        input2 = relay.var("x2", shape=ifm2_shape, dtype=dtype)
+        c1 = relay.left_shift(input1, input2)
+        f = relay.Function([input1, input2], c1)
+        mod = tvm.IRModule()
+        mod["main"] = f
+        return mod
+
+    def verify(ext_func):
+        out_shape = ifm2_shape if reversed_operands else ifm_shape
+        shapes = [ifm_shape, ifm2_shape]
+        ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1)
+        op = ext_func.body
+        assert list(op.args[0].checked_type.shape) == shapes[ifm_index]
+        assert list(op.args[1].checked_type.shape) == shapes[ifm2_index]
+        assert op.args[0].checked_type.dtype == dtype
+        assert list(op.checked_type.shape) == out_shape
+        assert op.checked_type.dtype == dtype
+        assert op.attrs.operator_type == operator_type
+        assert op.attrs.reversed_operands == reversed_operands
+        assert str(op.attrs.activation) == "NONE"
+
+    rewriter = legalize.ShlRewriter()
+    pattern_table = [
+        (
+            ethosu.ShlParams.composite_name,
+            ethosu.shl_pattern(),
+            lambda pat: ethosu.ShlParams(pat).is_valid(),
+        ),
+    ]
+
+    mod = create_graph()
+    mod = partition_ethosu_by_table(mod, pattern_table)
+
+    mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
+        rewriter, mod["tvmgen_default_ethosu_main_0"]
+    )
+    verify(mod["tvmgen_default_ethosu_main_0"])
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git 
a/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py 
b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py
new file mode 100644
index 0000000..6dcd9da
--- /dev/null
+++ b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py
@@ -0,0 +1,335 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+pytest.importorskip("ethosu.vela")
+
+import tvm
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+from tvm.relay.backend.contrib.ethosu.tir import spec
+from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from .infra import make_ethosu_binary_elementwise, get_binary_elementwise_args
+
+
[email protected](
+    "ifm_shape, ifm2_shape, ifm_channels, ifm2_channels, ifm_layout, 
ofm_layout",
+    [
+        ((1, 5, 9, 3), (1, 5, 9, 3), 3, 3, "NHWC", "NHWC"),
+        ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHCWB16"),
+        ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHWC"),
+        ((1, 8, 9, 40), (1, 8, 9, 40), 40, 40, "NHWC", "NHCWB16"),
+        # Broadcast
+        ((1, 5, 9, 3), (1, 1, 9, 1), 3, 1, "NHWC", "NHWC"),
+        ((1, 8, 9, 40), (1, 1, 1, 1), 40, 1, "NHWC", "NHCWB16"),
+    ],
+)
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
[email protected]("activation", ["NONE", "CLIP"])
+def test_binary_elementwise_single(
+    ifm_shape,
+    ifm2_shape,
+    ifm_channels,
+    ifm2_channels,
+    ifm_layout,
+    ofm_layout,
+    operator_type,
+    activation,
+):
+    dtype = "int8"
+    ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+    ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
+
+    binary_elementwise = make_ethosu_binary_elementwise(
+        ifm,
+        ifm2,
+        ifm_channels,
+        ifm2_channels,
+        operator_type,
+        dtype,
+        False,
+        activation,
+        ifm_layout,
+        ifm_layout,
+        ofm_layout,
+    )
+    func = relay.Function(relay.analysis.free_vars(binary_elementwise), 
binary_elementwise)
+    func = run_opt_pass(func, relay.transform.InferType())
+    mod, _ = lower_to_tir(func)
+    data = []
+
+    def _visit(stmt):
+        if isinstance(stmt, tvm.tir.Call):
+            data.append(get_binary_elementwise_args(stmt))
+
+    tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit)
+    if ifm_layout == "NHWC":
+        ifm_stride_c = 1
+        ifm_stride_w = ifm_shape[3] if ifm_shape[2] != 1 else 1
+        ifm_stride_h = ifm_shape[2] * ifm_shape[3] if ifm_shape[1] != 1 else 1
+
+        ifm2_stride_c = 1
+        ifm2_stride_w = ifm2_shape[3] if ifm2_shape[2] != 1 else 1
+        ifm2_stride_h = ifm2_shape[2] * ifm2_shape[3] if ifm2_shape[1] != 1 
else 1
+
+        ofm_height = ifm_shape[1]
+        ofm_width = ifm_shape[2]
+    else:
+        ifm_stride_w = 16
+        ifm_stride_c = 16 * ifm_shape[3]
+        ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3]
+
+        ifm2_stride_w = 16
+        ifm2_stride_c = 16 * ifm2_shape[3]
+        ifm2_stride_h = 16 * ifm2_shape[2] * ifm2_shape[3]
+
+        ofm_height = ifm_shape[1]
+        ofm_width = ifm_shape[3]
+
+    if ofm_layout == "NHWC":
+        ofm_stride_c = 1
+        ofm_stride_w = ifm_channels if ofm_width > 1 else 1
+        ofm_stride_h = ifm_channels * ofm_width if ofm_height > 1 else 1
+    else:
+        ofm_stride_w = 16
+        ofm_stride_c = 16 * ofm_width
+        ofm_stride_h = 16 * ofm_width * ((ifm_channels - 1) // 16 + 1)
+
+    serial_binary_elementwise = spec.SerialBinaryElementwise(
+        ifm=spec.SerialFeatureMap(
+            data_type=dtype,
+            height=ifm_shape[1],
+            width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3],
+            channels=ifm_channels,
+            tile_height_0=ifm_shape[1],
+            tile_height_1=0,
+            tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else 
ifm_shape[3],
+            tile_address_0=0,
+            tile_address_1=0,
+            tile_address_2=0,
+            tile_address_3=0,
+            scale=1.0,
+            zero_point=0,
+            layout=ifm_layout,
+            stride_h=ifm_stride_h,
+            stride_w=ifm_stride_w,
+            stride_c=ifm_stride_c,
+        ),
+        ifm2=spec.SerialFeatureMap(
+            data_type=dtype,
+            height=ifm2_shape[1],
+            width=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3],
+            channels=ifm2_channels,
+            tile_height_0=ifm2_shape[1],
+            tile_height_1=0,
+            tile_width_0=ifm2_shape[2] if ifm_layout == "NHWC" else 
ifm2_shape[3],
+            tile_address_0=0,
+            tile_address_1=0,
+            tile_address_2=0,
+            tile_address_3=0,
+            scale=1.0,
+            zero_point=0,
+            layout=ifm_layout,
+            stride_h=ifm2_stride_h,
+            stride_w=ifm2_stride_w,
+            stride_c=ifm2_stride_c,
+        ),
+        ofm=spec.SerialFeatureMap(
+            data_type=dtype,
+            height=ofm_height,
+            width=ofm_width,
+            channels=ifm_channels,
+            tile_height_0=ofm_height,
+            tile_height_1=0,
+            tile_width_0=ofm_width,
+            tile_address_0=0,
+            tile_address_1=0,
+            tile_address_2=0,
+            tile_address_3=0,
+            scale=1.0,
+            zero_point=0,
+            layout=ofm_layout,
+            stride_h=ofm_stride_h,
+            stride_w=ofm_stride_w,
+            stride_c=ofm_stride_c,
+        ),
+        operator_type=operator_type,
+        reversed_operands=False,
+        activation=spec.SerialActivation(
+            op=activation,
+            clip_min=10 if activation == "CLIP" else 0,
+            clip_max=100 if activation == "CLIP" else 0,
+        ),
+    )
+
+    assert data[0] == ["ethosu_binary_elementwise"] + 
list(serial_binary_elementwise)
+
+
[email protected](
+    "ifm_shape, ifm2_shape, ifm_channels, ifm2_channels, ifm_layout, 
ofm_layout",
+    [
+        ((1, 5, 9, 3), (1, 5, 9, 3), 3, 3, "NHWC", "NHWC"),
+        ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHCWB16"),
+        ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHWC"),
+        ((1, 8, 9, 40), (1, 8, 9, 40), 40, 40, "NHWC", "NHCWB16"),
+        # Broadcast
+        ((1, 5, 9, 3), (1, 1, 9, 1), 3, 1, "NHWC", "NHWC"),
+        ((1, 8, 9, 40), (1, 1, 1, 1), 40, 1, "NHWC", "NHCWB16"),
+    ],
+)
[email protected]("operator_type", ["SHR", "SHL"])
+def test_shift_binary_elementwise_single(
+    ifm_shape,
+    ifm2_shape,
+    ifm_channels,
+    ifm2_channels,
+    ifm_layout,
+    ofm_layout,
+    operator_type,
+):
+    dtype = "int32"
+    activation = "NONE"  # Only NONE is available if the activation type is 
int32
+    ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+    ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
+
+    binary_elementwise = make_ethosu_binary_elementwise(
+        ifm,
+        ifm2,
+        ifm_channels,
+        ifm2_channels,
+        operator_type,
+        dtype,
+        False,
+        "NONE",
+        ifm_layout,
+        ifm_layout,
+        ofm_layout,
+    )
+    func = relay.Function(relay.analysis.free_vars(binary_elementwise), 
binary_elementwise)
+    func = run_opt_pass(func, relay.transform.InferType())
+    mod, _ = lower_to_tir(func)
+    data = []
+
+    def _visit(stmt):
+        if isinstance(stmt, tvm.tir.Call):
+            data.append(get_binary_elementwise_args(stmt))
+
+    tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit)
+    if ifm_layout == "NHWC":
+        ifm_stride_c = 1
+        ifm_stride_w = ifm_shape[3] if ifm_shape[2] != 1 else 1
+        ifm_stride_h = ifm_shape[2] * ifm_shape[3] if ifm_shape[1] != 1 else 1
+
+        ifm2_stride_c = 1
+        ifm2_stride_w = ifm2_shape[3] if ifm2_shape[2] != 1 else 1
+        ifm2_stride_h = ifm2_shape[2] * ifm2_shape[3] if ifm2_shape[1] != 1 
else 1
+
+        ofm_height = ifm_shape[1]
+        ofm_width = ifm_shape[2]
+    else:
+        ifm_stride_w = 16
+        ifm_stride_c = 16 * ifm_shape[3]
+        ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3]
+
+        ifm2_stride_w = 16
+        ifm2_stride_c = 16 * ifm2_shape[3]
+        ifm2_stride_h = 16 * ifm2_shape[2] * ifm2_shape[3]
+
+        ofm_height = ifm_shape[1]
+        ofm_width = ifm_shape[3]
+
+    if ofm_layout == "NHWC":
+        ofm_stride_c = 1
+        ofm_stride_w = ifm_channels if ofm_width > 1 else 1
+        ofm_stride_h = ifm_channels * ofm_width if ofm_height > 1 else 1
+    else:
+        ofm_stride_w = 16
+        ofm_stride_c = 16 * ofm_width
+        ofm_stride_h = 16 * ofm_width * ((ifm_channels - 1) // 16 + 1)
+
+    serial_binary_elementwise = spec.SerialBinaryElementwise(
+        ifm=spec.SerialFeatureMap(
+            data_type=dtype,
+            height=ifm_shape[1],
+            width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3],
+            channels=ifm_channels,
+            tile_height_0=ifm_shape[1],
+            tile_height_1=0,
+            tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else 
ifm_shape[3],
+            tile_address_0=0,
+            tile_address_1=0,
+            tile_address_2=0,
+            tile_address_3=0,
+            scale=1.0,
+            zero_point=0,
+            layout=ifm_layout,
+            stride_h=ifm_stride_h,
+            stride_w=ifm_stride_w,
+            stride_c=ifm_stride_c,
+        ),
+        ifm2=spec.SerialFeatureMap(
+            data_type=dtype,
+            height=ifm2_shape[1],
+            width=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3],
+            channels=ifm2_channels,
+            tile_height_0=ifm2_shape[1],
+            tile_height_1=0,
+            tile_width_0=ifm2_shape[2] if ifm_layout == "NHWC" else 
ifm2_shape[3],
+            tile_address_0=0,
+            tile_address_1=0,
+            tile_address_2=0,
+            tile_address_3=0,
+            scale=1.0,
+            zero_point=0,
+            layout=ifm_layout,
+            stride_h=ifm2_stride_h,
+            stride_w=ifm2_stride_w,
+            stride_c=ifm2_stride_c,
+        ),
+        ofm=spec.SerialFeatureMap(
+            data_type=dtype,
+            height=ofm_height,
+            width=ofm_width,
+            channels=ifm_channels,
+            tile_height_0=ofm_height,
+            tile_height_1=0,
+            tile_width_0=ofm_width,
+            tile_address_0=0,
+            tile_address_1=0,
+            tile_address_2=0,
+            tile_address_3=0,
+            scale=1.0,
+            zero_point=0,
+            layout=ofm_layout,
+            stride_h=ofm_stride_h,
+            stride_w=ofm_stride_w,
+            stride_c=ofm_stride_c,
+        ),
+        operator_type=operator_type,
+        reversed_operands=False,
+        activation=spec.SerialActivation(
+            op=activation,
+            clip_min=0,
+            clip_max=0,
+        ),
+    )
+
+    assert data[0] == ["ethosu_binary_elementwise"] + 
list(serial_binary_elementwise)
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py 
b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
index f4b83a4..ab1bad2 100644
--- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
+++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
@@ -913,5 +913,439 @@ def test_translate_ethosu_pooling():
     assert npu_op.ifm_upscale == vapi.NpuResamplingMode.NONE
 
 
+# fmt: off
+"""A ethosu_binary_elementwise ADD tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseAdd:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(
+            placeholder, [270], dtype="int8", elem_offset=0, align=128, 
offset_factor=1
+        )
+        ethosu_write_2 = T.match_buffer(
+            ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, 
align=128, offset_factor=1
+        )
+        # body
+        T.evaluate(T.call_extern( "ethosu_binary_elementwise", "int8", 5, 9, 
3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 
135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, 
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 
3, 1, "ADD", 0, "CLIP", 10, 100, dtype="int8"))
+
+    __tvm_meta__ = None
+# fmt: on
+
+# fmt: off
+"""A ethosu_binary_elementwise SUB tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseSub:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], 
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 
5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 
135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, 
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 
3, 1, "SUB", 0, "CLIP", 10, 100, dtype="int8"))
+    __tvm_meta__ = None
+# fmt: on
+
+# fmt: off
+"""A ethosu_binary_elementwise MUL tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMul:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], 
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 
5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 
135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, 
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 
3, 1, "MUL", 0, "CLIP", 10, 100, dtype="int8"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise MIN tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMin:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], 
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 
5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 
135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, 
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 
3, 1, "MIN", 0, "CLIP", 10, 100, dtype="int8"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise Max tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMax:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], 
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 
5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 
135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, 
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 
3, 1, "MAX", 0, "CLIP", 10, 100, dtype="int8"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise SHR tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseShr:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], 
dtype="int32", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 
3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", 
placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, 
"int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, 
T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHR", 0, "NONE", 0, 0, dtype="int32"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise SHL tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseShl:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], 
dtype="int32", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 
3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", 
placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, 
"int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, 
T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHL", 0, "CLIP", 10, 100, dtype="int32"))
+    __tvm_meta__ = None
+# fmt: on
+
+
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX", 
"SHR", "SHL"])
+def test_translate_ethosu_binary_elementwise(operator_type):
+    if operator_type == "SHR" or operator_type == "SHL":
+        data_type = vapi.NpuDataType.INT32
+        data_type_bytes = 4
+    else:
+        data_type = vapi.NpuDataType.INT8
+        data_type_bytes = 1
+
+    def extract_ethosu_binary_elementwise_call_extern(mod):
+        # There should only be a single function
+        assert len(mod.functions.items()) == 1
+        primfunc = mod.functions.items()[0][1]
+
+        ethosu_binary_elementwise_calls = list()
+
+        def populate_ethosu_binary_elementwise_calls(stmt):
+            if (
+                isinstance(stmt, tvm.tir.Call)
+                and stmt.op.name == "tir.call_extern"
+                and stmt.args[0] == "ethosu_binary_elementwise"
+            ):
+                ethosu_binary_elementwise_calls.append(stmt)
+
+        stmt_functor.post_order_visit(primfunc.body, 
populate_ethosu_binary_elementwise_calls)
+        return ethosu_binary_elementwise_calls[0]
+
+    if operator_type == "ADD":
+        binary_elementwise = SingleEthosuBinaryElementwiseAdd
+    elif operator_type == "SUB":
+        binary_elementwise = SingleEthosuBinaryElementwiseSub
+    elif operator_type == "MUL":
+        binary_elementwise = SingleEthosuBinaryElementwiseMul
+    elif operator_type == "MIN":
+        binary_elementwise = SingleEthosuBinaryElementwiseMin
+    elif operator_type == "MAX":
+        binary_elementwise = SingleEthosuBinaryElementwiseMax
+    elif operator_type == "SHR":
+        binary_elementwise = SingleEthosuBinaryElementwiseShr
+    elif operator_type == "SHL":
+        binary_elementwise = SingleEthosuBinaryElementwiseShl
+    binary_elementwise_call = 
extract_ethosu_binary_elementwise_call_extern(binary_elementwise)
+    npu_op = 
tir_to_cs_translator.translate_ethosu_binary_elementwise(binary_elementwise_call)
+
+    # Compare IFM
+    assert npu_op.ifm.data_type == data_type
+    assert npu_op.ifm.shape == vapi.NpuShape3D(5, 9, 3)
+    assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 
0]).height_0
+    assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 
0]).height_1
+    assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 
0]).width_0
+    assert npu_op.ifm.quantization == vapi.NpuQuantization(1.0, 0)
+    assert npu_op.ifm.layout == vapi.NpuLayout.NHWC
+    assert npu_op.ifm.strides == vapi.NpuShape3D(
+        27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes
+    )
+    # Compare IFM2
+    assert npu_op.ifm2.data_type == data_type
+    assert npu_op.ifm2.shape == vapi.NpuShape3D(5, 9, 3)
+    assert npu_op.ifm2.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 
0]).height_0
+    assert npu_op.ifm2.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 
0]).height_1
+    assert npu_op.ifm2.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 
0]).width_0
+    assert npu_op.ifm2.quantization == vapi.NpuQuantization(1.0, 0)
+    assert npu_op.ifm2.layout == vapi.NpuLayout.NHWC
+    assert npu_op.ifm2.strides == vapi.NpuShape3D(
+        27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes
+    )
+    # Compare OFM
+    assert npu_op.ofm.data_type == data_type
+    assert npu_op.ofm.shape == vapi.NpuShape3D(5, 9, 3)
+    assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 
0]).height_0
+    assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 
0]).height_1
+    assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 
0]).width_0
+    assert npu_op.ofm.quantization == vapi.NpuQuantization(1.0, 0)
+    assert npu_op.ofm.layout == vapi.NpuLayout.NHWC
+    assert npu_op.ofm.strides == vapi.NpuShape3D(
+        27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes
+    )
+    # Compare op type
+    if operator_type == "ADD":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.ADD
+    elif operator_type == "SUB":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SUB
+    elif operator_type == "MUL":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MUL
+    elif operator_type == "MIN":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MIN
+    elif operator_type == "MAX":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MAX
+    elif operator_type == "SHR":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHR
+    elif operator_type == "SHL":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHL
+    # Compare reversed_operands
+    assert npu_op.reversed_operands == False
+    # Compare activation
+    if operator_type == "SHR":
+        assert npu_op.activation is None
+    else:
+        assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU
+        assert npu_op.activation.min == 10
+        assert npu_op.activation.max == 100
+
+
+# fmt: off
+"""A ethosu_binary_elementwise ADD with broadcasting tir testcase for the 
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseAddBroadcasting:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], 
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 
2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 
0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, 
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 
4, 1, "ADD", 1, "CLIP", 10, 100, dtype="int8"))
+    __tvm_meta__ = None
+# fmt: on
+
+# fmt: off
+"""A ethosu_binary_elementwise SUB with broadcasting tir testcase for the 
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseSubBroadcasting:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], 
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 
2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 
0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, 
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 
4, 1, "SUB", 1, "CLIP", 10, 100, dtype="int8"))
+    __tvm_meta__ = None
+# fmt: on
+
+# fmt: off
+"""A ethosu_binary_elementwise MUL with broadcasting tir testcase for the 
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMulBroadcasting:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], 
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 
2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 
0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, 
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 
4, 1, "MUL", 1, "CLIP", 10, 100, dtype="int8"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise MIN with broadcasting tir testcase for the 
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMinBroadcasting:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], 
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 
2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 
0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, 
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 
4, 1, "MIN", 1, "CLIP", 10, 100, dtype="int8"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise MAX with broadcasting tir testcase for the 
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMaxBroadcasting:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], 
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 
2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 
0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, 
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 
4, 1, "MAX", 1, "CLIP", 10, 100, dtype="int8"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise SHR with broadcasting tir testcase for the 
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseShrBroadcasting:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], 
dtype="int32", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 
4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32", 
placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 
2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, 
T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHR", 1, "NONE", 0, 0, dtype="int32"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise SHL with broadcasting tir testcase for the 
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseShlBroadcasting:
+    @T.prim_func
+    def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", 
elem_offset=0, align=128, offset_factor=1)
+        ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], 
dtype="int32", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 
4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, 
"NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32", 
placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 
2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, 
T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHL", 1, "CLIP", 10, 100, dtype="int32"))
+    __tvm_meta__ = None
+# fmt: on
+
+
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX", 
"SHR", "SHL"])
+def test_translate_ethosu_binary_elementwise_broadcasting(operator_type):
+    if operator_type == "SHR" or operator_type == "SHL":
+        data_type = vapi.NpuDataType.INT32
+        data_type_bytes = 4
+    else:
+        data_type = vapi.NpuDataType.INT8
+        data_type_bytes = 1
+
+    def extract_ethosu_binary_elementwise_broadcasting_call_extern(mod):
+        # There should only be a single function
+        assert len(mod.functions.items()) == 1
+        primfunc = mod.functions.items()[0][1]
+
+        ethosu_binary_elementwise_calls = list()
+
+        def populate_ethosu_binary_elementwise_calls(stmt):
+            if (
+                isinstance(stmt, tvm.tir.Call)
+                and stmt.op.name == "tir.call_extern"
+                and stmt.args[0] == "ethosu_binary_elementwise"
+            ):
+                ethosu_binary_elementwise_calls.append(stmt)
+
+        stmt_functor.post_order_visit(primfunc.body, 
populate_ethosu_binary_elementwise_calls)
+        return ethosu_binary_elementwise_calls[0]
+
+    if operator_type == "ADD":
+        binary_elementwise = SingleEthosuBinaryElementwiseAddBroadcasting
+    elif operator_type == "SUB":
+        binary_elementwise = SingleEthosuBinaryElementwiseSubBroadcasting
+    elif operator_type == "MUL":
+        binary_elementwise = SingleEthosuBinaryElementwiseMulBroadcasting
+    elif operator_type == "MIN":
+        binary_elementwise = SingleEthosuBinaryElementwiseMinBroadcasting
+    elif operator_type == "MAX":
+        binary_elementwise = SingleEthosuBinaryElementwiseMaxBroadcasting
+    elif operator_type == "SHR":
+        binary_elementwise = SingleEthosuBinaryElementwiseShrBroadcasting
+    elif operator_type == "SHL":
+        binary_elementwise = SingleEthosuBinaryElementwiseShlBroadcasting
+    binary_elementwise_call = 
extract_ethosu_binary_elementwise_broadcasting_call_extern(
+        binary_elementwise
+    )
+    npu_op = 
tir_to_cs_translator.translate_ethosu_binary_elementwise(binary_elementwise_call)
+
+    # Compare IFM
+    assert npu_op.ifm.data_type == data_type
+    assert npu_op.ifm.shape == vapi.NpuShape3D(2, 3, 4)
+    assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 
0]).height_0
+    assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 
0]).height_1
+    assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 
0]).width_0
+    assert npu_op.ifm.quantization == vapi.NpuQuantization(1.0, 0)
+    assert npu_op.ifm.layout == vapi.NpuLayout.NHWC
+    assert npu_op.ifm.strides == vapi.NpuShape3D(
+        12 * data_type_bytes, 4 * data_type_bytes, 1 * data_type_bytes
+    )
+    # Compare IFM2
+    assert npu_op.ifm2.data_type == data_type
+    assert npu_op.ifm2.shape == vapi.NpuShape3D(1, 3, 1)
+    assert npu_op.ifm2.tiles.height_0 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0, 
0]).height_0
+    assert npu_op.ifm2.tiles.height_1 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0, 
0]).height_1
+    assert npu_op.ifm2.tiles.width_0 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0, 
0]).width_0
+    assert npu_op.ifm2.quantization == vapi.NpuQuantization(1.0, 0)
+    assert npu_op.ifm2.layout == vapi.NpuLayout.NHWC
+    assert npu_op.ifm2.strides == vapi.NpuShape3D(
+        1 * data_type_bytes, 1 * data_type_bytes, 1 * data_type_bytes
+    )
+    # Compare OFM
+    assert npu_op.ofm.data_type == data_type
+    assert npu_op.ofm.shape == vapi.NpuShape3D(2, 3, 4)
+    assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 
0]).height_0
+    assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 
0]).height_1
+    assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 
0]).width_0
+    assert npu_op.ofm.quantization == vapi.NpuQuantization(1.0, 0)
+    assert npu_op.ofm.layout == vapi.NpuLayout.NHWC
+    assert npu_op.ofm.strides == vapi.NpuShape3D(
+        12 * data_type_bytes, 4 * data_type_bytes, 1 * data_type_bytes
+    )
+    # Compare op type
+    if operator_type == "ADD":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.ADD
+    elif operator_type == "SUB":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SUB
+    elif operator_type == "MUL":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MUL
+    elif operator_type == "MIN":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MIN
+    elif operator_type == "MAX":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MAX
+    elif operator_type == "SHR":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHR
+    elif operator_type == "SHL":
+        assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHL
+    # Compare reversed_operands
+    assert npu_op.reversed_operands == True
+    # Compare activation
+
+    if operator_type == "SHR":
+        assert npu_op.activation is None
+    else:
+        assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU
+        assert npu_op.activation.min == 10
+        assert npu_op.activation.max == 100
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py 
b/tests/python/contrib/test_ethosu/test_type_inference.py
index ecbe31b..e068439 100644
--- a/tests/python/contrib/test_ethosu/test_type_inference.py
+++ b/tests/python/contrib/test_ethosu/test_type_inference.py
@@ -24,6 +24,7 @@ from tvm.relay.testing import run_opt_pass
 from .infra import make_ethosu_conv2d
 from .infra import make_ethosu_depthwise_conv2d
 from .infra import make_ethosu_pooling
+from .infra import make_ethosu_binary_elementwise
 
 
 @pytest.mark.parametrize(
@@ -226,5 +227,120 @@ def test_ethosu_pooling_invalid_dtype():
         run_opt_pass(func, relay.transform.InferType())
 
 
[email protected](
+    "ifm_shape, ifm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16), 
"NHCWB16")]
+)
[email protected](
+    "ofm_shape, ofm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16), 
"NHCWB16")]
+)
+def test_ethosu_binary_elementwise_type_inference(
+    ifm_shape,
+    ifm_layout,
+    ofm_shape,
+    ofm_layout,
+):
+    dtype = "int8"
+    ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+    ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype)
+    operator_type = "ADD"
+    ifm_channels, ifm2_channels = 33, 33
+    binary_elementwise = make_ethosu_binary_elementwise(
+        ifm,
+        ifm2,
+        ifm_channels,
+        ifm2_channels,
+        operator_type,
+        dtype,
+        ifm_layout=ifm_layout,
+        ifm2_layout=ifm_layout,
+        ofm_layout=ofm_layout,
+    )
+    func = relay.Function([ifm, ifm2], binary_elementwise)
+    func = run_opt_pass(func, relay.transform.InferType())
+    assert tuple(func.body.checked_type.shape) == ofm_shape
+    assert func.body.checked_type.dtype == dtype
+
+
+def test_ethosu_binary_elementwise_invalid_operator_type():
+    invalid_operator_type = "A"
+    ifm_shape = [1, 4, 5, 33]
+    dtype = "int8"
+    ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+    ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype)
+    ifm_channels, ifm2_channels = 33, 33
+    binary_elementwise = make_ethosu_binary_elementwise(
+        ifm,
+        ifm2,
+        ifm_channels,
+        ifm2_channels,
+        invalid_operator_type,
+        dtype,
+    )
+    func = relay.Function([ifm, ifm2], binary_elementwise)
+    with pytest.raises(TVMError):
+        run_opt_pass(func, relay.transform.InferType())
+
+
+def test_ethosu_binary_elementwise_invalid_data_types():
+    dtype = "int8"
+    dtype2 = "int32"
+    operator_type = "ADD"
+    ifm_shape = [1, 4, 5, 33]
+    ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+    ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype2)
+    ifm_channels, ifm2_channels = 33, 33
+    binary_elementwise = make_ethosu_binary_elementwise(
+        ifm,
+        ifm2,
+        ifm_channels,
+        ifm2_channels,
+        operator_type,
+        dtype,
+    )
+    func = relay.Function([ifm, ifm2], binary_elementwise)
+    with pytest.raises(TVMError):
+        run_opt_pass(func, relay.transform.InferType())
+
+
[email protected]("operator_type", ["MIN", "MAX"])
+def test_ethosu_binary_elementwise_min_max_invalid_data_type(operator_type):
+    invalid_dtype = "int32"
+    ifm_shape = [1, 4, 5, 33]
+    ifm = relay.var("ifm", shape=ifm_shape, dtype=invalid_dtype)
+    ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=invalid_dtype)
+    ifm_channels, ifm2_channels = 33, 33
+    binary_elementwise = make_ethosu_binary_elementwise(
+        ifm,
+        ifm2,
+        ifm_channels,
+        ifm2_channels,
+        operator_type,
+        invalid_dtype,
+    )
+    func = relay.Function([ifm, ifm2], binary_elementwise)
+    with pytest.raises(TVMError):
+        run_opt_pass(func, relay.transform.InferType())
+
+
[email protected]("invalid_dtype", ["int8", "uint8"])
[email protected]("operator_type", ["RHS", "SHR"])
+def test_ethosu_binary_elementwise_shift_invalid_data_type(invalid_dtype, 
operator_type):
+    ifm_shape = [1, 4, 5, 33]
+    ifm = relay.var("ifm", shape=ifm_shape, dtype=invalid_dtype)
+    ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=invalid_dtype)
+    ifm_channels, ifm2_channels = 33, 33
+    binary_elementwise = make_ethosu_binary_elementwise(
+        ifm,
+        ifm2,
+        ifm_channels,
+        ifm2_channels,
+        operator_type,
+        invalid_dtype,
+    )
+    func = relay.Function([ifm, ifm2], binary_elementwise)
+    with pytest.raises(TVMError):
+        run_opt_pass(func, relay.transform.InferType())
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to