This is an automated email from the ASF dual-hosted git repository.
mbaret pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e81d391 Arm(R) Ethos(TM)-U NPU BinaryElementwise operators support
(#9442)
e81d391 is described below
commit e81d391640ed8c536ecc584ab3b60d756cce47f5
Author: Nicola Lancellotti <[email protected]>
AuthorDate: Thu Nov 11 17:27:40 2021 +0000
Arm(R) Ethos(TM)-U NPU BinaryElementwise operators support (#9442)
This commit adds support for the binary elementwise primitive operators for
the Arm(R) Ethos(TM)-U NPU and includes a few minor rewording changes.
---
.../tvm/relay/backend/contrib/ethosu/legalize.py | 227 +++++++++++
.../relay/backend/contrib/ethosu/op/__init__.py | 1 +
.../contrib/ethosu/op/binary_elementwise.py | 215 ++++++++++
.../relay/backend/contrib/ethosu/op/convolution.py | 4 +-
.../relay/backend/contrib/ethosu/op/depthwise.py | 7 +-
.../tvm/relay/backend/contrib/ethosu/op/pooling.py | 6 +-
.../relay/backend/contrib/ethosu/te/__init__.py | 1 +
.../contrib/ethosu/te/binary_elementwise.py | 184 +++++++++
.../contrib/ethosu/tir/binary_elementwise.py | 102 +++++
.../tvm/relay/backend/contrib/ethosu/tir/passes.py | 2 +
.../tvm/relay/backend/contrib/ethosu/tir/spec.py | 21 +
.../backend/contrib/ethosu/tir_to_cs_translator.py | 76 +++-
python/tvm/relay/backend/contrib/ethosu/util.py | 15 +
python/tvm/relay/op/contrib/ethosu.py | 352 ++++++++++++++++-
src/relay/op/contrib/ethosu/binary_elementwise.cc | 301 ++++++++++++++
src/relay/op/contrib/ethosu/common.cc | 18 +
src/relay/op/contrib/ethosu/common.h | 11 +
src/relay/op/contrib/ethosu/pooling.cc | 2 +-
tests/python/contrib/test_ethosu/infra.py | 53 +++
tests/python/contrib/test_ethosu/test_codegen.py | 252 +++++++++++-
tests/python/contrib/test_ethosu/test_legalize.py | 188 +++++++++
.../test_ethosu/test_replace_binary_elementwise.py | 335 ++++++++++++++++
.../test_ethosu/test_tir_to_cs_translator.py | 434 +++++++++++++++++++++
.../contrib/test_ethosu/test_type_inference.py | 116 ++++++
24 files changed, 2902 insertions(+), 21 deletions(-)
diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py
b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index c4b70c1..d0d04ce 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -413,6 +413,224 @@ class LegalizeAvgPooling:
pass
+class BinaryElementwiseRewriter(DFPatternCallback):
+ """Convert ethosu binary elementwise composite functions to
+ ethosu_binary_elementwise operators"""
+
+ def __init__(
+ self,
+ params_class: Type,
+ pattern: CallPattern,
+ ):
+ super().__init__(require_type=True)
+ self.params_class = params_class
+ self.pattern = pattern
+
+ def callback(
+ self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map:
tvm.ir.container.Map
+ ) -> tvm.relay.Expr:
+ params = self.params_class(post.op.body)
+ params.ifm.tensor = post.args[1] if params.reversed_operands else
post.args[0]
+ params.ifm2.tensor = post.args[0] if params.reversed_operands else
post.args[1]
+ channels_map = {
+ "NHWC": 3,
+ }
+ if str(params.ofm.layout) not in channels_map.keys():
+ raise UnsupportedLayout(str(params.ofm.layout))
+
+ activation_map = {"clip": "CLIP"}
+ if params.activation:
+ activation = activation_map[params.activation.op.name]
+ clip_min = int(params.activation.attrs.a_min)
+ clip_max = int(params.activation.attrs.a_max)
+ else:
+ activation = "NONE"
+ clip_min = 0
+ clip_max = 0
+
+ # We don't yet support activation functions that need to get legalized
to LUTs.
+ lut = relay.const([], dtype="int8")
+
+ return ethosu_ops.ethosu_binary_elementwise(
+ ifm=params.ifm.tensor,
+ ifm2=params.ifm2.tensor,
+ lut=lut,
+ operator_type=params.operator_type,
+ ifm_scale=float(params.ifm.q_params.scale_f32),
+ ifm_zero_point=int(params.ifm.q_params.zero_point),
+ ifm2_scale=float(params.ifm2.q_params.scale_f32),
+ ifm2_zero_point=int(params.ifm2.q_params.zero_point),
+ ofm_scale=float(params.ofm.q_params.scale_f32),
+ ofm_zero_point=int(params.ofm.q_params.zero_point),
+ ifm_channels=params.ifm.shape[3],
+ ifm2_channels=params.ifm2.shape[3],
+ reversed_operands=params.reversed_operands,
+ ofm_dtype=params.ofm.dtype,
+ activation=activation,
+ clip_min=clip_min,
+ clip_max=clip_max,
+ ifm_layout=str(params.ifm.layout),
+ ifm2_layout=str(params.ifm2.layout),
+ ofm_layout=str(params.ofm.layout),
+ )
+
+
+class AddRewriter(BinaryElementwiseRewriter):
+ def __init__(self):
+ super().__init__(
+ params_class=ethosu_patterns.AddParams,
+ pattern=(wildcard().has_attr({"Composite":
ethosu_patterns.AddParams.composite_name}))(
+ wildcard(), wildcard()
+ ),
+ )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeAdd:
+ """This is the pass that wraps the AddRewriter"""
+
+ def transform_module(
+ self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+ ) -> tvm.ir.IRModule:
+ for global_var, func in mod.functions.items():
+ func = rewrite(AddRewriter(), func)
+ mod.update_func(global_var, func)
+ return mod
+
+ def __call__(self, *args, **kwargs):
+ pass
+
+
+class SubRewriter(BinaryElementwiseRewriter):
+ def __init__(self):
+ super().__init__(
+ params_class=ethosu_patterns.SubParams,
+ pattern=(wildcard().has_attr({"Composite":
ethosu_patterns.SubParams.composite_name}))(
+ wildcard(), wildcard()
+ ),
+ )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeSub:
+ """This is the pass that wraps the SubRewriter"""
+
+ def transform_module(
+ self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+ ) -> tvm.ir.IRModule:
+ for global_var, func in mod.functions.items():
+ func = rewrite(SubRewriter(), func)
+ mod.update_func(global_var, func)
+ return mod
+
+ def __call__(self, *args, **kwargs):
+ pass
+
+
+class MulRewriter(BinaryElementwiseRewriter):
+ def __init__(self):
+ super().__init__(
+ params_class=ethosu_patterns.MulParams,
+ pattern=(wildcard().has_attr({"Composite":
ethosu_patterns.MulParams.composite_name}))(
+ wildcard(), wildcard()
+ ),
+ )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeMul:
+ """This is the pass that wraps the MulRewriter"""
+
+ def transform_module(
+ self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+ ) -> tvm.ir.IRModule:
+ for global_var, func in mod.functions.items():
+ func = rewrite(MulRewriter(), func)
+ mod.update_func(global_var, func)
+ return mod
+
+ def __call__(self, *args, **kwargs):
+ pass
+
+
+class MinRewriter(BinaryElementwiseRewriter):
+ def __init__(self):
+ super().__init__(
+ params_class=ethosu_patterns.MinParams,
+ pattern=(wildcard().has_attr({"Composite":
ethosu_patterns.MinParams.composite_name}))(
+ wildcard(), wildcard()
+ ),
+ )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeMin:
+ """This is the pass that wraps the MinRewriter"""
+
+ def transform_module(
+ self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+ ) -> tvm.ir.IRModule:
+ for global_var, func in mod.functions.items():
+ func = rewrite(MinRewriter(), func)
+ mod.update_func(global_var, func)
+ return mod
+
+ def __call__(self, *args, **kwargs):
+ pass
+
+
+class MaxRewriter(BinaryElementwiseRewriter):
+ def __init__(self):
+ super().__init__(
+ params_class=ethosu_patterns.MaxParams,
+ pattern=(wildcard().has_attr({"Composite":
ethosu_patterns.MaxParams.composite_name}))(
+ wildcard(), wildcard()
+ ),
+ )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeMax:
+ """This is the pass that wraps the MaxRewriter"""
+
+ def transform_module(
+ self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+ ) -> tvm.ir.IRModule:
+ for global_var, func in mod.functions.items():
+ func = rewrite(MaxRewriter(), func)
+ mod.update_func(global_var, func)
+ return mod
+
+ def __call__(self, *args, **kwargs):
+ pass
+
+
+class ShlRewriter(BinaryElementwiseRewriter):
+ def __init__(self):
+ super().__init__(
+ params_class=ethosu_patterns.ShlParams,
+ pattern=(wildcard().has_attr({"Composite":
ethosu_patterns.ShlParams.composite_name}))(
+ wildcard(), wildcard()
+ ),
+ )
+
+
[email protected]_pass(opt_level=1)
+class LegalizeShl:
+ """This is the pass that wraps the ShlRewriter"""
+
+ def transform_module(
+ self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
+ ) -> tvm.ir.IRModule:
+ for global_var, func in mod.functions.items():
+ func = rewrite(ShlRewriter(), func)
+ mod.update_func(global_var, func)
+ return mod
+
+ def __call__(self, *args, **kwargs):
+ pass
+
+
@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
@@ -423,11 +641,20 @@ class LegalizeEthosU:
def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
+ """This is the method that replaces the operations with
hardware/codegen supported
+ operations.
+ """
mod = LegalizeSplit()(mod)
mod = LegalizeConv2D()(mod)
mod = LegalizeDepthwiseConv2D()(mod)
mod = LegalizeMaxPooling()(mod)
mod = LegalizeAvgPooling()(mod)
+ mod = LegalizeAdd()(mod)
+ mod = LegalizeSub()(mod)
+ mod = LegalizeMul()(mod)
+ mod = LegalizeMin()(mod)
+ mod = LegalizeMax()(mod)
+ mod = LegalizeShl()(mod)
return mod
def __call__(self, *args, **kwargs):
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py
b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py
index c9aa59b..05d4053 100644
--- a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py
+++ b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py
@@ -19,3 +19,4 @@
from .convolution import ethosu_conv2d
from .depthwise import ethosu_depthwise_conv2d
from .pooling import ethosu_pooling
+from .binary_elementwise import ethosu_binary_elementwise
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py
b/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py
new file mode 100644
index 0000000..d4ae18b
--- /dev/null
+++ b/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py
@@ -0,0 +1,215 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument
+"""Relay operators for binary elementwise operators for Arm(R) Ethos(TM)-U
NPU"""
+from typing import Optional
+import tvm
+from tvm.relay.op import _make
+from tvm.topi.generic import schedule_injective
+from tvm.relay.op.op import OpStrategy
+from tvm.relay.op import strategy as _strategy
+
+from ..te import binary_elementwise_compute
+
+
+def _extract_ethosu_binary_elementwise_params(attrs, args):
+ """Get the parameters necessary to construct a ethosu_binary_elementwise
compute TE
+ from a ethosu_binary_elementwise Relay call."""
+ ifm = args[0]
+ ifm2 = args[1]
+ lut = args[2]
+ operator_type = attrs.operator_type
+ ifm_scale = attrs.ifm_scale
+ ifm_zero_point = attrs.ifm_zero_point
+ ifm2_scale = attrs.ifm2_scale
+ ifm2_zero_point = attrs.ifm2_zero_point
+ ofm_scale = attrs.ofm_scale
+ ofm_zero_point = attrs.ofm_zero_point
+ ifm_channels = attrs.ifm_channels
+ ifm2_channels = attrs.ifm2_channels
+ reversed_operands = attrs.reversed_operands
+ activation = attrs.activation
+ clip_min = attrs.clip_min
+ clip_max = attrs.clip_max
+ ifm_layout = attrs.ifm_layout
+ ifm2_layout = attrs.ifm2_layout
+ ofm_layout = attrs.ofm_layout
+ ofm_dtype = attrs.ofm_dtype
+
+ return (
+ ifm,
+ ifm2,
+ lut,
+ operator_type,
+ ifm_scale,
+ ifm_zero_point,
+ ifm2_scale,
+ ifm2_zero_point,
+ ofm_scale,
+ ofm_zero_point,
+ ifm_channels,
+ ifm2_channels,
+ reversed_operands,
+ activation,
+ clip_min,
+ clip_max,
+ ifm_layout,
+ ifm2_layout,
+ ofm_layout,
+ ofm_dtype,
+ )
+
+
[email protected]_op_attr("contrib.ethosu.binary_elementwise", "FTVMCompute")
+def create_ethosu_binary_elementwise_compute(attrs, args, out_type):
+ """Create an ethosu_binary_elementwise compute op."""
+ params = _extract_ethosu_binary_elementwise_params(attrs, args)
+ op = binary_elementwise_compute(*params)
+ return [op]
+
+
[email protected]_op_attr("contrib.ethosu.binary_elementwise", "FTVMStrategy")
+def binary_elementwise_strategy_ethosu(attrs, inputs, out_type, target):
+ strategy = OpStrategy()
+ strategy.add_implementation(
+ create_ethosu_binary_elementwise_compute,
+ _strategy.wrap_topi_schedule(schedule_injective),
+ name="ethosu_binary_elementwise",
+ )
+ return strategy
+
+
+def ethosu_binary_elementwise(
+ ifm: tvm.relay.Expr,
+ ifm2: tvm.relay.Expr,
+ lut: tvm.relay.Expr,
+ operator_type: str,
+ ifm_scale: float,
+ ifm_zero_point: int,
+ ifm2_scale: float,
+ ifm2_zero_point: int,
+ ofm_scale: float,
+ ofm_zero_point: int,
+ ifm_channels: int,
+ ifm2_channels: int,
+ reversed_operands: bool,
+ ofm_dtype: str,
+ activation: Optional[str] = "NONE",
+ clip_min: Optional[int] = 0,
+ clip_max: Optional[int] = 0,
+ ifm_layout: Optional[str] = "NHWC",
+ ifm2_layout: Optional[str] = "NHWC",
+ ofm_layout: Optional[str] = "NHWC",
+) -> tvm.relay.Call:
+ """This is a quantized binary elementwise operation as supported by
+ the NPU. It accepts either NHWC or NHCWB16 format
+ for the input data.
+
+ Parameters
+ ----------
+ ifm : tvm.relay.Expr
+ The Input Feature Map tensor (IFM).
+ ifm2 : tvm.relay.Expr
+ The Input Feature Map tensor 2 (IFM2).
+ lut : tvm.relay.Expr
+ The look-up table of values to use if activation = "LUT".
+ operator_type: str
+ The type of the binary elementwise operator.
+ "ADD"
+ "SUB"
+ "MUL"
+ "MIN"
+ "MAX"
+ "SHR"
+ "SHL"
+ ifm_scale : float
+ The quantization scale for the Input Feature Map tensor.
+ ifm_zero_point : int
+ The quantization zero point for the Input Feature Map tensor.
+ ifm2_scale : float
+ The quantization scale for the Input Feature Map tensor 2.
+ ifm2_zero_point : int
+ The quantization zero point for the Input Feature Map tensor 2.
+ ofm_scale : float
+ The quantization scale for the Output Feature Map tensor.
+ ofm_zero_point : int
+ The quantization zero point for the Output Feature Map tensor.
+ ifm_channels : int
+ The number of the Input Feature Map channels.
+ ifm2_channels : int
+ The number of the Input Feature Map 2 channels.
+ reversed_operands : bool
+ True if IFM2 is the first operand and IFM is the second operand.
+ ofm_dtype: str
+ The Output Feature Map tensor type.
+ MUL, ADD, SUB {IFM}->{OFM}:
+ {uint8, int8 int32} -> {uint8, int8, int32}, any pairing
+ MAX, MIN:
+ IFM and OFM must be of the same type, one of:
+ {int8, uint8}
+ SHR {IFM}->{OFM}:
+ {int32}->{int8, uint8, int32}, any pairing"
+ SHL:
+ {int32}->{int32} only
+ activation : str, optional
+ The activation function to use.
+ "NONE" - no activation function.
+ "CLIP" - clip the output between clip_min and clip_max.
+ "TANH" - tanh activation function.
+ "SIGMOID" - sigmoid activation function.
+ "LUT" - use a look-up table to perform the activation function.
+ Available activations for activation type:
+ {int8, uint8}: "NONE", "CLIP", "TANH", "SIGMOID", "LUT"
+ {int32}: "NONE"
+ clip_min : int, optional
+ The minimum clipping value if activation = "CLIP".
+ clip_max : int, optional
+ The maximum clipping value if activation = "CLIP".
+ ifm_layout : str, optional
+ The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
+ ifm2_layout : str, optional
+ The layout of the Input Feature Map tensor 2. Can be "NHWC" or
"NHCWB16".
+ ofm_layout : str, optional
+ The layout of the Output Feature Map tensor. Can be "NHWC" or
"NHCWB16".
+
+ Returns
+ -------
+ out : tvm.relay.Call
+ A call to the ethosu_binary_elementwise op.
+ """
+ return _make.ethosu_binary_elementwise(
+ ifm,
+ ifm2,
+ lut,
+ operator_type,
+ ifm_scale,
+ ifm_zero_point,
+ ifm2_scale,
+ ifm2_zero_point,
+ ofm_scale,
+ ofm_zero_point,
+ ifm_channels,
+ ifm2_channels,
+ reversed_operands,
+ activation,
+ clip_min,
+ clip_max,
+ ifm_layout,
+ ifm2_layout,
+ ofm_layout,
+ ofm_dtype,
+ )
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py
b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py
index 7fb054e..970e366 100644
--- a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py
+++ b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py
@@ -112,8 +112,8 @@ def ethosu_conv2d(
ifm_layout: str = "NHWC",
ofm_layout: str = "NHWC",
) -> tvm.relay.Call:
- """This is a quantized 2D convolution operation as supported by the
- Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format
+ """This is a quantized 2D convolution operation as supported by
+ the NPU. It accepts either NHWC or NHCWB16 format
for the input data and OHWI format for the kernel weights.
Reference: https://developer.arm.com/documentation/102420/0200/
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py
b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py
index d1b49ef..d8f2e8b 100644
--- a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py
+++ b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py
@@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
-"""Relay operator for depthwise convolution"""
+"""Relay operator for depthwise convolution for Arm(R) Ethos(TM)-U NPU"""
+
from typing import Tuple
import tvm
@@ -112,8 +113,8 @@ def ethosu_depthwise_conv2d(
ifm_layout: str = "NHWC",
ofm_layout: str = "NHWC",
) -> tvm.relay.Call:
- """This is a quantized 2D depthwise convolution operation as supported by
the
- Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format
+ """This is a quantized 2D depthwise convolution operation as supported by
+ the NPU. It accepts either NHWC or NHCWB16 format
for the input data and OHWI format for the kernel weights.
Reference: https://developer.arm.com/documentation/102420/0200/
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py
b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py
index f344f61..cc36373 100644
--- a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py
+++ b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
-"""Relay operators for pooling"""
+"""Relay operators for pooling for Arm(R) Ethos(TM)-U NPU"""
from typing import Tuple
import tvm
@@ -107,8 +107,8 @@ def ethosu_pooling(
ifm_layout: str = "NHWC",
ofm_layout: str = "NHWC",
) -> tvm.relay.Call:
- """This is a quantized 2D pooling operation as supported by the
- Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format
+ """This is a quantized 2D pooling operation as supported by
+ the NPU. It accepts either NHWC or NHCWB16 format
for the input data.
Parameters
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py
b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py
index e2eb28f..5c26236 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py
@@ -19,3 +19,4 @@
from .convolution import *
from .depthwise import *
from .pooling import *
+from .binary_elementwise import *
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py
b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py
new file mode 100644
index 0000000..84d4e1b
--- /dev/null
+++ b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py
@@ -0,0 +1,184 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-argument
+"""Tensor Expressions for binary_elementwise"""
+import operator
+from tvm import te
+from .dma import dma_ofm_compute, dma_ifm_compute
+
+
+def binary_elementwise_compute(
+ ifm: te.Tensor,
+ ifm2: te.Tensor,
+ lut: te.Tensor,
+ operator_type: str,
+ ifm_scale: float,
+ ifm_zero_point: int,
+ ifm2_scale: float,
+ ifm2_zero_point: int,
+ ofm_scale: float,
+ ofm_zero_point: int,
+ ifm_channels: int,
+ ifm2_channels: int,
+ reversed_operands: bool,
+ activation: str,
+ clip_min: int,
+ clip_max: int,
+ ifm_layout: str,
+ ifm2_layout: str,
+ ofm_layout: str,
+ ofm_dtype: str,
+) -> te.Tensor:
+ """A compute operator representing the capabilities of binary_elementwise
for the NPU.
+
+ Parameters
+ ----------
+ ifm : te.Tensor
+ The Input Feature Map tensor (IFM).
+ ifm2 : te.Tensor
+ The Input Feature Map tensor 2 (IFM2).
+ lut : te.Tensor
+ The look-up table values to use if activation = "LUT".
+ operator_type: str
+ The type of the binary elementwise operator.
+ "ADD"
+ "SUB"
+ "MUL"
+ "MIN"
+ "MAX"
+ "SHR"
+ "SHL"
+ ifm_scale : float
+ The quantization scale for the Input Feature Map tensor.
+ ifm_zero_point : int
+ The quantization zero point for the Input Feature Map tensor.
+ ifm2_scale : float
+ The quantization scale for the Input Feature Map tensor 2.
+ ifm2_zero_point : int
+ The quantization zero point for the Input Feature Map tensor 1.
+ ofm_scale : float
+ The quantization scale for the Output Feature Map tensor.
+ ofm_zero_point : int
+ The quantization zero point for the Output Feature Map tensor.
+ ifm_channels : int
+ The number of the Input Feature Map channels.
+ ifm2_channels : int
+ The number of the Input Feature Map 2 channels.
+ reversed_operands : bool
+ True if IFM2 is the first operand and IFM is the second operand.
+ activation : str
+ The activation function to use.
+ "NONE" - no activation function.
+ "CLIP" - clip the output between clip_min and clip_max.
+ "TANH" - tanh activation function.
+ "SIGMOID" - sigmoid activation function.
+ "LUT" - use a look-up table to perform the activation function.
+ Available activations for activation type:
+ {int8, uint8}: "NONE", "CLIP", "TANH", "SIGMOID", "LUT"
+ {int32}: "NONE"
+ clip_min : int
+ The minimum clipping value if activation = "CLIP".
+ clip_max : int
+ The maximum clipping value if activation = "CLIP".
+ ifm_layout : str, optional
+ The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16".
+ ifm2_layout : str, optional
+ The layout of the Input Feature Map tensor 2. Can be "NHWC" or
"NHCWB16".
+ ofm_layout : str, optional
+ The layout of the Output Feature Map tensor. Can be "NHWC" or
"NHCWB16".
+ ofm_dtype: str
+ The Output Feature Map tensor type.
+ MUL, ADD, SUB {IFM}->{OFM}:
+ {uint8, int8 int32} -> {uint8, int8, int32}, any pairing
+ MAX, MIN:
+ IFM and OFM must be of the same type, one of:
+ {int8, uint8}
+ SHR {IFM}->{OFM}:
+ {int32}->{int8, uint8, int32}, any pairing"
+ SHL:
+ {int32}->{int32} only
+
+ Returns
+ -------
+ te.Tensor
+ The Output Feature Map tensor.
+ """
+ # Compute operation for the IFM DMA pipeline
+ dmaed_ifm = dma_ifm_compute(
+ ifm, ifm_layout, ifm_zero_point, ifm_scale, ifm_channels, (0, 0, 0, 0)
+ )
+ dmaed_ifm2 = dma_ifm_compute(
+ ifm2, ifm2_layout, ifm2_zero_point, ifm2_scale, ifm2_channels, (0, 0,
0, 0)
+ )
+
+ # Binary elementwise compute operation
+ ofm_height = dmaed_ifm.shape[1]
+ ofm_width = dmaed_ifm.shape[2]
+
+ binary_elementwise_attrs = {
+ "op": "ethosu_binary_elementwise",
+ "operator_type": operator_type,
+ "reversed_operands": reversed_operands,
+ "activation": activation,
+ "clip_min": clip_min,
+ "clip_max": clip_max,
+ }
+
+ operators = {
+ "ADD": operator.add,
+ "SUB": operator.sub,
+ "MUL": operator.mul,
+ "MIN": te.min,
+ "MAX": te.max,
+ "SHR": operator.add,
+ "SHL": operator.add,
+ }
+ broadcast = [value == 1 for value in dmaed_ifm2.shape]
+
+ if reversed_operands:
+ binary_elementwise = te.compute(
+ (1, ofm_height, ofm_width, ifm_channels),
+ lambda nn, hh, ww, cc: operators[operator_type](
+ dmaed_ifm2(
+ 0 if broadcast[0] else nn,
+ 0 if broadcast[1] else hh,
+ 0 if broadcast[2] else ww,
+ 0 if broadcast[3] else cc,
+ ).astype(ifm.dtype),
+ dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype),
+ ).astype(ofm_dtype),
+ name="ethosu_binary_elementwise",
+ attrs=binary_elementwise_attrs,
+ )
+ else:
+ binary_elementwise = te.compute(
+ (1, ofm_height, ofm_width, ifm_channels),
+ lambda nn, hh, ww, cc: operators[operator_type](
+ dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype),
+ dmaed_ifm2(
+ 0 if broadcast[0] else nn,
+ 0 if broadcast[1] else hh,
+ 0 if broadcast[2] else ww,
+ 0 if broadcast[3] else cc,
+ ).astype(ifm.dtype),
+ ).astype(ofm_dtype),
+ name="ethosu_binary_elementwise",
+ attrs=binary_elementwise_attrs,
+ )
+
+ # Compute operation for the OFM DMA pipeline
+ return dma_ofm_compute(binary_elementwise, ofm_layout, ofm_zero_point,
ofm_scale, ifm_channels)
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
new file mode 100644
index 0000000..1ea24ed
--- /dev/null
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""Extract information from the binary_elementwise operators in TIR."""
+from typing import Dict, Tuple
+import tvm
+from .utils import get_outer_loops, get_op_attrs
+from .dma import get_ifm_params, get_ofm_params
+from .spec import SerialActivation, SerialBinaryElementwise
+
+
+def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var:
+ """When the datatype of the ifm, ifm2 and ofm do not match,
+ casts are inserted in TE to handle the difference in these types.
+ Since TIR is not directly run on the NPU we can simply ignore
+ these, and allow the NPU to handle the difference in datatypes
+ itself.
+
+ Parameters
+ ----------
+ tir_load : tvm.tir.expr.Load
+
+ Returns
+ -------
+ tvm.tir.Var
+ """
+ return tir_load.value if isinstance(tir_load, tvm.tir.Cast) else tir_load
+
+
+def get_binary_elementwise_params(
+ stmt: tvm.tir.AttrStmt,
+ producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
+ consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
+) -> Tuple[SerialBinaryElementwise, tvm.tir.Var, tvm.tir.Var]:
+ """Get the parameters necessary to construct a call_extern for a
binary_elementwise.
+
+ Parameters
+ ----------
+ stmt : tvm.tir.AttrStmt
+ The outermost attribute statement of a binary elementwise loop nest.
+ producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
+ A dictionary to associate pointers with the loop nest
+ that produces their values.
+ consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
+ A dictionary to associate pointers with the loop nest
+ that consumes their values.
+
+ Returns
+ -------
+ SerialBinaryElementwise
+ The parameters needed to construct a binary elementwise operator.
+ output_pointer : tvm.tir.Var
+ The output pointer of the binary elementwise operation.
+ replace_pointer : tvm.tir.Var
+ The output pointer of the DMA write operation, which is to replace
+ the binary elementwise output pointer.
+ """
+ attrs, body = get_op_attrs(stmt)
+ reversed_operands = attrs["reversed_operands"]
+
+ _, _, _, _, _, inner = get_outer_loops(body, "NHWC")
+ op = ignore_cast(inner.value)
+ input_pointer = ignore_cast(op.a).buffer_var
+ input_pointer1 = ignore_cast(op.b).buffer_var
+
+ if reversed_operands:
+ input_pointer, input_pointer1 = input_pointer1, input_pointer
+ output_pointer = inner.buffer_var
+ # Get feature map info
+ serial_ifm, _ = get_ifm_params(input_pointer, producers)
+ serial_ifm2, _ = get_ifm_params(input_pointer1, producers)
+ serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers)
+ # Get activation info
+ serial_activation = SerialActivation(
+ op=attrs["activation"], clip_min=attrs["clip_min"],
clip_max=attrs["clip_max"]
+ )
+ return (
+ SerialBinaryElementwise(
+ ifm=serial_ifm,
+ ifm2=serial_ifm2,
+ ofm=serial_ofm,
+ operator_type=attrs["operator_type"],
+ reversed_operands=reversed_operands,
+ activation=serial_activation,
+ ),
+ output_pointer,
+ replace_pointer,
+ )
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
index 2f5d7ab..a5678d1 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py
@@ -23,6 +23,7 @@ from tvm.relay.backend.contrib.ethosu import vela_api
from .convolution import get_conv2d_params
from .depthwise import get_depthwise_conv2d_params
from .pooling import get_pooling_params
+from .binary_elementwise import get_binary_elementwise_params
from .transform import get_copy_params
from .utils import get_weights_pointer, get_scale_bias_pointer
@@ -56,6 +57,7 @@ def ReplaceOperators():
"ethosu_copy": get_copy_params,
"ethosu_depthwise_conv2d": get_depthwise_conv2d_params,
"ethosu_pooling": get_pooling_params,
+ "ethosu_binary_elementwise": get_binary_elementwise_params,
}
pointer_to_producer = {}
pointer_to_consumer = {}
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
index ff019c7..269238a 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py
@@ -261,3 +261,24 @@ class SerialPooling(SerializableFormat):
self.padding = padding
self.activation = activation
self.upscale = upscale
+
+
+class SerialBinaryElementwise(SerializableFormat):
+ """Specialization class to retrieve arguments of
+ a ethosu.binary_elementwise tir extern call on a predefined ordering"""
+
+ def __init__(
+ self,
+ ifm: SerialFeatureMap,
+ ifm2: SerialFeatureMap,
+ ofm: SerialFeatureMap,
+ operator_type: str,
+ reversed_operands: bool,
+ activation: SerialActivation,
+ ):
+ self.ifm = ifm
+ self.ifm2 = ifm2
+ self.ofm = ofm
+ self.operator_type = operator_type
+ self.reversed_operands = reversed_operands
+ self.activation = activation
diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
index 8616695..f82d7bb 100644
--- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
+++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
@@ -213,7 +213,10 @@ def assign_addresses(buffer_info, npu_ops):
buffer = npu_fm.tiles.addresses[0].buffer_var
assert buffer in buffer_addresses.keys()
address, buffer_type = buffer_addresses[buffer]
- npu_fm.tiles.addresses[0] = address
+ index = npu_fm.tiles.addresses[0].index * (
+ np.iinfo(np.dtype(npu_fm.tiles.addresses[0])).bits // 8
+ )
+ npu_fm.tiles.addresses[0] = address + int(index)
npu_fm.region = _REGION_MAP[buffer_type]
return npu_fm
@@ -304,6 +307,7 @@ def translate_ethosu_tir_call_extern(tir_call_extern):
"ethosu_copy": translate_ethosu_copy,
"ethosu_depthwise_conv2d": translate_ethosu_depthwise_conv2d,
"ethosu_pooling": translate_ethosu_pooling,
+ "ethosu_binary_elementwise": translate_ethosu_binary_elementwise,
}
ext_call_type = tir_call_extern.args[0].value
assert ext_call_type in supported_call_extern.keys(), f"{ext_call_type} is
not yet supported"
@@ -482,6 +486,7 @@ def _create_npu_feature_map(serial_feature_map:
spec.SerialFeatureMap) -> vapi.N
}
layout = str(serial_feature_map.layout.value)
data_type = str(serial_feature_map.data_type.value)
+ date_type_bytes = np.iinfo(np.dtype(data_type)).bits // 8
assert layout in layout_map.keys()
assert data_type in datatype_map.keys()
nfm = vapi.NpuFeatureMap()
@@ -507,9 +512,9 @@ def _create_npu_feature_map(serial_feature_map:
spec.SerialFeatureMap) -> vapi.N
)
nfm.layout = layout_map[layout]
nfm.strides = vapi.NpuShape3D(
- int(serial_feature_map.stride_h),
- int(serial_feature_map.stride_w),
- int(serial_feature_map.stride_c),
+ int(serial_feature_map.stride_h.value) * date_type_bytes,
+ int(serial_feature_map.stride_w.value) * date_type_bytes,
+ int(serial_feature_map.stride_c.value) * date_type_bytes,
)
return nfm
@@ -677,3 +682,66 @@ def _create_npu_op_pooling(serial_pooling:
spec.SerialPooling):
npu_pooling_op.block_config = block_config
return npu_pooling_op
+
+
+def translate_ethosu_binary_elementwise(
+ tir_call_extern: tvm.tir.Call,
+) -> vapi.NpuElementWiseOperation:
+ """This function will translate a TIR call_extern
+ as produced by NPU Relay to TIR compilation.
+
+ Parameters
+ ----------
+ tir_call_extern : tvm.tir.Call
+ This should be a TIR call_extern that has agreed upon ordering
+ for TIR Compiler. See SerialBinaryElementwise in
+ tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering.
+
+ Returns
+ -------
+ ethosu.vela.api.NpuElementWiseOperation
+ The vela object containing the params of ethosu_binary_elementwise
+ """
+ serial_object = spec.create_serial_object(
+ spec.SerialBinaryElementwise, tir_call_extern.args[1:]
+ )
+ return _create_npu_op_binary_elementwise(serial_object)
+
+
+def _create_npu_op_binary_elementwise(serial_binary_elementwise:
spec.SerialBinaryElementwise):
+ operator_type = serial_binary_elementwise.operator_type
+ if operator_type == "ADD":
+ op = vapi.NpuElementWiseOp.ADD
+ elif operator_type == "SUB":
+ op = vapi.NpuElementWiseOp.SUB
+ elif operator_type == "MUL":
+ op = vapi.NpuElementWiseOp.MUL
+ elif operator_type == "MIN":
+ op = vapi.NpuElementWiseOp.MIN
+ elif operator_type == "MAX":
+ op = vapi.NpuElementWiseOp.MAX
+ elif operator_type == "SHR":
+ op = vapi.NpuElementWiseOp.SHR
+ elif operator_type == "SHL":
+ op = vapi.NpuElementWiseOp.SHL
+
+ npu_binary_elementwise_op = vapi.NpuElementWiseOperation(op)
+ npu_binary_elementwise_op.ifm =
_create_npu_feature_map(serial_binary_elementwise.ifm)
+ npu_binary_elementwise_op.ifm2 =
_create_npu_feature_map(serial_binary_elementwise.ifm2)
+ npu_binary_elementwise_op.ofm =
_create_npu_feature_map(serial_binary_elementwise.ofm)
+ npu_binary_elementwise_op.reversed_operands =
serial_binary_elementwise.reversed_operands
+
+ npu_binary_elementwise_op.activation = _create_npu_activation(
+ serial_binary_elementwise.activation
+ )
+ if (
+ npu_binary_elementwise_op.activation
+ and npu_binary_elementwise_op.activation.op_type ==
vapi.NpuActivationOp.NONE_OR_RELU
+ ):
+ _convert_clip_bounds(npu_binary_elementwise_op)
+
+ target_accel_config = vela_api.get_accelerator_config()
+ block_config =
vela_api.get_optimal_block_config(npu_binary_elementwise_op,
target_accel_config)
+ npu_binary_elementwise_op.block_config = block_config
+
+ return npu_binary_elementwise_op
diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py
b/python/tvm/relay/backend/contrib/ethosu/util.py
index ee47e4a..8afb6eb 100644
--- a/python/tvm/relay/backend/contrib/ethosu/util.py
+++ b/python/tvm/relay/backend/contrib/ethosu/util.py
@@ -75,6 +75,21 @@ class ClipArgs(Enum):
A_MAX = 2
+class BinaryElementwiseArgs(Enum):
+ """This is a helper enums to access the correct index
+ of binary elementwise arguments
+ """
+
+ ifm = 0
+ ifm2 = 1
+ ifm_scale = 2
+ ifm_zero_point = 3
+ ifm2_scale = 4
+ ifm2_zero_point = 5
+ ofm_scale = 6
+ ofm_zero_point = 7
+
+
def is_composite_func(func: relay.Function, name: str) -> bool:
"""
This method checks whether the call is to
diff --git a/python/tvm/relay/op/contrib/ethosu.py
b/python/tvm/relay/op/contrib/ethosu.py
index a152235..25538ca 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -40,6 +40,7 @@ try:
from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type:
ignore
from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
from tvm.relay.backend.contrib.ethosu.util import RequantArgs
+ from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs
from tvm.relay.backend.contrib.ethosu.util import get_dim_value
except ImportError:
vapi = None
@@ -99,9 +100,8 @@ def check_strides(strides: List[int]) -> bool:
return True
-def check_valid_dtypes(tensor_params: List[TensorParams]) -> bool:
+def check_valid_dtypes(tensor_params: List[TensorParams], supported_dtypes:
List[type]) -> bool:
"""This function checks whether dtypes are supported by the NPU"""
- supported_dtypes = (np.uint8, np.int8)
for tep in tensor_params:
# Check for dtypes
if np.dtype(tep.dtype) not in supported_dtypes:
@@ -248,7 +248,7 @@ class QnnConv2DParams:
This function checks whether QnnConv2D has compatible attributes with
the NPU
"""
tensor_params = [self.weights, self.ifm, self.ofm]
- if not check_valid_dtypes(tensor_params):
+ if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8,
np.int8]):
return False
if not check_weights(self.weights, self.dilation):
return False
@@ -287,7 +287,7 @@ class QnnDepthwiseConv2DParams(QnnConv2DParams):
Checks whether QnnDepthwiseConv2D + activation function has compatible
attributes with HW
"""
tensor_params = [self.weights, self.ifm, self.ofm]
- if not check_valid_dtypes(tensor_params):
+ if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8,
np.int8]):
return False
if not check_weights(self.weights, self.dilation):
return False
@@ -373,7 +373,7 @@ class MaxPool2DParams:
This function checks whether MaxPool2D has compatible attributes with
the NPU
"""
tensor_params = [self.ifm, self.ofm]
- if not check_valid_dtypes(tensor_params):
+ if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8,
np.int8]):
return False
if self.ifm.dtype != self.ofm.dtype:
return False
@@ -432,7 +432,7 @@ class AvgPool2DParams:
This function checks whether AvgPool2D has compatible attributes with
the NPU
"""
tensor_params = [self.ifm, self.ofm]
- if not check_valid_dtypes(tensor_params):
+ if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8,
np.int8]):
return False
if self.ifm.dtype != self.ofm.dtype:
return False
@@ -458,6 +458,316 @@ def qnn_avgpool2d_pattern() ->
tvm.relay.dataflow_pattern.DFPattern:
return pattern
+class BinaryElementwiseParams:
+ """
+ This class will parse a call to a ethosu.binary_elementwise composite
function
+ and extract the parameter information.
+ """
+
+ def __init__(self, func_body: Call, operator_type: str,
has_quantization_parameters: bool):
+ clip = None
+ if str(func_body.op) == "clip":
+ clip = func_body
+ binary_op = clip.args[0]
+ else:
+ binary_op = func_body
+
+ layout = "NHWC"
+
+ if has_quantization_parameters:
+ self.ifm = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm.value],
+ layout,
+ binary_op.args[BinaryElementwiseArgs.ifm_scale.value],
+ binary_op.args[BinaryElementwiseArgs.ifm_zero_point.value],
+ )
+ self.ifm2 = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm2.value],
+ layout,
+ binary_op.args[BinaryElementwiseArgs.ifm2_scale.value],
+ binary_op.args[BinaryElementwiseArgs.ifm2_zero_point.value],
+ )
+ self.ofm = TensorParams(
+ binary_op,
+ layout,
+ binary_op.args[BinaryElementwiseArgs.ofm_scale.value],
+ binary_op.args[BinaryElementwiseArgs.ofm_zero_point.value],
+ )
+ else:
+ self.ifm = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm.value],
+ layout,
+ )
+ self.ifm2 = TensorParams(
+ binary_op.args[BinaryElementwiseArgs.ifm2.value],
+ layout,
+ )
+ self.ofm = TensorParams(
+ binary_op,
+ layout,
+ )
+ self.activation = clip
+ self.operator_type = operator_type
+
+ def can_broadcast(x, y):
+ for i in range(1, 4):
+ if x.shape[i] == y.shape[i] or y.shape[i] == 1:
+ continue
+ return False
+ return True
+
+ if can_broadcast(self.ifm, self.ifm2):
+ self.reversed_operands = False
+ self.valid_broadcast = True
+ elif can_broadcast(self.ifm2, self.ifm):
+ self.reversed_operands = True
+ self.ifm, self.ifm2 = self.ifm2, self.ifm
+ self.valid_broadcast = True
+ else:
+ self.valid_broadcast = False
+
+ def is_valid(self):
+ """
+ This function checks whether BinaryElementwise has compatible
attributes with the NPU
+ """
+ if np.dtype(self.ofm) == np.int32 and self.activation is not None:
+ return False
+ if len(self.ifm.shape) != 4 or len(self.ifm2.shape) != 4:
+ return False
+ if self.ifm.shape[0] != 1 or self.ifm2.shape[0] != 1:
+ return False
+ if not self.valid_broadcast:
+ return False
+ return True
+
+
+class AddParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Add composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.add"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "ADD", True)
+
+ def is_valid(self):
+ """
+ This function checks whether Add has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if not check_valid_dtypes(
+ [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8,
np.int8, np.int32]
+ ):
+ return False
+ return True
+
+
+def qnn_add_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for qnn.add with optional fused RELU
activation.
+ """
+ pattern = is_op("qnn.add")(
+ wildcard(),
+ wildcard(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ )
+ pattern = pattern.optional(is_op("clip"))
+ return pattern
+
+
+class SubParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Sub composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.sub"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "SUB", True)
+
+ def is_valid(self):
+ """
+ This function checks whether Sub has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if not check_valid_dtypes(
+ [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8,
np.int8, np.int32]
+ ):
+ return False
+ return True
+
+
+def qnn_subtract_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for qnn.subtract with optional fused
RELU activation.
+ """
+ pattern = is_op("qnn.subtract")(
+ wildcard(),
+ wildcard(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ )
+ pattern = pattern.optional(is_op("clip"))
+ return pattern
+
+
+class MulParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Mul composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.mul"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "MUL", True)
+
+ def is_valid(self):
+ """
+ This function checks whether Mul has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if not check_valid_dtypes(
+ [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8,
np.int8, np.int32]
+ ):
+ return False
+ return True
+
+
+def qnn_mul_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for qnn.mul with optional fused RELU
activation.
+ """
+ pattern = is_op("qnn.mul")(
+ wildcard(),
+ wildcard(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ is_constant(),
+ )
+ pattern = pattern.optional(is_op("clip"))
+ return pattern
+
+
+class MinParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Min composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.min"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "MIN", False)
+
+ def is_valid(self):
+ """
+ This function checks whether Min has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if self.ifm.dtype != self.ifm2.dtype:
+ return False
+ if not check_valid_dtypes(
+ [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8,
np.int8]
+ ):
+ return False
+ return True
+
+
+def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for minimum with optional fused RELU
activation.
+ """
+ pattern = is_op("minimum")(wildcard(), wildcard())
+ pattern = pattern.optional(is_op("clip"))
+ return pattern
+
+
+class MaxParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Max composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.max"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "MAX", False)
+
+ def is_valid(self):
+ """
+ This function checks whether Max has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if self.ifm.dtype != self.ifm2.dtype:
+ return False
+ if not check_valid_dtypes(
+ [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8,
np.int8]
+ ):
+ return False
+ return True
+
+
+def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for maximum with optional fused RELU
activation.
+ """
+ pattern = is_op("maximum")(wildcard(), wildcard())
+ pattern = pattern.optional(is_op("clip"))
+ return pattern
+
+
+class ShlParams(BinaryElementwiseParams):
+ """
+ This class will parse a call to a ethosu.binary_elementwise Shl composite
function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethosu.shl"
+
+ def __init__(self, func_body: Call):
+ BinaryElementwiseParams.__init__(self, func_body, "SHL", False)
+
+ def is_valid(self):
+ """
+ This function checks whether Shl has compatible attributes with the NPU
+ """
+ if not super().is_valid():
+ return False
+ if not check_valid_dtypes([self.ifm, self.ifm2, self.ofm],
supported_dtypes=[np.int32]):
+ return False
+ return True
+
+
+def shl_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
+ """
+ This function creates the pattern for left_shift with optional fused RELU
activation.
+ """
+ pattern = is_op("left_shift")(wildcard(), wildcard())
+ pattern = pattern.optional(is_op("clip"))
+ return pattern
+
+
@register_pattern_table("ethosu")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern,
Callable]]:
return [
@@ -481,6 +791,36 @@ def pattern_table() -> List[Tuple[str,
tvm.relay.dataflow_pattern.DFPattern, Cal
qnn_avgpool2d_pattern(),
lambda pat: AvgPool2DParams(pat).is_valid(),
),
+ (
+ AddParams.composite_name,
+ qnn_add_pattern(),
+ lambda pat: AddParams(pat).is_valid(),
+ ),
+ (
+ SubParams.composite_name,
+ qnn_subtract_pattern(),
+ lambda pat: SubParams(pat).is_valid(),
+ ),
+ (
+ MulParams.composite_name,
+ qnn_mul_pattern(),
+ lambda pat: MulParams(pat).is_valid(),
+ ),
+ (
+ MinParams.composite_name,
+ minimum_pattern(),
+ lambda pat: MinParams(pat).is_valid(),
+ ),
+ (
+ MaxParams.composite_name,
+ maximum_pattern(),
+ lambda pat: MaxParams(pat).is_valid(),
+ ),
+ (
+ ShlParams.composite_name,
+ shl_pattern(),
+ lambda pat: ShlParams(pat).is_valid(),
+ ),
]
diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc
b/src/relay/op/contrib/ethosu/binary_elementwise.cc
new file mode 100644
index 0000000..5b4900e
--- /dev/null
+++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/op/contrib/ethosu/binary_elementwise.cc
+ * \brief Binary elementwise operators definitions for the Arm(R) Ethos(TM)-U
NPU.
+ */
+#include <tvm/relay/op.h>
+
+#include "common.h"
+
+namespace tvm {
+namespace relay {
+namespace op {
+namespace contrib {
+namespace ethosu {
+
+/*! \brief Attributes used by the Ethos(TM)-U NPU binary elementwise operators
*/
+struct EthosuBinaryElementwiseAttrs : public
tvm::AttrsNode<EthosuBinaryElementwiseAttrs> {
+ String operator_type;
+ double ifm_scale;
+ int ifm_zero_point;
+ double ifm2_scale;
+ int ifm2_zero_point;
+ double ofm_scale;
+ int ofm_zero_point;
+ IndexExpr ifm_channels;
+ IndexExpr ifm2_channels;
+ bool reversed_operands;
+ String activation;
+ int clip_min;
+ int clip_max;
+ String ifm_layout;
+ String ifm2_layout;
+ String ofm_layout;
+ String ofm_dtype;
+
+ TVM_DECLARE_ATTRS(EthosuBinaryElementwiseAttrs,
"relay.attrs.EthosuBinaryElementwiseAttrs") {
+ TVM_ATTR_FIELD(operator_type)
+ .describe(
+ "The type of the binary elementwise operator."
+ "'ADD'"
+ "'SUB'"
+ "'MUL'"
+ "'MIN'"
+ "'MAX'"
+ "'SHR'"
+ "'SHL'");
+ TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input
Feature Map tensor.");
+ TVM_ATTR_FIELD(ifm_zero_point)
+ .describe("The quantization zero point for the Input Feature Map
tensor.");
+ TVM_ATTR_FIELD(ifm2_scale)
+ .describe("The quantization scale for the Input Feature Map tensor
2.");
+ TVM_ATTR_FIELD(ifm2_zero_point)
+ .describe("The quantization zero point for the Input Feature Map
tensor 2.");
+ TVM_ATTR_FIELD(ofm_scale).describe("The quantization scale for the Output
Feature Map tensor.");
+ TVM_ATTR_FIELD(ofm_zero_point)
+ .describe("The quantization zero point for the Output Feature Map
tensor.");
+ TVM_ATTR_FIELD(ifm_channels).describe("The number of the Input Feature Map
channels.");
+ TVM_ATTR_FIELD(ifm2_channels).describe("The number of the Input Feature
Map 2 channels.");
+ TVM_ATTR_FIELD(reversed_operands)
+ .describe("True if IFM2 is the first operand and IFM is the second
operand.")
+ .set_default(false);
+ TVM_ATTR_FIELD(activation)
+ .describe(
+ "The activation function to use. "
+ "'NONE' - no activation function. "
+ "'CLIP' - clip the output between clip_min and clip_max. "
+ "'TANH' - tanh activation function. "
+ "'SIGMOID' - sigmoid activation function. "
+ "'LUT' - use a look-up table to perform the activation function."
+ "Available activations for activation type:"
+ "{int8, uint8}: 'NONE', 'CLIP', 'TANH', 'SIGMOID', 'LUT'"
+ "{int32}: 'NONE'")
+ .set_default("NONE");
+ TVM_ATTR_FIELD(clip_min)
+ .describe("The minimum clipping value if activation = 'CLIP'.")
+ .set_default(0);
+ TVM_ATTR_FIELD(clip_max)
+ .describe("The maximum clipping value if activation = 'CLIP'.")
+ .set_default(0);
+ TVM_ATTR_FIELD(ifm_layout)
+ .describe("The layout of the Input Feature Map tensor. Can be 'NHWC'
or 'NHCWB16'.")
+ .set_default("NHWC");
+ TVM_ATTR_FIELD(ifm2_layout)
+ .describe("The layout of the Input Feature Map tensor 2. Can be 'NHWC'
or 'NHCWB16'.")
+ .set_default("NHWC");
+ TVM_ATTR_FIELD(ofm_layout)
+ .describe("The layout of the Output Feature Map tensor. Can be 'NHWC'
or 'NHCWB16'.")
+ .set_default("NHWC");
+ TVM_ATTR_FIELD(ofm_dtype).describe(
+ "The Output Feature Map tensor type."
+ "MUL, ADD, SUB {IFM}->{OFM}:"
+ " {uint8, int8 int32} -> {uint8, int8, int32}, any pairing"
+ "MAX, MIN:"
+ " IFM and OFM must be of the same type, one of:"
+ " {int8, uint8}"
+ "SHR {IFM}->{OFM}:"
+ " {int32}->{int8, uint8, int32}, any pairing"
+ "SHL:"
+ " {int32}->{int32} only");
+ }
+};
+
+TVM_REGISTER_NODE_TYPE(EthosuBinaryElementwiseAttrs);
+
+bool EthosuBinaryElementwiseRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
+ const TypeReporter& reporter) {
+ const int ifm_index = 0;
+ const int ifm2_index = 1;
+ const int result_index = 3;
+ ICHECK_EQ(types.size(), result_index + 1);
+
+ const auto* ifm = types[ifm_index].as<TensorTypeNode>();
+ const auto* ifm2 = types[ifm2_index].as<TensorTypeNode>();
+ if (ifm == nullptr) return false;
+ if (ifm2 == nullptr) return false;
+
+ const auto* param = attrs.as<EthosuBinaryElementwiseAttrs>();
+ ICHECK(param != nullptr) << "EthosuBinaryElementwiseAttrs cannot be
nullptr.";
+
+ String operator_type = param->operator_type;
+ auto ifm_dtype = ifm->dtype;
+ auto ifm2_dtype = ifm2->dtype;
+ DataType ofm_dtype;
+
+ if (param->ofm_dtype == "int8") {
+ ofm_dtype = DataType::Int(8);
+ } else if (param->ofm_dtype == "uint8") {
+ ofm_dtype = DataType::UInt(8);
+ } else if (param->ofm_dtype == "int32") {
+ ofm_dtype = DataType::Int(32);
+ }
+
+ if (ifm_dtype != ifm2_dtype) {
+ reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+ << "Invalid operator: expected
ethosu_binary_elementwise "
+ << "type for ifm2 be the same of ifm but
was " << ifm2_dtype
+ << " instead of " << ifm_dtype);
+ return false;
+ }
+
+ if (operator_type == "ADD" || operator_type == "SUB" || operator_type ==
"MUL") {
+ if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) &&
+ ifm_dtype != DataType::Int(32)) {
+ reporter->GetDiagCtx().EmitFatal(
+ Diagnostic::Error(reporter->GetSpan())
+ << "Invalid operator: expected ethosu_binary_elementwise " <<
operator_type
+ << " type(uint8) or type(int8) or type(int32) for ifm but was " <<
ifm_dtype);
+ return false;
+ }
+ if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
+ ofm_dtype != DataType::Int(32)) {
+ reporter->GetDiagCtx().EmitFatal(
+ Diagnostic::Error(reporter->GetSpan())
+ << "Invalid operator: expected ethosu_binary_elementwise " <<
operator_type
+ << " type(uint8) or type(int8) or type(int32) for ofm but was " <<
ofm_dtype);
+ return false;
+ }
+ } else if (operator_type == "MIN" || operator_type == "MAX") {
+ if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) {
+ reporter->GetDiagCtx().EmitFatal(
+ Diagnostic::Error(reporter->GetSpan())
+ << "Invalid operator: expected ethosu_binary_elementwise " <<
operator_type
+ << " type(uint8) or type(int8) for ifm but was " << ifm_dtype);
+ return false;
+ }
+ if (ifm_dtype != ofm_dtype) {
+ reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+ << "Invalid operator: expected
ethosu_binary_elementwise "
+ << operator_type
+ << " type for ofm be the same of ifm
but was " << ofm_dtype
+ << " instead of " << ifm_dtype);
+ return false;
+ }
+ } else if (operator_type == "SHR") {
+ if (ifm_dtype != DataType::Int(32)) {
+ reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+ << "Invalid operator: expected
ethosu_binary_elementwise "
+ << operator_type << " type(int32) for
ifm but was "
+ << ifm_dtype);
+ return false;
+ }
+ if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) &&
+ ofm_dtype != DataType::Int(32)) {
+ reporter->GetDiagCtx().EmitFatal(
+ Diagnostic::Error(reporter->GetSpan())
+ << "Invalid operator: expected ethosu_binary_elementwise " <<
operator_type
+ << " type(uint8) or type(int8) or type(int32) for ofm but was " <<
ofm_dtype);
+ return false;
+ }
+ } else if (operator_type == "SHL") {
+ if (ifm_dtype != DataType::Int(32)) {
+ reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+ << "Invalid operator: expected
ethosu_binary_elementwise "
+ << operator_type << " type(int32) for
ifm but was "
+ << ifm_dtype);
+
+ return false;
+ }
+ if (ofm_dtype != DataType::Int(32)) {
+ reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+ << "Invalid operator: expected
ethosu_binary_elementwise "
+ << operator_type << " type(int32) for
ofm but was "
+ << ofm_dtype);
+ return false;
+ }
+ } else {
+ reporter->GetDiagCtx().EmitFatal(
+ Diagnostic::Error(reporter->GetSpan())
+ << "Invalid operator: expected ethosu_binary_elementwise 'ADD' or
'SUB' or 'MUL' or "
+ << "'MIN' or 'MAX' or 'SHR' or 'SHL' for operator_type but was " <<
param->operator_type);
+ return false;
+ }
+
+ // Assign ofm type
+ auto ofm_shape = EthosuInferBinaryElementwiseOutputShape(ifm->shape,
param->ifm_layout,
+ param->ofm_layout,
param->ifm_channels);
+ reporter->Assign(types[result_index], TensorType(ofm_shape, ofm_dtype));
+ return true;
+}
+
+Expr MakeEthosuBinaryElementwise(Expr ifm, Expr ifm2, Expr lut, String
operator_type,
+ double ifm_scale, int ifm_zero_point, double
ifm2_scale,
+ int ifm2_zero_point, double ofm_scale, int
ofm_zero_point,
+ IndexExpr ifm_channels, IndexExpr
ifm2_channels,
+ bool reversed_operands, String activation,
int clip_min,
+ int clip_max, String ifm_layout, String
ifm2_layout,
+ String ofm_layout, String ofm_dtype) {
+ auto attrs = make_object<EthosuBinaryElementwiseAttrs>();
+
+ attrs->operator_type = std::move(operator_type);
+ attrs->ifm_scale = ifm_scale;
+ attrs->ifm_zero_point = ifm_zero_point;
+ attrs->ifm2_scale = ifm2_scale;
+ attrs->ifm2_zero_point = ifm2_zero_point;
+ attrs->ofm_scale = ofm_scale;
+ attrs->ofm_zero_point = ofm_zero_point;
+ attrs->ifm_channels = std::move(ifm_channels);
+ attrs->ifm2_channels = std::move(ifm2_channels);
+ attrs->reversed_operands = reversed_operands;
+ attrs->activation = std::move(activation);
+ attrs->clip_min = clip_min;
+ attrs->clip_max = clip_max;
+ attrs->ifm_layout = std::move(ifm_layout);
+ attrs->ifm2_layout = std::move(ifm2_layout);
+ attrs->ofm_layout = std::move(ofm_layout);
+ attrs->ofm_dtype = std::move(ofm_dtype);
+
+ static const Op& op = Op::Get("contrib.ethosu.binary_elementwise");
+ return Call(op, {ifm, ifm2, lut}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.ethosu_binary_elementwise")
+ .set_body_typed(MakeEthosuBinaryElementwise);
+
+RELAY_REGISTER_OP("contrib.ethosu.binary_elementwise")
+ .describe(R"code(Arm(R) Ethos(TM)-U NPU quantized binary elementwise
operator.
+
+This Relay operator corresponds to the hardware-implemented quantized
+binary elementwise operation found on Ethos(TM)-U NPU. It accepts either NHWC
+or NHCWB16 format for the inputs data (input feature maps, or IFMs).
+
+Reference: https://developer.arm.com/documentation/102420/0200/
+
+- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels)
+ NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16)
+- **ifm2**: NHWC - (1, ifm_height, ifm_width, ifm_channels)
+ NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16)
+- **ofm**: (1, ofm_height, ofm_width, ifm_channels)
+
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<EthosuBinaryElementwiseAttrs>()
+ .set_num_inputs(3)
+ .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).")
+ .add_argument("ifm2", "Tensor", "The Input Feature Map tensor 2 (IFM2).")
+ .add_argument("lut", "Tensor", "The look-up table of values to use if
activation = 'LUT'")
+ .set_support_level(11)
+ .add_type_rel("EthosuBinaryElementwise", EthosuBinaryElementwiseRel);
+
+} // namespace ethosu
+} // namespace contrib
+} // namespace op
+} // namespace relay
+} // namespace tvm
diff --git a/src/relay/op/contrib/ethosu/common.cc
b/src/relay/op/contrib/ethosu/common.cc
index bdda81b..bdaa9da 100644
--- a/src/relay/op/contrib/ethosu/common.cc
+++ b/src/relay/op/contrib/ethosu/common.cc
@@ -32,6 +32,24 @@ namespace op {
namespace contrib {
namespace ethosu {
+Array<IndexExpr> EthosuInferBinaryElementwiseOutputShape(Array<IndexExpr>
ifm_shape,
+ String ifm_layout,
String ofm_layout,
+ IndexExpr
ofm_channels) {
+ // In the case of NHCWB16, convert the ifm shape to NHW (C not required for
this function)
+ if (ifm_layout == "NHCWB16") {
+ ifm_shape = {ifm_shape[0], ifm_shape[1], ifm_shape[3]};
+ }
+ Array<IndexExpr> oshape({ifm_shape[0], ifm_shape[1], ifm_shape[2],
ofm_channels});
+
+ // If the ofm is NHCWB16, convert the layout
+ if (ofm_layout == "NHCWB16") {
+ int channel_bricks = 1 + (oshape[3].as<IntImmNode>()->value - 1) / 16;
+ oshape = {oshape[0], oshape[1], channel_bricks, oshape[2], 16};
+ }
+
+ return oshape;
+}
+
Array<IndexExpr> EthosuInferKernelOutput(Array<IndexExpr> ifm_shape, String
ifm_layout,
String ofm_layout, Array<IndexExpr>
kernel_shape,
IndexExpr ofm_channels,
Array<IndexExpr> dilation,
diff --git a/src/relay/op/contrib/ethosu/common.h
b/src/relay/op/contrib/ethosu/common.h
index b5377e6..574fb91 100644
--- a/src/relay/op/contrib/ethosu/common.h
+++ b/src/relay/op/contrib/ethosu/common.h
@@ -33,6 +33,17 @@ namespace op {
namespace contrib {
namespace ethosu {
+/*! \brief Infer the output tensor shape for binary elementwise operators.
+ * \param ifm_shape The shape of Input Feature Map.
+ * \param ifm_layout The layout of the IFM (NHWC or NHCWB16).
+ * \param ofm_layout The layout of the OFM (NHWC or NHCWB16).
+ * \param ofm_channels The number of Output Feature Map channels.
+ * \return The shape of the output tensor.
+ */
+Array<IndexExpr> EthosuInferBinaryElementwiseOutputShape(Array<IndexExpr>
ifm_shape,
+ String ifm_layout,
String ofm_layout,
+ IndexExpr
ofm_channels);
+
/*! \brief Infer the output tensor shape for convolution and pooling operators.
* \param ifm_shape The shape of Input Feature Map.
* \param ifm_layout The layout of the IFM (NHWC or NHCWB16).
diff --git a/src/relay/op/contrib/ethosu/pooling.cc
b/src/relay/op/contrib/ethosu/pooling.cc
index 86f14f3..bcf54fb 100644
--- a/src/relay/op/contrib/ethosu/pooling.cc
+++ b/src/relay/op/contrib/ethosu/pooling.cc
@@ -19,7 +19,7 @@
/*!
* \file src/relay/op/contrib/ethosu/pooling.cc
- * \brief Pooling operators definitions for the Arm(R) Ethos(TM)-U NPU
convolution ops.
+ * \brief Pooling operators definitions for the Arm(R) Ethos(TM)-U NPU.
*/
#include <tvm/relay/op.h>
diff --git a/tests/python/contrib/test_ethosu/infra.py
b/tests/python/contrib/test_ethosu/infra.py
index 58862c5..17d3fad 100644
--- a/tests/python/contrib/test_ethosu/infra.py
+++ b/tests/python/contrib/test_ethosu/infra.py
@@ -509,3 +509,56 @@ def make_ethosu_pooling(
ofm_layout=ofm_layout,
)
return pooling
+
+
+def get_binary_elementwise_args(call, include_buffers=False):
+ args = call.args
+ binary_elementwise_args = []
+
+ for i, arg in enumerate(args):
+ if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg,
tvm.tir.expr.FloatImm):
+ binary_elementwise_args.append(arg.value)
+ elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers:
+ binary_elementwise_args.append(arg.index)
+ else:
+ binary_elementwise_args.append(arg)
+
+ return binary_elementwise_args
+
+
+def make_ethosu_binary_elementwise(
+ ifm,
+ ifm2,
+ ifm_channels,
+ ifm2_channels,
+ operator_type,
+ ofm_dtype,
+ reversed_operands=False,
+ activation="NONE",
+ ifm_layout="NHWC",
+ ifm2_layout="NHWC",
+ ofm_layout="NHWC",
+):
+ ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise(
+ ifm=ifm,
+ ifm2=ifm2,
+ lut=relay.const([], dtype="int8"),
+ operator_type=operator_type,
+ ifm_scale=1,
+ ifm_zero_point=0,
+ ifm2_scale=1,
+ ifm2_zero_point=0,
+ ofm_scale=1,
+ ofm_zero_point=0,
+ ifm_channels=ifm_channels,
+ ifm2_channels=ifm2_channels,
+ reversed_operands=reversed_operands,
+ activation=activation,
+ ofm_dtype=ofm_dtype,
+ clip_min=10 if activation == "CLIP" else 0,
+ clip_max=100 if activation == "CLIP" else 0,
+ ifm_layout=ifm_layout,
+ ifm2_layout=ifm2_layout,
+ ofm_layout=ofm_layout,
+ )
+ return ethosu_binary_elementwise
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py
b/tests/python/contrib/test_ethosu/test_codegen.py
index 478a3c2..a5686c8 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -276,8 +276,6 @@ def test_ethosu_pooling(
dtype = "int8"
def create_tflite_graph():
- tf.config.run_functions_eagerly(True)
-
class Model(tf.Module):
@tf.function
def tf_function(self, x):
@@ -343,5 +341,255 @@ def test_ethosu_pooling(
infra.verify_source(compiled_models, accel_type)
[email protected]("accel_type", ACCEL_TYPES)
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
[email protected](
+ "ifm_shape, ifm2_shape",
+ [
+ ([1, 2, 3, 4], [1, 2, 3, 4]),
+ ([1, 2, 3, 4], [1, 1, 1, 1]),
+ ([1, 1, 1, 1], [1, 2, 3, 4]),
+ ],
+)
[email protected]("activation_function", ["NONE", "RELU"])
+def test_ethosu_binary_elementwise(
+ accel_type,
+ operator_type,
+ ifm_shape,
+ ifm2_shape,
+ activation_function,
+):
+ dtype = "int8"
+
+ def create_tflite_graph():
+ class Model(tf.Module):
+ @tf.function
+ def tf_function(self, lhs, rhs):
+ if operator_type == "ADD":
+ op = tf.math.add(lhs, rhs)
+ elif operator_type == "SUB":
+ op = tf.math.subtract(lhs, rhs)
+ elif operator_type == "MUL":
+ op = tf.math.multiply(lhs, rhs)
+ elif operator_type == "MIN":
+ op = tf.math.minimum(lhs, rhs)
+ elif operator_type == "MAX":
+ op = tf.math.maximum(lhs, rhs)
+ if activation_function == "RELU":
+ op = tf.nn.relu(op)
+ return op
+
+ model = Model()
+ concrete_func = model.tf_function.get_concrete_function(
+ tf.TensorSpec(ifm_shape, dtype=tf.float32),
tf.TensorSpec(ifm2_shape, dtype=tf.float32)
+ )
+
+ # Convert the model
+ def representative_dataset():
+ for _ in range(100):
+ data = np.random.rand(*tuple(ifm_shape))
+ data2 = np.random.rand(*tuple(ifm2_shape)) * 2
+ yield [data.astype(np.float32), data2.astype(np.float32)]
+
+ converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset
+ converter.target_spec.supported_ops =
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+ tflite_model = converter.convert()
+ return tflite_model
+
+ tflite_graph = create_tflite_graph()
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+ mod, params = relay.frontend.from_tflite(
+ tflite_model,
+ shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape},
+ dtype_dict={"ifm": dtype, "ifm2": dtype},
+ )
+ mod = partition_for_ethosu(mod, params)
+
+ # Generate reference data
+ input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+ compiled_models = infra.build_source(
+ mod,
+ input_data,
+ output_data,
+ accel_type,
+ output_tolerance=1 if operator_type == "MAX" else 0,
+ )
+
+ # Assumes only two runtime.Modules are created -- i.e. single offload
module
+ imported_modules = compiled_models[0].executor_factory.lib.imported_modules
+ assert len(imported_modules) == 2
+ ethosu_module = imported_modules[0]
+
+ # Verify generated C source
+ get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+ cmms = get_cs(ethosu_module)
+ cmms = bytes.fromhex(cmms)
+
+ infra.print_payload(cmms)
+ infra.verify_source(compiled_models, accel_type)
+
+
[email protected]("accel_type", ACCEL_TYPES)
[email protected](
+ "ifm_shape, ifm2_shape",
+ [
+ ([1, 2, 3, 4], [1, 2, 3, 4]),
+ ([1, 2, 3, 4], [1, 1, 3, 1]),
+ ([1, 1, 3, 1], [1, 2, 3, 4]),
+ ],
+)
+def test_ethosu_left_shift_binary_elemwise(
+ accel_type,
+ ifm_shape,
+ ifm2_shape,
+):
+ dtype = "int32"
+
+ def create_model():
+ ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+ ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
+ c1 = relay.left_shift(ifm, ifm2)
+ f = relay.Function([ifm, ifm2], c1)
+ mod = tvm.IRModule()
+ mod["main"] = f
+ return mod
+
+ relay_mod = create_model()
+ mod = partition_for_ethosu(relay_mod)
+
+ # Generate reference data
+ in_min, in_max = util.get_range_for_dtype_str(dtype)
+ input_data = {
+ "ifm": np.random.randint(in_min, high=in_max, size=ifm_shape,
dtype=dtype),
+ "ifm2": np.random.randint(0, high=32, size=ifm2_shape, dtype=dtype),
+ }
+ output_data = generate_ref_data(relay_mod, input_data)
+
+ compiled_models = infra.build_source(
+ mod,
+ input_data,
+ output_data,
+ accel_type,
+ )
+
+ # Assumes only two runtime.Modules are created -- i.e. single offload
module
+ imported_modules = compiled_models[0].executor_factory.lib.imported_modules
+ assert len(imported_modules) == 2
+ ethosu_module = imported_modules[0]
+
+ # Verify generated C source
+ get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+ cmms = get_cs(ethosu_module)
+ cmms = bytes.fromhex(cmms)
+
+ infra.print_payload(cmms)
+ infra.verify_source(compiled_models, accel_type)
+
+
[email protected]("accel_type", ACCEL_TYPES)
[email protected](
+ "ifm_shape, ifm2_shape, reversed_operands, ofm_dtype",
+ [
+ ([1, 2, 3, 4], [1, 2, 3, 4], False, "int8"),
+ ([1, 2, 3, 1], [1, 1, 3, 1], False, "int32"),
+ ([1, 1, 3, 1], [1, 2, 3, 1], True, "int32"),
+ ],
+)
+def test_ethosu_right_shift_binary_elemwise(
+ ifm_shape, ifm2_shape, reversed_operands, accel_type, ofm_dtype
+):
+ dtype = "int32"
+
+ def create_model():
+ ifm_count = int(np.prod(ifm_shape))
+ ifm2_count = int(np.prod(ifm2_shape))
+
+ # Create a "partitioned" Relay function
+ ifms = relay.var("ifms", shape=[ifm_count + ifm2_count], dtype=dtype)
+ split = relay.split(ifms, [ifm_count])
+ ifm = relay.reshape(split[0], newshape=ifm_shape)
+ ifm2 = relay.reshape(split[1], newshape=ifm2_shape)
+ shr_op = infra.make_ethosu_binary_elementwise(
+ ifm, ifm2, ifm_shape[3], ifm2_shape[3], "SHR", ofm_dtype,
reversed_operands
+ )
+
+ glb_ethosu = relay.GlobalVar("tvmgen_default_ethosu_main_0")
+ func = (
+ relay.Function([ifms], shr_op)
+ .with_attr("Inline", 1)
+ .with_attr("Compiler", "ethosu")
+ .with_attr("global_symbol", "tvmgen_default_ethosu_main_0")
+ .with_attr("Primitive", 1)
+ )
+ mod = tvm.IRModule()
+ mod[glb_ethosu] = func
+ mod = relay.transform.InferType()(mod)
+
+ # Main
+ ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+ ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
+ call = relay.Call(
+ glb_ethosu,
+ [
+ relay.concatenate(
+ data=(
+ relay.reshape(ifm, newshape=ifm_count),
+ relay.reshape(ifm2, newshape=ifm2_count),
+ ),
+ axis=0,
+ )
+ ],
+ )
+ mod["main"] = relay.Function([ifm, ifm2], call)
+ mod = relay.transform.InferType()(mod)
+ return mod
+
+ mod = create_model()
+
+ # Generate reference data
+ in_min, in_max = util.get_range_for_dtype_str(dtype)
+ in_min, in_max = 18, 19
+ lhs = np.random.randint(in_min, high=in_max, size=ifm_shape, dtype=dtype)
+ rhs = np.random.randint(1, high=2, size=ifm2_shape, dtype=dtype)
+ input_data = {
+ "ifm": lhs,
+ "ifm2": rhs,
+ }
+
+ if reversed_operands:
+ lhs = np.broadcast_to(lhs, ifm2_shape)
+ lhs, rhs = rhs, lhs
+ else:
+ rhs = np.broadcast_to(rhs, ifm_shape)
+
+ def rounding_right_shift(lhs, rhs):
+ r = 1 << (rhs - 1)
+ return (lhs + r) >> rhs
+
+ output_data = np.array(
+ [rounding_right_shift(x[0], x[1]) for x in zip(lhs.flat, rhs.flat)]
+ ).astype(ofm_dtype)
+
+ compiled_model = infra.build_source(mod, input_data, [output_data],
accel_type)
+
+ imported_modules = compiled_model[0].executor_factory.lib.imported_modules
+ assert len(imported_modules) == 2
+ ethosu_module = imported_modules[0]
+
+ # Verify generated C source
+ get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+ cmms = get_cs(ethosu_module)
+ cmms = bytes.fromhex(cmms)
+
+ infra.print_payload(cmms)
+ infra.verify_source(compiled_model, accel_type)
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py
b/tests/python/contrib/test_ethosu/test_legalize.py
index fc03a98..2a84a23 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -558,5 +558,193 @@ def test_tflite_pool2d_legalize(
verify(mod["tvmgen_default_ethosu_main_0"])
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
[email protected](
+ "ifm_shape, ifm2_shape, reversed_operands",
+ [
+ ([1, 2, 3, 4], [1, 2, 3, 4], False),
+ ([1, 2, 3, 4], [1, 1, 3, 1], False),
+ ([1, 1, 3, 1], [1, 2, 3, 4], True),
+ ],
+)
[email protected]("activation_function", ["NONE", "RELU"])
+def test_tflite_binary_elemwise_legalize(
+ operator_type,
+ ifm_shape,
+ ifm2_shape,
+ reversed_operands,
+ activation_function,
+):
+ dtype = "int8"
+
+ def create_tflite_graph():
+ class Model(tf.Module):
+ @tf.function
+ def tf_function(self, x, y):
+ if operator_type == "ADD":
+ op = tf.math.add(x, y)
+ elif operator_type == "SUB":
+ op = tf.math.subtract(x, y)
+ elif operator_type == "MUL":
+ op = tf.math.multiply(x, y)
+ elif operator_type == "MIN":
+ op = tf.math.minimum(x, y)
+ elif operator_type == "MAX":
+ op = tf.math.maximum(x, y)
+ if activation_function == "RELU":
+ op = tf.nn.relu(op)
+ return op
+
+ model = Model()
+ concrete_func = model.tf_function.get_concrete_function(
+ tf.TensorSpec(ifm_shape, dtype=tf.float32),
tf.TensorSpec(ifm2_shape, dtype=tf.float32)
+ )
+
+ # Convert the model
+ def representative_dataset():
+ for _ in range(100):
+ data = np.random.rand(*tuple(ifm_shape))
+ data2 = np.random.rand(*tuple(ifm2_shape)) * 2
+ yield [data.astype(np.float32), data2.astype(np.float32)]
+
+ converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset
+ converter.target_spec.supported_ops =
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+ tflite_model = converter.convert()
+ return tflite_model
+
+ def verify(ext_func):
+ out_shape = ifm2_shape if reversed_operands else ifm_shape
+ shapes = [ifm_shape, ifm2_shape]
+ ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1)
+ op = ext_func.body
+ assert list(op.args[0].checked_type.shape) == shapes[ifm_index]
+ assert list(op.args[1].checked_type.shape) == shapes[ifm2_index]
+ assert op.args[0].checked_type.dtype == dtype
+ assert list(op.checked_type.shape) == out_shape
+ assert op.checked_type.dtype == dtype
+ assert op.attrs.operator_type == operator_type
+ assert op.attrs.reversed_operands == reversed_operands
+ if activation_function == "RELU":
+ assert str(op.attrs.activation) == "CLIP"
+
+ if operator_type == "ADD":
+ rewriter = legalize.AddRewriter()
+ pattern_table = [
+ (
+ ethosu.AddParams.composite_name,
+ ethosu.qnn_add_pattern(),
+ lambda pat: ethosu.AddParams(pat).is_valid(),
+ ),
+ ]
+ elif operator_type == "SUB":
+ rewriter = legalize.SubRewriter()
+ pattern_table = [
+ (
+ ethosu.SubParams.composite_name,
+ ethosu.qnn_subtract_pattern(),
+ lambda pat: ethosu.SubParams(pat).is_valid(),
+ ),
+ ]
+ elif operator_type == "MUL":
+ rewriter = legalize.MulRewriter()
+ pattern_table = [
+ (
+ ethosu.MulParams.composite_name,
+ ethosu.qnn_mul_pattern(),
+ lambda pat: ethosu.MulParams(pat).is_valid(),
+ ),
+ ]
+ elif operator_type == "MIN":
+ rewriter = legalize.MinRewriter()
+ pattern_table = [
+ (
+ ethosu.MinParams.composite_name,
+ ethosu.minimum_pattern(),
+ lambda pat: ethosu.MinParams(pat).is_valid(),
+ ),
+ ]
+ elif operator_type == "MAX":
+ rewriter = legalize.MaxRewriter()
+ pattern_table = [
+ (
+ ethosu.MaxParams.composite_name,
+ ethosu.maximum_pattern(),
+ lambda pat: ethosu.MaxParams(pat).is_valid(),
+ ),
+ ]
+
+ tflite_graph = create_tflite_graph()
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+ mod, _ = relay.frontend.from_tflite(
+ tflite_model,
+ shape_dict={"x": ifm_shape, "y": ifm2_shape},
+ dtype_dict={"x": dtype, "y": dtype},
+ )
+ mod = partition_ethosu_by_table(mod, pattern_table)
+
+ mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
+ rewriter, mod["tvmgen_default_ethosu_main_0"]
+ )
+ verify(mod["tvmgen_default_ethosu_main_0"])
+
+
[email protected](
+ "ifm_shape, ifm2_shape, reversed_operands",
+ [
+ ([1, 2, 3, 4], [1, 2, 3, 4], False),
+ ([1, 2, 3, 4], [1, 1, 3, 1], False),
+ ([1, 1, 3, 1], [1, 2, 3, 4], True),
+ ],
+)
+def test_ethosu_left_shift_binary_elemwise_legalize(ifm_shape, ifm2_shape,
reversed_operands):
+ dtype = "int32"
+ operator_type = "SHL"
+
+ def create_graph():
+ input1 = relay.var("x1", shape=ifm_shape, dtype=dtype)
+ input2 = relay.var("x2", shape=ifm2_shape, dtype=dtype)
+ c1 = relay.left_shift(input1, input2)
+ f = relay.Function([input1, input2], c1)
+ mod = tvm.IRModule()
+ mod["main"] = f
+ return mod
+
+ def verify(ext_func):
+ out_shape = ifm2_shape if reversed_operands else ifm_shape
+ shapes = [ifm_shape, ifm2_shape]
+ ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1)
+ op = ext_func.body
+ assert list(op.args[0].checked_type.shape) == shapes[ifm_index]
+ assert list(op.args[1].checked_type.shape) == shapes[ifm2_index]
+ assert op.args[0].checked_type.dtype == dtype
+ assert list(op.checked_type.shape) == out_shape
+ assert op.checked_type.dtype == dtype
+ assert op.attrs.operator_type == operator_type
+ assert op.attrs.reversed_operands == reversed_operands
+ assert str(op.attrs.activation) == "NONE"
+
+ rewriter = legalize.ShlRewriter()
+ pattern_table = [
+ (
+ ethosu.ShlParams.composite_name,
+ ethosu.shl_pattern(),
+ lambda pat: ethosu.ShlParams(pat).is_valid(),
+ ),
+ ]
+
+ mod = create_graph()
+ mod = partition_ethosu_by_table(mod, pattern_table)
+
+ mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
+ rewriter, mod["tvmgen_default_ethosu_main_0"]
+ )
+ verify(mod["tvmgen_default_ethosu_main_0"])
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git
a/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py
b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py
new file mode 100644
index 0000000..6dcd9da
--- /dev/null
+++ b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py
@@ -0,0 +1,335 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+pytest.importorskip("ethosu.vela")
+
+import tvm
+from tvm import relay
+from tvm.relay.testing import run_opt_pass
+from tvm.relay.backend.contrib.ethosu.tir import spec
+from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir
+from .infra import make_ethosu_binary_elementwise, get_binary_elementwise_args
+
+
[email protected](
+ "ifm_shape, ifm2_shape, ifm_channels, ifm2_channels, ifm_layout,
ofm_layout",
+ [
+ ((1, 5, 9, 3), (1, 5, 9, 3), 3, 3, "NHWC", "NHWC"),
+ ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHCWB16"),
+ ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHWC"),
+ ((1, 8, 9, 40), (1, 8, 9, 40), 40, 40, "NHWC", "NHCWB16"),
+ # Broadcast
+ ((1, 5, 9, 3), (1, 1, 9, 1), 3, 1, "NHWC", "NHWC"),
+ ((1, 8, 9, 40), (1, 1, 1, 1), 40, 1, "NHWC", "NHCWB16"),
+ ],
+)
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
[email protected]("activation", ["NONE", "CLIP"])
+def test_binary_elementwise_single(
+ ifm_shape,
+ ifm2_shape,
+ ifm_channels,
+ ifm2_channels,
+ ifm_layout,
+ ofm_layout,
+ operator_type,
+ activation,
+):
+ dtype = "int8"
+ ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+ ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
+
+ binary_elementwise = make_ethosu_binary_elementwise(
+ ifm,
+ ifm2,
+ ifm_channels,
+ ifm2_channels,
+ operator_type,
+ dtype,
+ False,
+ activation,
+ ifm_layout,
+ ifm_layout,
+ ofm_layout,
+ )
+ func = relay.Function(relay.analysis.free_vars(binary_elementwise),
binary_elementwise)
+ func = run_opt_pass(func, relay.transform.InferType())
+ mod, _ = lower_to_tir(func)
+ data = []
+
+ def _visit(stmt):
+ if isinstance(stmt, tvm.tir.Call):
+ data.append(get_binary_elementwise_args(stmt))
+
+ tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit)
+ if ifm_layout == "NHWC":
+ ifm_stride_c = 1
+ ifm_stride_w = ifm_shape[3] if ifm_shape[2] != 1 else 1
+ ifm_stride_h = ifm_shape[2] * ifm_shape[3] if ifm_shape[1] != 1 else 1
+
+ ifm2_stride_c = 1
+ ifm2_stride_w = ifm2_shape[3] if ifm2_shape[2] != 1 else 1
+ ifm2_stride_h = ifm2_shape[2] * ifm2_shape[3] if ifm2_shape[1] != 1
else 1
+
+ ofm_height = ifm_shape[1]
+ ofm_width = ifm_shape[2]
+ else:
+ ifm_stride_w = 16
+ ifm_stride_c = 16 * ifm_shape[3]
+ ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3]
+
+ ifm2_stride_w = 16
+ ifm2_stride_c = 16 * ifm2_shape[3]
+ ifm2_stride_h = 16 * ifm2_shape[2] * ifm2_shape[3]
+
+ ofm_height = ifm_shape[1]
+ ofm_width = ifm_shape[3]
+
+ if ofm_layout == "NHWC":
+ ofm_stride_c = 1
+ ofm_stride_w = ifm_channels if ofm_width > 1 else 1
+ ofm_stride_h = ifm_channels * ofm_width if ofm_height > 1 else 1
+ else:
+ ofm_stride_w = 16
+ ofm_stride_c = 16 * ofm_width
+ ofm_stride_h = 16 * ofm_width * ((ifm_channels - 1) // 16 + 1)
+
+ serial_binary_elementwise = spec.SerialBinaryElementwise(
+ ifm=spec.SerialFeatureMap(
+ data_type=dtype,
+ height=ifm_shape[1],
+ width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3],
+ channels=ifm_channels,
+ tile_height_0=ifm_shape[1],
+ tile_height_1=0,
+ tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else
ifm_shape[3],
+ tile_address_0=0,
+ tile_address_1=0,
+ tile_address_2=0,
+ tile_address_3=0,
+ scale=1.0,
+ zero_point=0,
+ layout=ifm_layout,
+ stride_h=ifm_stride_h,
+ stride_w=ifm_stride_w,
+ stride_c=ifm_stride_c,
+ ),
+ ifm2=spec.SerialFeatureMap(
+ data_type=dtype,
+ height=ifm2_shape[1],
+ width=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3],
+ channels=ifm2_channels,
+ tile_height_0=ifm2_shape[1],
+ tile_height_1=0,
+ tile_width_0=ifm2_shape[2] if ifm_layout == "NHWC" else
ifm2_shape[3],
+ tile_address_0=0,
+ tile_address_1=0,
+ tile_address_2=0,
+ tile_address_3=0,
+ scale=1.0,
+ zero_point=0,
+ layout=ifm_layout,
+ stride_h=ifm2_stride_h,
+ stride_w=ifm2_stride_w,
+ stride_c=ifm2_stride_c,
+ ),
+ ofm=spec.SerialFeatureMap(
+ data_type=dtype,
+ height=ofm_height,
+ width=ofm_width,
+ channels=ifm_channels,
+ tile_height_0=ofm_height,
+ tile_height_1=0,
+ tile_width_0=ofm_width,
+ tile_address_0=0,
+ tile_address_1=0,
+ tile_address_2=0,
+ tile_address_3=0,
+ scale=1.0,
+ zero_point=0,
+ layout=ofm_layout,
+ stride_h=ofm_stride_h,
+ stride_w=ofm_stride_w,
+ stride_c=ofm_stride_c,
+ ),
+ operator_type=operator_type,
+ reversed_operands=False,
+ activation=spec.SerialActivation(
+ op=activation,
+ clip_min=10 if activation == "CLIP" else 0,
+ clip_max=100 if activation == "CLIP" else 0,
+ ),
+ )
+
+ assert data[0] == ["ethosu_binary_elementwise"] +
list(serial_binary_elementwise)
+
+
[email protected](
+ "ifm_shape, ifm2_shape, ifm_channels, ifm2_channels, ifm_layout,
ofm_layout",
+ [
+ ((1, 5, 9, 3), (1, 5, 9, 3), 3, 3, "NHWC", "NHWC"),
+ ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHCWB16"),
+ ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHWC"),
+ ((1, 8, 9, 40), (1, 8, 9, 40), 40, 40, "NHWC", "NHCWB16"),
+ # Broadcast
+ ((1, 5, 9, 3), (1, 1, 9, 1), 3, 1, "NHWC", "NHWC"),
+ ((1, 8, 9, 40), (1, 1, 1, 1), 40, 1, "NHWC", "NHCWB16"),
+ ],
+)
[email protected]("operator_type", ["SHR", "SHL"])
+def test_shift_binary_elementwise_single(
+ ifm_shape,
+ ifm2_shape,
+ ifm_channels,
+ ifm2_channels,
+ ifm_layout,
+ ofm_layout,
+ operator_type,
+):
+ dtype = "int32"
+ activation = "NONE" # Only NONE is available if the activation type is
int32
+ ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+ ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype)
+
+ binary_elementwise = make_ethosu_binary_elementwise(
+ ifm,
+ ifm2,
+ ifm_channels,
+ ifm2_channels,
+ operator_type,
+ dtype,
+ False,
+ "NONE",
+ ifm_layout,
+ ifm_layout,
+ ofm_layout,
+ )
+ func = relay.Function(relay.analysis.free_vars(binary_elementwise),
binary_elementwise)
+ func = run_opt_pass(func, relay.transform.InferType())
+ mod, _ = lower_to_tir(func)
+ data = []
+
+ def _visit(stmt):
+ if isinstance(stmt, tvm.tir.Call):
+ data.append(get_binary_elementwise_args(stmt))
+
+ tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit)
+ if ifm_layout == "NHWC":
+ ifm_stride_c = 1
+ ifm_stride_w = ifm_shape[3] if ifm_shape[2] != 1 else 1
+ ifm_stride_h = ifm_shape[2] * ifm_shape[3] if ifm_shape[1] != 1 else 1
+
+ ifm2_stride_c = 1
+ ifm2_stride_w = ifm2_shape[3] if ifm2_shape[2] != 1 else 1
+ ifm2_stride_h = ifm2_shape[2] * ifm2_shape[3] if ifm2_shape[1] != 1
else 1
+
+ ofm_height = ifm_shape[1]
+ ofm_width = ifm_shape[2]
+ else:
+ ifm_stride_w = 16
+ ifm_stride_c = 16 * ifm_shape[3]
+ ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3]
+
+ ifm2_stride_w = 16
+ ifm2_stride_c = 16 * ifm2_shape[3]
+ ifm2_stride_h = 16 * ifm2_shape[2] * ifm2_shape[3]
+
+ ofm_height = ifm_shape[1]
+ ofm_width = ifm_shape[3]
+
+ if ofm_layout == "NHWC":
+ ofm_stride_c = 1
+ ofm_stride_w = ifm_channels if ofm_width > 1 else 1
+ ofm_stride_h = ifm_channels * ofm_width if ofm_height > 1 else 1
+ else:
+ ofm_stride_w = 16
+ ofm_stride_c = 16 * ofm_width
+ ofm_stride_h = 16 * ofm_width * ((ifm_channels - 1) // 16 + 1)
+
+ serial_binary_elementwise = spec.SerialBinaryElementwise(
+ ifm=spec.SerialFeatureMap(
+ data_type=dtype,
+ height=ifm_shape[1],
+ width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3],
+ channels=ifm_channels,
+ tile_height_0=ifm_shape[1],
+ tile_height_1=0,
+ tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else
ifm_shape[3],
+ tile_address_0=0,
+ tile_address_1=0,
+ tile_address_2=0,
+ tile_address_3=0,
+ scale=1.0,
+ zero_point=0,
+ layout=ifm_layout,
+ stride_h=ifm_stride_h,
+ stride_w=ifm_stride_w,
+ stride_c=ifm_stride_c,
+ ),
+ ifm2=spec.SerialFeatureMap(
+ data_type=dtype,
+ height=ifm2_shape[1],
+ width=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3],
+ channels=ifm2_channels,
+ tile_height_0=ifm2_shape[1],
+ tile_height_1=0,
+ tile_width_0=ifm2_shape[2] if ifm_layout == "NHWC" else
ifm2_shape[3],
+ tile_address_0=0,
+ tile_address_1=0,
+ tile_address_2=0,
+ tile_address_3=0,
+ scale=1.0,
+ zero_point=0,
+ layout=ifm_layout,
+ stride_h=ifm2_stride_h,
+ stride_w=ifm2_stride_w,
+ stride_c=ifm2_stride_c,
+ ),
+ ofm=spec.SerialFeatureMap(
+ data_type=dtype,
+ height=ofm_height,
+ width=ofm_width,
+ channels=ifm_channels,
+ tile_height_0=ofm_height,
+ tile_height_1=0,
+ tile_width_0=ofm_width,
+ tile_address_0=0,
+ tile_address_1=0,
+ tile_address_2=0,
+ tile_address_3=0,
+ scale=1.0,
+ zero_point=0,
+ layout=ofm_layout,
+ stride_h=ofm_stride_h,
+ stride_w=ofm_stride_w,
+ stride_c=ofm_stride_c,
+ ),
+ operator_type=operator_type,
+ reversed_operands=False,
+ activation=spec.SerialActivation(
+ op=activation,
+ clip_min=0,
+ clip_max=0,
+ ),
+ )
+
+ assert data[0] == ["ethosu_binary_elementwise"] +
list(serial_binary_elementwise)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
index f4b83a4..ab1bad2 100644
--- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
+++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
@@ -913,5 +913,439 @@ def test_translate_ethosu_pooling():
assert npu_op.ifm_upscale == vapi.NpuResamplingMode.NONE
+# fmt: off
+"""A ethosu_binary_elementwise ADD tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseAdd:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(
+ placeholder, [270], dtype="int8", elem_offset=0, align=128,
offset_factor=1
+ )
+ ethosu_write_2 = T.match_buffer(
+ ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0,
align=128, offset_factor=1
+ )
+ # body
+ T.evaluate(T.call_extern( "ethosu_binary_elementwise", "int8", 5, 9,
3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data,
135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9,
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27,
3, 1, "ADD", 0, "CLIP", 10, 100, dtype="int8"))
+
+ __tvm_meta__ = None
+# fmt: on
+
+# fmt: off
+"""A ethosu_binary_elementwise SUB tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseSub:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3],
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3,
5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data,
135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9,
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27,
3, 1, "SUB", 0, "CLIP", 10, 100, dtype="int8"))
+ __tvm_meta__ = None
+# fmt: on
+
+# fmt: off
+"""A ethosu_binary_elementwise MUL tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMul:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3],
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3,
5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data,
135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9,
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27,
3, 1, "MUL", 0, "CLIP", 10, 100, dtype="int8"))
+ __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise MIN tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMin:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3],
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3,
5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data,
135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9,
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27,
3, 1, "MIN", 0, "CLIP", 10, 100, dtype="int8"))
+ __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise Max tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMax:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3],
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3,
5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data,
135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9,
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27,
3, 1, "MAX", 0, "CLIP", 10, 100, dtype="int8"))
+ __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise SHR tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseShr:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3],
dtype="int32", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9,
3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32",
placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1,
"int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0,
T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHR", 0, "NONE", 0, 0, dtype="int32"))
+ __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise SHL tir testcase for the translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseShl:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3],
dtype="int32", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9,
3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32",
placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1,
"int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0,
T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHL", 0, "CLIP", 10, 100, dtype="int32"))
+ __tvm_meta__ = None
+# fmt: on
+
+
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX",
"SHR", "SHL"])
+def test_translate_ethosu_binary_elementwise(operator_type):
+ if operator_type == "SHR" or operator_type == "SHL":
+ data_type = vapi.NpuDataType.INT32
+ data_type_bytes = 4
+ else:
+ data_type = vapi.NpuDataType.INT8
+ data_type_bytes = 1
+
+ def extract_ethosu_binary_elementwise_call_extern(mod):
+ # There should only be a single function
+ assert len(mod.functions.items()) == 1
+ primfunc = mod.functions.items()[0][1]
+
+ ethosu_binary_elementwise_calls = list()
+
+ def populate_ethosu_binary_elementwise_calls(stmt):
+ if (
+ isinstance(stmt, tvm.tir.Call)
+ and stmt.op.name == "tir.call_extern"
+ and stmt.args[0] == "ethosu_binary_elementwise"
+ ):
+ ethosu_binary_elementwise_calls.append(stmt)
+
+ stmt_functor.post_order_visit(primfunc.body,
populate_ethosu_binary_elementwise_calls)
+ return ethosu_binary_elementwise_calls[0]
+
+ if operator_type == "ADD":
+ binary_elementwise = SingleEthosuBinaryElementwiseAdd
+ elif operator_type == "SUB":
+ binary_elementwise = SingleEthosuBinaryElementwiseSub
+ elif operator_type == "MUL":
+ binary_elementwise = SingleEthosuBinaryElementwiseMul
+ elif operator_type == "MIN":
+ binary_elementwise = SingleEthosuBinaryElementwiseMin
+ elif operator_type == "MAX":
+ binary_elementwise = SingleEthosuBinaryElementwiseMax
+ elif operator_type == "SHR":
+ binary_elementwise = SingleEthosuBinaryElementwiseShr
+ elif operator_type == "SHL":
+ binary_elementwise = SingleEthosuBinaryElementwiseShl
+ binary_elementwise_call =
extract_ethosu_binary_elementwise_call_extern(binary_elementwise)
+ npu_op =
tir_to_cs_translator.translate_ethosu_binary_elementwise(binary_elementwise_call)
+
+ # Compare IFM
+ assert npu_op.ifm.data_type == data_type
+ assert npu_op.ifm.shape == vapi.NpuShape3D(5, 9, 3)
+ assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0,
0]).height_0
+ assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0,
0]).height_1
+ assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0,
0]).width_0
+ assert npu_op.ifm.quantization == vapi.NpuQuantization(1.0, 0)
+ assert npu_op.ifm.layout == vapi.NpuLayout.NHWC
+ assert npu_op.ifm.strides == vapi.NpuShape3D(
+ 27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes
+ )
+ # Compare IFM2
+ assert npu_op.ifm2.data_type == data_type
+ assert npu_op.ifm2.shape == vapi.NpuShape3D(5, 9, 3)
+ assert npu_op.ifm2.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0,
0]).height_0
+ assert npu_op.ifm2.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0,
0]).height_1
+ assert npu_op.ifm2.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0,
0]).width_0
+ assert npu_op.ifm2.quantization == vapi.NpuQuantization(1.0, 0)
+ assert npu_op.ifm2.layout == vapi.NpuLayout.NHWC
+ assert npu_op.ifm2.strides == vapi.NpuShape3D(
+ 27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes
+ )
+ # Compare OFM
+ assert npu_op.ofm.data_type == data_type
+ assert npu_op.ofm.shape == vapi.NpuShape3D(5, 9, 3)
+ assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0,
0]).height_0
+ assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0,
0]).height_1
+ assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0,
0]).width_0
+ assert npu_op.ofm.quantization == vapi.NpuQuantization(1.0, 0)
+ assert npu_op.ofm.layout == vapi.NpuLayout.NHWC
+ assert npu_op.ofm.strides == vapi.NpuShape3D(
+ 27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes
+ )
+ # Compare op type
+ if operator_type == "ADD":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.ADD
+ elif operator_type == "SUB":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SUB
+ elif operator_type == "MUL":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MUL
+ elif operator_type == "MIN":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MIN
+ elif operator_type == "MAX":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MAX
+ elif operator_type == "SHR":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHR
+ elif operator_type == "SHL":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHL
+ # Compare reversed_operands
+ assert npu_op.reversed_operands == False
+ # Compare activation
+ if operator_type == "SHR":
+ assert npu_op.activation is None
+ else:
+ assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU
+ assert npu_op.activation.min == 10
+ assert npu_op.activation.max == 100
+
+
+# fmt: off
+"""A ethosu_binary_elementwise ADD with broadcasting tir testcase for the
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseAddBroadcasting:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4],
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4,
2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data,
0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3,
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12,
4, 1, "ADD", 1, "CLIP", 10, 100, dtype="int8"))
+ __tvm_meta__ = None
+# fmt: on
+
+# fmt: off
+"""A ethosu_binary_elementwise SUB with broadcasting tir testcase for the
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseSubBroadcasting:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4],
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4,
2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data,
0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3,
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12,
4, 1, "SUB", 1, "CLIP", 10, 100, dtype="int8"))
+ __tvm_meta__ = None
+# fmt: on
+
+# fmt: off
+"""A ethosu_binary_elementwise MUL with broadcasting tir testcase for the
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMulBroadcasting:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4],
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4,
2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data,
0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3,
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12,
4, 1, "MUL", 1, "CLIP", 10, 100, dtype="int8"))
+ __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise MIN with broadcasting tir testcase for the
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMinBroadcasting:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4],
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4,
2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data,
0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3,
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12,
4, 1, "MIN", 1, "CLIP", 10, 100, dtype="int8"))
+ __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise MAX with broadcasting tir testcase for the
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseMaxBroadcasting:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4],
dtype="int8", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4,
2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data,
0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3,
T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12,
4, 1, "MAX", 1, "CLIP", 10, 100, dtype="int8"))
+ __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise SHR with broadcasting tir testcase for the
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseShrBroadcasting:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4],
dtype="int32", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3,
4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32",
placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32",
2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0,
T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHR", 1, "NONE", 0, 0, dtype="int32"))
+ __tvm_meta__ = None
+# fmt: on
+
+
+# fmt: off
+"""A ethosu_binary_elementwise SHL with broadcasting tir testcase for the
translator"""
[email protected]_module
+class SingleEthosuBinaryElementwiseShlBroadcasting:
+ @T.prim_func
+ def main(placeholder: T.handle, ethosu_write: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32",
elem_offset=0, align=128, offset_factor=1)
+ ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4],
dtype="int32", elem_offset=0, align=128, offset_factor=1)
+ # body
+ T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3,
4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0,
"NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32",
placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32",
2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0,
T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHL", 1, "CLIP", 10, 100, dtype="int32"))
+ __tvm_meta__ = None
+# fmt: on
+
+
[email protected]("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX",
"SHR", "SHL"])
+def test_translate_ethosu_binary_elementwise_broadcasting(operator_type):
+ if operator_type == "SHR" or operator_type == "SHL":
+ data_type = vapi.NpuDataType.INT32
+ data_type_bytes = 4
+ else:
+ data_type = vapi.NpuDataType.INT8
+ data_type_bytes = 1
+
+ def extract_ethosu_binary_elementwise_broadcasting_call_extern(mod):
+ # There should only be a single function
+ assert len(mod.functions.items()) == 1
+ primfunc = mod.functions.items()[0][1]
+
+ ethosu_binary_elementwise_calls = list()
+
+ def populate_ethosu_binary_elementwise_calls(stmt):
+ if (
+ isinstance(stmt, tvm.tir.Call)
+ and stmt.op.name == "tir.call_extern"
+ and stmt.args[0] == "ethosu_binary_elementwise"
+ ):
+ ethosu_binary_elementwise_calls.append(stmt)
+
+ stmt_functor.post_order_visit(primfunc.body,
populate_ethosu_binary_elementwise_calls)
+ return ethosu_binary_elementwise_calls[0]
+
+ if operator_type == "ADD":
+ binary_elementwise = SingleEthosuBinaryElementwiseAddBroadcasting
+ elif operator_type == "SUB":
+ binary_elementwise = SingleEthosuBinaryElementwiseSubBroadcasting
+ elif operator_type == "MUL":
+ binary_elementwise = SingleEthosuBinaryElementwiseMulBroadcasting
+ elif operator_type == "MIN":
+ binary_elementwise = SingleEthosuBinaryElementwiseMinBroadcasting
+ elif operator_type == "MAX":
+ binary_elementwise = SingleEthosuBinaryElementwiseMaxBroadcasting
+ elif operator_type == "SHR":
+ binary_elementwise = SingleEthosuBinaryElementwiseShrBroadcasting
+ elif operator_type == "SHL":
+ binary_elementwise = SingleEthosuBinaryElementwiseShlBroadcasting
+ binary_elementwise_call =
extract_ethosu_binary_elementwise_broadcasting_call_extern(
+ binary_elementwise
+ )
+ npu_op =
tir_to_cs_translator.translate_ethosu_binary_elementwise(binary_elementwise_call)
+
+ # Compare IFM
+ assert npu_op.ifm.data_type == data_type
+ assert npu_op.ifm.shape == vapi.NpuShape3D(2, 3, 4)
+ assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0,
0]).height_0
+ assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0,
0]).height_1
+ assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0,
0]).width_0
+ assert npu_op.ifm.quantization == vapi.NpuQuantization(1.0, 0)
+ assert npu_op.ifm.layout == vapi.NpuLayout.NHWC
+ assert npu_op.ifm.strides == vapi.NpuShape3D(
+ 12 * data_type_bytes, 4 * data_type_bytes, 1 * data_type_bytes
+ )
+ # Compare IFM2
+ assert npu_op.ifm2.data_type == data_type
+ assert npu_op.ifm2.shape == vapi.NpuShape3D(1, 3, 1)
+ assert npu_op.ifm2.tiles.height_0 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0,
0]).height_0
+ assert npu_op.ifm2.tiles.height_1 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0,
0]).height_1
+ assert npu_op.ifm2.tiles.width_0 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0,
0]).width_0
+ assert npu_op.ifm2.quantization == vapi.NpuQuantization(1.0, 0)
+ assert npu_op.ifm2.layout == vapi.NpuLayout.NHWC
+ assert npu_op.ifm2.strides == vapi.NpuShape3D(
+ 1 * data_type_bytes, 1 * data_type_bytes, 1 * data_type_bytes
+ )
+ # Compare OFM
+ assert npu_op.ofm.data_type == data_type
+ assert npu_op.ofm.shape == vapi.NpuShape3D(2, 3, 4)
+ assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0,
0]).height_0
+ assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0,
0]).height_1
+ assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0,
0]).width_0
+ assert npu_op.ofm.quantization == vapi.NpuQuantization(1.0, 0)
+ assert npu_op.ofm.layout == vapi.NpuLayout.NHWC
+ assert npu_op.ofm.strides == vapi.NpuShape3D(
+ 12 * data_type_bytes, 4 * data_type_bytes, 1 * data_type_bytes
+ )
+ # Compare op type
+ if operator_type == "ADD":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.ADD
+ elif operator_type == "SUB":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SUB
+ elif operator_type == "MUL":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MUL
+ elif operator_type == "MIN":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MIN
+ elif operator_type == "MAX":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MAX
+ elif operator_type == "SHR":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHR
+ elif operator_type == "SHL":
+ assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHL
+ # Compare reversed_operands
+ assert npu_op.reversed_operands == True
+ # Compare activation
+
+ if operator_type == "SHR":
+ assert npu_op.activation is None
+ else:
+ assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU
+ assert npu_op.activation.min == 10
+ assert npu_op.activation.max == 100
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py
b/tests/python/contrib/test_ethosu/test_type_inference.py
index ecbe31b..e068439 100644
--- a/tests/python/contrib/test_ethosu/test_type_inference.py
+++ b/tests/python/contrib/test_ethosu/test_type_inference.py
@@ -24,6 +24,7 @@ from tvm.relay.testing import run_opt_pass
from .infra import make_ethosu_conv2d
from .infra import make_ethosu_depthwise_conv2d
from .infra import make_ethosu_pooling
+from .infra import make_ethosu_binary_elementwise
@pytest.mark.parametrize(
@@ -226,5 +227,120 @@ def test_ethosu_pooling_invalid_dtype():
run_opt_pass(func, relay.transform.InferType())
[email protected](
+ "ifm_shape, ifm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16),
"NHCWB16")]
+)
[email protected](
+ "ofm_shape, ofm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16),
"NHCWB16")]
+)
+def test_ethosu_binary_elementwise_type_inference(
+ ifm_shape,
+ ifm_layout,
+ ofm_shape,
+ ofm_layout,
+):
+ dtype = "int8"
+ ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+ ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype)
+ operator_type = "ADD"
+ ifm_channels, ifm2_channels = 33, 33
+ binary_elementwise = make_ethosu_binary_elementwise(
+ ifm,
+ ifm2,
+ ifm_channels,
+ ifm2_channels,
+ operator_type,
+ dtype,
+ ifm_layout=ifm_layout,
+ ifm2_layout=ifm_layout,
+ ofm_layout=ofm_layout,
+ )
+ func = relay.Function([ifm, ifm2], binary_elementwise)
+ func = run_opt_pass(func, relay.transform.InferType())
+ assert tuple(func.body.checked_type.shape) == ofm_shape
+ assert func.body.checked_type.dtype == dtype
+
+
+def test_ethosu_binary_elementwise_invalid_operator_type():
+ invalid_operator_type = "A"
+ ifm_shape = [1, 4, 5, 33]
+ dtype = "int8"
+ ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+ ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype)
+ ifm_channels, ifm2_channels = 33, 33
+ binary_elementwise = make_ethosu_binary_elementwise(
+ ifm,
+ ifm2,
+ ifm_channels,
+ ifm2_channels,
+ invalid_operator_type,
+ dtype,
+ )
+ func = relay.Function([ifm, ifm2], binary_elementwise)
+ with pytest.raises(TVMError):
+ run_opt_pass(func, relay.transform.InferType())
+
+
+def test_ethosu_binary_elementwise_invalid_data_types():
+ dtype = "int8"
+ dtype2 = "int32"
+ operator_type = "ADD"
+ ifm_shape = [1, 4, 5, 33]
+ ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
+ ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype2)
+ ifm_channels, ifm2_channels = 33, 33
+ binary_elementwise = make_ethosu_binary_elementwise(
+ ifm,
+ ifm2,
+ ifm_channels,
+ ifm2_channels,
+ operator_type,
+ dtype,
+ )
+ func = relay.Function([ifm, ifm2], binary_elementwise)
+ with pytest.raises(TVMError):
+ run_opt_pass(func, relay.transform.InferType())
+
+
[email protected]("operator_type", ["MIN", "MAX"])
+def test_ethosu_binary_elementwise_min_max_invalid_data_type(operator_type):
+ invalid_dtype = "int32"
+ ifm_shape = [1, 4, 5, 33]
+ ifm = relay.var("ifm", shape=ifm_shape, dtype=invalid_dtype)
+ ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=invalid_dtype)
+ ifm_channels, ifm2_channels = 33, 33
+ binary_elementwise = make_ethosu_binary_elementwise(
+ ifm,
+ ifm2,
+ ifm_channels,
+ ifm2_channels,
+ operator_type,
+ invalid_dtype,
+ )
+ func = relay.Function([ifm, ifm2], binary_elementwise)
+ with pytest.raises(TVMError):
+ run_opt_pass(func, relay.transform.InferType())
+
+
[email protected]("invalid_dtype", ["int8", "uint8"])
[email protected]("operator_type", ["RHS", "SHR"])
+def test_ethosu_binary_elementwise_shift_invalid_data_type(invalid_dtype,
operator_type):
+ ifm_shape = [1, 4, 5, 33]
+ ifm = relay.var("ifm", shape=ifm_shape, dtype=invalid_dtype)
+ ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=invalid_dtype)
+ ifm_channels, ifm2_channels = 33, 33
+ binary_elementwise = make_ethosu_binary_elementwise(
+ ifm,
+ ifm2,
+ ifm_channels,
+ ifm2_channels,
+ operator_type,
+ invalid_dtype,
+ )
+ func = relay.Function([ifm, ifm2], binary_elementwise)
+ with pytest.raises(TVMError):
+ run_opt_pass(func, relay.transform.InferType())
+
+
if __name__ == "__main__":
pytest.main([__file__])