This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 815422cfc0 [microNPU] Add support for MEAN with uint8 ifm (#14353)
815422cfc0 is described below

commit 815422cfc04426912b1bc70e4dbac7a6814ea073
Author: Ilya Gozman <[email protected]>
AuthorDate: Thu Apr 13 14:09:59 2023 +0400

    [microNPU] Add support for MEAN with uint8 ifm (#14353)
    
    This PR involves supporting the legalization case of MEAN where axis == [1, 
2], keep_dims == True and input dtype == 'uint8'.
---
 .../tvm/relay/backend/contrib/ethosu/legalize.py   | 62 +++-------------------
 .../tvm/relay/backend/contrib/ethosu/op/pooling.py | 10 +++-
 .../tvm/relay/backend/contrib/ethosu/te/pooling.py |  6 ++-
 python/tvm/relay/op/contrib/ethosu.py              | 42 ++++++++++-----
 src/relay/op/contrib/ethosu/op_attrs.h             |  5 ++
 src/relay/op/contrib/ethosu/pooling.cc             | 19 +++++--
 .../cascader/test_ethosu_pooling_matcher.py        |  1 +
 tests/python/contrib/test_ethosu/infra.py          |  2 +
 tests/python/contrib/test_ethosu/test_codegen.py   | 31 +++++------
 .../contrib/test_ethosu/test_identity_optimizer.py | 14 +++--
 .../contrib/test_ethosu/test_layout_optimizer.py   | 53 +++++++++++++-----
 tests/python/contrib/test_ethosu/test_legalize.py  | 40 +++++++-------
 .../contrib/test_ethosu/test_replace_pooling.py    | 11 +++-
 .../contrib/test_ethosu/test_type_inference.py     |  9 +++-
 14 files changed, 177 insertions(+), 128 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py 
b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index 24dd9afd7b..5aaa1417ae 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -626,6 +626,7 @@ class PoolingRewriter(DFPatternCallback):
             ofm_zero_point=params.ofm.q_params.zero_point,
             pool_shape=params.pool_shape,
             
ofm_channels=params.ofm.shape[channels_map[str(params.ofm.layout)]],
+            ofm_dtype=params.ofm.dtype,
             strides=params.strides,
             padding=params.padding,
             activation=activation,
@@ -975,10 +976,8 @@ class AbsRewriter(UnaryElementwiseRewriter):
 
 class MeanRewriter(DFPatternCallback):
     """Convert ethosu.mean composite functions to an equivalent legalization:
-    - Case 1 (axis == [1, 2] and keepsdims == True):
-        ethosu_depthwise_conv2d + ethosu_binary_elementwise
-    - Case 2 (ifm qparams == ofm qparams): ethosu_pooling
-    - Case 3 (else): ethosu_depthwise_conv2d
+    - Case 1 (ifm qparams == ofm qparams): ethosu_pooling
+    - Case 2 (else): ethosu_depthwise_conv2d
     """
 
     def __init__(self):
@@ -1021,56 +1020,7 @@ class MeanRewriter(DFPatternCallback):
             filter_height = 1
             reduced_op = relay.reshape(reduced_op, ifm_shape)
 
-        if axis == [1, 2] and params.keepdims:
-            weight_scale = 1
-            weight_values = np.ones([out_channels, filter_height, 
filter_width, 1])
-            scale_bias = vela_api.pack_biases(
-                biases=np.zeros(ifm_shape[-1]),
-                ifm_scale=params.ifm.q_params.scale_f32,
-                ifm_dtype=np.dtype(params.ifm.dtype),
-                weight_scales=np.array([weight_scale], dtype=np.float),
-                ofm_scale=params.ofm.q_params.scale_f32,
-                is_activation_tanh_or_sigmoid=False,
-            )
-
-            reduced_op = ethosu_ops.ethosu_depthwise_conv2d(
-                ifm=reduced_op,
-                weight=relay.const(weight_values, params.ifm.dtype),
-                scale_bias=relay.const(scale_bias, "uint8"),
-                lut=lut,
-                ifm_scale=float(params.ifm.q_params.scale_f32),
-                ifm_zero_point=int(params.ifm.q_params.zero_point),
-                weight_zero_point=0,
-                ofm_scale=float(params.ofm.q_params.scale_f32),
-                ofm_zero_point=int(params.ofm.q_params.zero_point),
-                kernel_shape=(filter_height, filter_width),
-                ofm_channels=out_channels,
-                ofm_dtype="int16",
-            )
-
-            n = int(filter_height * filter_width)
-            eps = 1 / (256 * (n + 1)) if n % 2 == 0 else 0
-
-            scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int16"), 
dtype="int16")
-
-            reduced_op = ethosu_ops.ethosu_binary_elementwise(
-                ifm=reduced_op,
-                ifm2=scalar_tensor,
-                lut=lut,
-                operator_type="MUL",
-                ifm_scale=float(params.ofm.q_params.scale_f32),
-                ifm_zero_point=int(params.ofm.q_params.zero_point),
-                ifm2_scale=1 / (n - eps),
-                ifm2_zero_point=0,
-                ofm_scale=float(params.ofm.q_params.scale_f32),
-                ofm_zero_point=int(params.ofm.q_params.zero_point),
-                ifm_channels=out_channels,
-                ifm2_channels=out_channels,
-                reversed_operands=False,
-                ofm_dtype="int8",
-                rounding_mode="NATURAL",
-            )
-        elif (
+        if (
             params.ifm.q_params.scale_f32 == params.ofm.q_params.scale_f32
             and params.ifm.q_params.zero_point == 
params.ofm.q_params.zero_point
         ):
@@ -1084,6 +1034,7 @@ class MeanRewriter(DFPatternCallback):
                 ofm_zero_point=0,
                 pool_shape=(filter_height, filter_width),
                 ofm_channels=out_channels,
+                ofm_dtype=params.ofm.dtype,
                 rounding_mode="TRUNCATE",
             )
         else:
@@ -1112,6 +1063,7 @@ class MeanRewriter(DFPatternCallback):
                 kernel_shape=(filter_height, filter_width),
                 ofm_channels=out_channels,
                 rounding_mode="NATURAL",
+                ofm_dtype=params.ofm.dtype,
             )
 
         # Reshape to original ofm shape
@@ -1168,6 +1120,7 @@ class SumRewriter(DFPatternCallback):
             ofm_zero_point=0,
             pool_shape=(1, 1),
             ofm_channels=1,
+            ofm_dtype="int32",
             activation=activation,
             clip_min=clip_min,
             clip_max=clip_max,
@@ -1319,6 +1272,7 @@ class Resize2dRewriter(DFPatternCallback):
             ofm_zero_point=int(params.ofm.q_params.zero_point),
             pool_shape=pool_shape,
             ofm_channels=in_channels,
+            ofm_dtype=params.ofm.dtype,
             strides=[1, 1],
             padding=padding,
             upscale="NEAREST",
diff --git a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py 
b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py
index 2d5aff9bec..4d12704acb 100644
--- a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py
+++ b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py
@@ -39,6 +39,7 @@ def _extract_ethosu_pooling_params(attrs, args):
     ofm_zero_point = attrs.ofm_zero_point
     pool_shape = attrs.pool_shape
     ofm_channels = attrs.ofm_channels
+    ofm_dtype = attrs.ofm_dtype
     strides = attrs.strides
     padding = attrs.padding
     activation = attrs.activation
@@ -59,6 +60,7 @@ def _extract_ethosu_pooling_params(attrs, args):
         ofm_zero_point,
         pool_shape,
         ofm_channels,
+        ofm_dtype,
         strides,
         padding,
         activation,
@@ -100,6 +102,7 @@ def ethosu_pooling(
     ofm_zero_point: int,
     pool_shape: Tuple[int, int],
     ofm_channels: int,
+    ofm_dtype: str,
     strides: Tuple[int, int] = (1, 1),
     padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
     activation: str = "NONE",
@@ -121,7 +124,7 @@ def ethosu_pooling(
     lut : tvm.relay.Expr
          The look-up table of values to use if activation = "LUT".
     pooling_type: str
-        The type of the pooling. "AVG" - average pool,   "MAX" - max pool.
+        The type of the pooling. "AVG" - average pool, "MAX" - max pool, "SUM" 
- reduce sum pool.
     ifm_scale : float
         The quantization scale for the Input Feature Map tensor.
     ifm_zero_point : int
@@ -134,6 +137,10 @@ def ethosu_pooling(
         The 2 dimensional pool shape as (pool_shape_height, pool_shape_width).
     ofm_channels : int
         The number of the Output Feature Map channels
+    ofm_dtype : str
+        The Output Feature Map tensor data type.
+            "AVG" or "MAX" pooling - can be "int8", "uint8", or "int16".
+            "SUM" pooling - can be "int32".
     strides : tuple of int, optional
         The 2 dimensional strides as (stride_height, stride_width).
     padding : tuple of int, optional
@@ -179,6 +186,7 @@ def ethosu_pooling(
         ofm_zero_point,
         pool_shape,
         ofm_channels,
+        ofm_dtype,
         strides,
         padding,
         activation,
diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py 
b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py
index 6843046fd0..7308103240 100644
--- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py
+++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py
@@ -36,6 +36,7 @@ def pooling_compute(
     ofm_zero_point: int,
     pool_shape: Tuple[int, int],
     ofm_channels: int,
+    ofm_dtype: str,
     strides: Tuple[int, int],
     padding: Tuple[int, int, int, int],
     activation: str,
@@ -68,6 +69,10 @@ def pooling_compute(
         The 2 dimensional pool shape as (pool_shape_height, pool_shape_width).
     ofm_channels : int
         The number of the Output Feature Map channels
+    ofm_dtype : str
+        The Output Feature Map tensor data type.
+            "AVG" or "MAX" pooling - can be "int8", "uint8", or "int16".
+            "SUM" pooling - can be "int32".
     strides : Tuple[int, int]
         The 2 dimensional strides as (stride_height, stride_width).
     padding : Tuple[int, int, int, int]
@@ -124,7 +129,6 @@ def pooling_compute(
     rh = te.reduce_axis((0, pool_shape_h), name="ry")
     rw = te.reduce_axis((0, pool_shape_w), name="rx")
     rc = te.reduce_axis((0, 1 if pooling_type != "SUM" else ifm_channels), 
name="rc")
-    ofm_dtype = ifm.dtype if pooling_type != "SUM" else "int32"
 
     pooling_attrs = {
         "op": "ethosu_pooling",
diff --git a/python/tvm/relay/op/contrib/ethosu.py 
b/python/tvm/relay/op/contrib/ethosu.py
index d74140da5d..8ec06d3a92 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -1336,30 +1336,46 @@ class MeanParams:
                 return axis in ([0], [1], [0, 1])
             return axis in ([1], [2], [1, 2])
 
-        tensor_params = [self.ifm, self.ofm]
-        if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]):
+        def check_single_axis_across_height(num_dims, axis):
+            return len(axis) == 1 and (num_dims in (2, 3) and axis == [0] or 
axis == [1])
+
+        same_quantization = (
+            self.ifm.q_params.scale_f32 == self.ofm.q_params.scale_f32
+            and self.ifm.q_params.zero_point == self.ofm.q_params.zero_point
+        )
+
+        # IFM must be int8 or uint8
+        if not check_valid_dtypes([self.ifm], [np.int8, np.uint8]):
             return False
-        if self.ifm.dtype != self.ofm.dtype:
+        # OFM must be int8, uint8 or int16
+        if not check_valid_dtypes([self.ofm], [np.int8, np.uint8, np.int16]):
             return False
+        # Input tensor must be at least 2D
         if not len(self.ifm.shape) in [2, 3, 4]:
             return False
+        # Axis indices must correspond to height and width axes
         if not check_axis(len(self.ifm.shape), self.axis):
             return False
 
-        # MEAN has further restrictions on the input size, depending on 
legalization method.
         input_size = self.height * self.width
+
+        # Product of height and width must be no greater than 65536
         if input_size > 65536:
             return False
-        if (
-            self.ifm.q_params.scale_f32 != self.ofm.q_params.scale_f32
-            or self.ifm.q_params.zero_point != self.ofm.q_params.zero_point
-        ) and input_size > 4096:
-            return False
-        if self.axis == [1, 2] and self.keepdims and self.ifm.dtype == "int8" 
and input_size > 256:
-            return False
-        # Large kernel height reshape only when axis is [1, 2]
-        if self.axis != [1, 2] and self.height > 64:
+        # Product of height and width must be no greater than 4096 when:
+        #   IFM and OFM have different scale or zero point; or
+        #   'keep_dims' is True
+        if input_size > 4096 and (not same_quantization or self.keepdims):
             return False
+        # For single axis averages across the height dimension:
+        if check_single_axis_across_height(len(self.ifm.shape), self.axis):
+            # IFM height must be no greater than 256 if the IFM and OFM scale 
and zero point match
+            if self.height > 256 and same_quantization:
+                return False
+            # IFM height must be no greater than 64 if the IFM and OFM scale 
or zero point
+            # do not match
+            if self.height > 64 and not same_quantization:
+                return False
         return True
 
 
diff --git a/src/relay/op/contrib/ethosu/op_attrs.h 
b/src/relay/op/contrib/ethosu/op_attrs.h
index e4ba2cfb9b..74e7fe856e 100644
--- a/src/relay/op/contrib/ethosu/op_attrs.h
+++ b/src/relay/op/contrib/ethosu/op_attrs.h
@@ -349,6 +349,7 @@ struct EthosuPoolingAttrs : public 
tvm::AttrsNode<EthosuPoolingAttrs> {
   int ofm_zero_point;
   Array<IndexExpr> pool_shape;
   IndexExpr ofm_channels;
+  String ofm_dtype;
   Array<IndexExpr> strides;
   Array<IndexExpr> padding;
   String activation;
@@ -376,6 +377,10 @@ struct EthosuPoolingAttrs : public 
tvm::AttrsNode<EthosuPoolingAttrs> {
     TVM_ATTR_FIELD(ofm_channels)
         .describe(" The number of the Output Feature Map channels.")
         .set_default(NullValue<IndexExpr>());
+    TVM_ATTR_FIELD(ofm_dtype).describe(
+        "The Output Feature Map tensor data type. "
+        "'AVG' or 'MAX' pooling - can be 'int8', 'uint8', or 'int16'. "
+        "'SUM' pooling - can be 'int32'.");
     TVM_ATTR_FIELD(strides)
         .set_default(Array<IndexExpr>({1, 1}))
         .describe("The 2 dimensional strides as (stride_height, 
stride_width).");
diff --git a/src/relay/op/contrib/ethosu/pooling.cc 
b/src/relay/op/contrib/ethosu/pooling.cc
index a9c072a011..92e704f667 100644
--- a/src/relay/op/contrib/ethosu/pooling.cc
+++ b/src/relay/op/contrib/ethosu/pooling.cc
@@ -61,15 +61,27 @@ bool EthosuPoolingRel(const Array<Type>& types, int 
num_inputs, const Attrs& att
                                                             DataType::Int(16), 
DataType::Int(32)};
 
   std::initializer_list<DataType>& allowed_ifm_dtypes = 
max_avg_pooling_ifm_dtypes;
-  auto ofm_dtype = ifm->dtype;
   if (param->pooling_type == "SUM") {
     allowed_ifm_dtypes = sum_pooling_ifm_dtypes;
-    ofm_dtype = DataType::Int(32);
   }
 
   CheckDataType(reporter, ifm->dtype, allowed_ifm_dtypes, operator_name, "ifm",
                 param->pooling_type);
 
+  DataType ofm_dtype = DataTypeFromString(param->ofm_dtype);
+
+  std::initializer_list<DataType> max_avg_pooling_ofm_dtypes = 
{DataType::Int(8), DataType::UInt(8),
+                                                                
DataType::Int(16)};
+  if (param->pooling_type == "AVG" || param->pooling_type == "MAX") {
+    CheckDataType(reporter, ofm_dtype, max_avg_pooling_ofm_dtypes, 
operator_name, "ofm",
+                  param->pooling_type);
+    CheckDataTypeMatch(reporter, ofm_dtype, ifm->dtype, operator_name, "ifm", 
"ofm",
+                       param->pooling_type);
+  } else {
+    CheckDataType(reporter, ofm_dtype, {DataType::Int(32)}, operator_name, 
"ofm",
+                  param->pooling_type);
+  }
+
   CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, 
operator_name);
 
   Array<IndexExpr> ifm_shape = ifm->shape;
@@ -88,7 +100,7 @@ bool EthosuPoolingRel(const Array<Type>& types, int 
num_inputs, const Attrs& att
 
 Expr MakeEthosuPooling(Expr ifm, Expr lut, String pooling_type, double 
ifm_scale,
                        int ifm_zero_point, double ofm_scale, int 
ofm_zero_point,
-                       Array<IndexExpr> pool_shape, IndexExpr ofm_channels,
+                       Array<IndexExpr> pool_shape, IndexExpr ofm_channels, 
String ofm_dtype,
                        Array<IndexExpr> strides, Array<IndexExpr> padding, 
String activation,
                        int clip_min, int clip_max, String rounding_mode, 
String upscale,
                        String ifm_layout, String ofm_layout) {
@@ -100,6 +112,7 @@ Expr MakeEthosuPooling(Expr ifm, Expr lut, String 
pooling_type, double ifm_scale
   attrs->ofm_zero_point = ofm_zero_point;
   attrs->pool_shape = std::move(pool_shape);
   attrs->ofm_channels = std::move(ofm_channels);
+  attrs->ofm_dtype = std::move(ofm_dtype);
   attrs->strides = std::move(strides);
   attrs->padding = std::move(padding);
   attrs->activation = std::move(activation);
diff --git 
a/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py 
b/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py
index 38aeee05f9..1faec87ba2 100644
--- a/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py
+++ b/tests/python/contrib/test_ethosu/cascader/test_ethosu_pooling_matcher.py
@@ -49,6 +49,7 @@ def test_ethosu_pooling_matcher(pool_shape, stride, padding, 
ifm_layout, ofm_lay
         ofm_zero_point=0,
         pool_shape=pool_shape,
         ofm_channels=ofm_channels,
+        ofm_dtype="int8",
         strides=stride,
         padding=padding,
         activation="NONE",
diff --git a/tests/python/contrib/test_ethosu/infra.py 
b/tests/python/contrib/test_ethosu/infra.py
index b205a6d335..c621155827 100644
--- a/tests/python/contrib/test_ethosu/infra.py
+++ b/tests/python/contrib/test_ethosu/infra.py
@@ -639,6 +639,7 @@ def make_ethosu_pooling(
     pooling_type,
     pool_shape,
     ofm_channels,
+    ofm_dtype,
     strides,
     padding,
     activation="NONE",
@@ -657,6 +658,7 @@ def make_ethosu_pooling(
         ofm_zero_point=0,
         pool_shape=pool_shape,
         ofm_channels=ofm_channels,
+        ofm_dtype=ofm_dtype,
         strides=strides,
         padding=padding,
         activation=activation,
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py 
b/tests/python/contrib/test_ethosu/test_codegen.py
index 1df9e88914..14441d8e93 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -391,31 +391,29 @@ def test_binary_add_with_non_4d_shapes(
     )
 
 
[email protected](reason="See https://github.com/apache/tvm/issues/12634";)
 @pytest.mark.parametrize(
     "accel_type",
     ACCEL_TYPES,
 )
 @pytest.mark.parametrize(
-    "ifm_shape, axis, keep_dims, use_same_quantization",
+    "ifm_shape, axis, keep_dims, use_same_quantization, dtype",
     [
-        # mean to depthwise + multiply
-        [(1, 8, 16, 16), (1, 2), True, False],
-        [(1, 3, 4), (0, 1), True, False],
-        [(1, 65, 2, 1), (1, 2), True, False],  # special case when h > 64
         # mean to average pool
-        [(1, 8, 16, 16), (2,), False, True],
-        [(3, 3, 4), (0,), True, True],
-        [(8, 5), (0,), False, True],
+        [(1, 8, 16, 16), (2,), False, True, "int8"],
+        [(1, 8, 16, 16), (2,), False, True, "uint8"],
+        [(3, 3, 4), (0,), True, True, "int8"],
+        [(8, 5), (0,), False, True, "int8"],
         # mean to depthwise
-        [(1, 8, 16, 16), (2,), True, False],
-        [(1, 8, 16, 16), (2, 1), False, False],
-        [(8, 4), (0,), False, False],
+        [(1, 8, 16, 16), (2,), True, False, "int8"],
+        [(1, 8, 16, 16), (2,), True, False, "uint8"],
+        [(1, 8, 16, 16), (2, 1), False, False, "int8"],
+        [(8, 4), (0,), False, False, "int8"],
+        [(1, 65, 2, 1), (1, 2), True, False, "int8"],  # special case when h > 
64
+        [(1, 65, 2, 1), (1, 2), True, False, "uint8"],  # special case when h 
> 64
     ],
 )
-def test_mean(accel_type, ifm_shape, axis, keep_dims, use_same_quantization):
+def test_mean(accel_type, ifm_shape, axis, keep_dims, use_same_quantization, 
dtype):
     np.random.seed(0)
-    dtype = "int8"
 
     def create_mod_from_tflite():
         class Model(tf.Module):
@@ -462,12 +460,14 @@ def test_mean(accel_type, ifm_shape, axis, keep_dims, 
use_same_quantization):
             input_zero_point=relay.const(0, dtype="int32"),
             output_scale=relay.const(1.0, dtype="float32"),
             output_zero_point=relay.const(0, dtype="int32"),
+            out_dtype=dtype,
         )
 
         func = relay.Function(relay.analysis.free_vars(requantize), requantize)
         mod = tvm.IRModule.from_expr(func)
 
-        input_data = {"input": np.random.randint(low=-127, high=128, 
size=ifm_shape, dtype=dtype)}
+        low, high = (0, 256) if dtype == "uint8" else (-127, 128)
+        input_data = {"input": np.random.randint(low=low, high=high, 
size=ifm_shape, dtype=dtype)}
         output_data = generate_ref_data(mod, input_data)
         return mod, input_data, output_data
 
@@ -546,6 +546,7 @@ def test_add_reduce_sum(dtype):
             pooling_type="SUM",
             pool_shape=(1, 1),
             ofm_channels=1,
+            ofm_dtype="int32",
             strides=(1, 1),
             padding=(0, 0, 0, 0),
             rounding_mode="NATURAL",
diff --git a/tests/python/contrib/test_ethosu/test_identity_optimizer.py 
b/tests/python/contrib/test_ethosu/test_identity_optimizer.py
index f90f0f2e62..3ae58dfc81 100644
--- a/tests/python/contrib/test_ethosu/test_identity_optimizer.py
+++ b/tests/python/contrib/test_ethosu/test_identity_optimizer.py
@@ -78,12 +78,14 @@ def test_simple_strided_slice_identity_removal():
     in the graph and a compute operation follows."""
 
     def get_graph(get_expected=False):
-        x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
-        x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 4, (1, 1), (0, 0))
+        dtype = "int8"
+
+        x = relay.var("x", shape=(1, 2, 2, 4), dtype=dtype)
+        x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 4, dtype, (1, 1), (0, 
0))
         x = relay.strided_slice(x, begin=[0, 0, 0, 0], end=[1, 2, 2, 2])
         if not get_expected:
             x = infra.make_ethosu_identity(x)
-        x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 2, (1, 1), (0, 0))
+        x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 2, dtype, (1, 1), (0, 
0))
         return relay.Function(relay.analysis.free_vars(x), x)
 
     actual = _optimize(get_graph())
@@ -95,9 +97,11 @@ def test_no_identity():
     """Check the graph is not affected when there is no identity in the 
graph."""
 
     def get_graph():
-        x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
+        dtype = "int8"
+
+        x = relay.var("x", shape=(1, 2, 2, 4), dtype=dtype)
         x = infra.make_ethosu_conv2d(x, 4, 4, (1, 1), (0, 0), (1, 1), (1, 1))
-        x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 4, (1, 1), (0, 0))
+        x = infra.make_ethosu_pooling(x, "MAX", (1, 1), 4, dtype, (1, 1), (0, 
0))
         x = infra.make_ethosu_depthwise_conv2d(x, 4, (1, 1), (0, 0), (1, 1), 
(1, 1))
         x = infra.make_ethosu_unary_elementwise(x, 4, "ABS")
         return relay.Function(relay.analysis.free_vars(x), x)
diff --git a/tests/python/contrib/test_ethosu/test_layout_optimizer.py 
b/tests/python/contrib/test_ethosu/test_layout_optimizer.py
index 05b9dce4c9..69d549acbb 100644
--- a/tests/python/contrib/test_ethosu/test_layout_optimizer.py
+++ b/tests/python/contrib/test_ethosu/test_layout_optimizer.py
@@ -147,6 +147,7 @@ def test_add_reduce_sum(dtype):
             pooling_type="SUM",
             pool_shape=(1, 1),
             ofm_channels=1,
+            ofm_dtype="int32",
             strides=(1, 1),
             padding=(0, 0),
             ifm_layout=layout,
@@ -330,13 +331,16 @@ def test_ignore_concatnate_with_layout_transform():
     """
 
     def get_graph():
-        in_1 = relay.var("x", shape=(1, 16, 16, 8), dtype="int8")
-        in_2 = relay.var("y", shape=(1, 16, 16, 8), dtype="int8")
+        dtype = "int8"
+
+        in_1 = relay.var("x", shape=(1, 16, 16, 8), dtype=dtype)
+        in_2 = relay.var("y", shape=(1, 16, 16, 8), dtype=dtype)
         pool_1 = infra.make_ethosu_pooling(
             in_1,
             "MAX",
             (1, 1),
             ofm_channels=8,
+            ofm_dtype=dtype,
             strides=(1, 1),
             padding=(0, 0),
             ifm_layout="NHWC",
@@ -347,6 +351,7 @@ def test_ignore_concatnate_with_layout_transform():
             "MAX",
             (1, 1),
             ofm_channels=8,
+            ofm_dtype=dtype,
             strides=(1, 1),
             padding=(0, 0),
             ifm_layout="NHWC",
@@ -358,6 +363,7 @@ def test_ignore_concatnate_with_layout_transform():
             "MAX",
             (1, 1),
             ofm_channels=8,
+            ofm_dtype=dtype,
             strides=(1, 1),
             padding=(0, 0),
             ifm_layout="NHWC",
@@ -385,12 +391,15 @@ def test_multiple_inputs():
     def get_graph():
         poolings = []
         for _ in range(3):
-            inp = relay.var("x", shape=(1, 3, 3, 4), dtype="int8")
+            dtype = "int8"
+
+            inp = relay.var("x", shape=(1, 3, 3, 4), dtype=dtype)
             pool = infra.make_ethosu_pooling(
                 inp,
                 "MAX",
                 (1, 1),
                 ofm_channels=4,
+                ofm_dtype=dtype,
                 strides=(1, 1),
                 padding=(0, 0),
                 ifm_layout="NHWC",
@@ -428,12 +437,15 @@ def test_multiple_outputs():
     """
 
     def get_graph(get_expected=False):
-        in_1 = relay.var("x", shape=(1, 4, 4, 8), dtype="int8")
+        dtype = "int8"
+
+        in_1 = relay.var("x", shape=(1, 4, 4, 8), dtype=dtype)
         pool_1 = infra.make_ethosu_pooling(
             in_1,
             "MAX",
             (1, 1),
             ofm_channels=4,
+            ofm_dtype=dtype,
             strides=(1, 1),
             padding=(0, 0),
             ifm_layout="NHWC",
@@ -447,6 +459,7 @@ def test_multiple_outputs():
                     "MAX",
                     (1, 1),
                     ofm_channels=4,
+                    ofm_dtype=dtype,
                     strides=(1, 1),
                     padding=(0, 0),
                     ifm_layout="NHCWB16" if get_expected else "NHWC",
@@ -527,7 +540,9 @@ def test_multiple_pooling():
     """
 
     def get_graph(get_expected=False):
-        x = relay.var("x", shape=(1, 8, 8, 4), dtype="int8")
+        dtype = "int8"
+
+        x = relay.var("x", shape=(1, 8, 8, 4), dtype=dtype)
         for i in range(3):
             ifm_layout = "NHCWB16" if get_expected and i != 0 else "NHWC"
             ofm_layout = "NHCWB16" if get_expected and i != 2 else "NHWC"
@@ -536,6 +551,7 @@ def test_multiple_pooling():
                 "MAX",
                 (1, 1),
                 ofm_channels=4,
+                ofm_dtype=dtype,
                 strides=(1, 1),
                 padding=(0, 0),
                 ifm_layout=ifm_layout,
@@ -594,8 +610,9 @@ def test_op_without_ethosu_consumer():
 
     def get_graph(get_expected=False):
         exp_layout = "NHCWB16" if get_expected else "NHWC"
+        dtype = "int8"
 
-        x = relay.var("x", shape=(1, 2, 2, 2), dtype="int8")
+        x = relay.var("x", shape=(1, 2, 2, 2), dtype=dtype)
         depthwise = infra.make_ethosu_depthwise_conv2d(
             x, 2, (1, 1), (0, 0), (1, 1), (0, 0), ofm_layout=exp_layout
         )
@@ -609,7 +626,7 @@ def test_op_without_ethosu_consumer():
             (0, 0),
             ifm_layout=exp_layout,
         )
-        pool = infra.make_ethosu_pooling(conv, "MAX", (1, 1), 2, (1, 1), (0, 
0))
+        pool = infra.make_ethosu_pooling(conv, "MAX", (1, 1), 2, dtype, (1, 
1), (0, 0))
         concat = relay.concatenate([conv, pool], axis=0)
         return relay.Function(relay.analysis.free_vars(concat), concat)
 
@@ -639,21 +656,31 @@ def test_diamond_graph():
 
     def get_graph(get_expected=False):
         exp_layout = "NHCWB16" if get_expected else "NHWC"
-        x = relay.var("x", shape=(1, 2, 2, 2), dtype="int8")
+        dtype = "int8"
+
+        x = relay.var("x", shape=(1, 2, 2, 2), dtype=dtype)
         pool_1 = infra.make_ethosu_pooling(
-            x, "MAX", (1, 1), 2, (1, 1), (0, 0), ofm_layout=exp_layout
+            x, "MAX", (1, 1), 2, dtype, (1, 1), (0, 0), ofm_layout=exp_layout
         )
         pool_2 = infra.make_ethosu_pooling(
-            pool_1, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout
+            pool_1, "MAX", (1, 1), 2, dtype, (1, 1), (0, 0), 
ifm_layout=exp_layout
         )
         pool_3 = infra.make_ethosu_pooling(
-            pool_2, "MAX", (1, 1), 2, (1, 1), (0, 0), ofm_layout=exp_layout
+            pool_2, "MAX", (1, 1), 2, dtype, (1, 1), (0, 0), 
ofm_layout=exp_layout
         )
         pool_4 = infra.make_ethosu_pooling(
-            pool_3, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout, 
ofm_layout=exp_layout
+            pool_3,
+            "MAX",
+            (1, 1),
+            2,
+            dtype,
+            (1, 1),
+            (0, 0),
+            ifm_layout=exp_layout,
+            ofm_layout=exp_layout,
         )
         pool_5 = infra.make_ethosu_pooling(
-            pool_4, "MAX", (1, 1), 2, (1, 1), (0, 0), ifm_layout=exp_layout
+            pool_4, "MAX", (1, 1), 2, dtype, (1, 1), (0, 0), 
ifm_layout=exp_layout
         )
         concat = relay.concatenate([pool_2, pool_5], axis=0)
         return relay.Function(relay.analysis.free_vars(concat), concat)
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py 
b/tests/python/contrib/test_ethosu/test_legalize.py
index 594f4a0e2a..6330930fa5 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -1535,15 +1535,10 @@ def test_tflite_tanh_legalize():
     assert tuple(func_body.args[1].checked_type.shape) == (256,)
 
 
[email protected]("dtype", ["int8", "uint8"])
 @pytest.mark.parametrize(
     "ifm_shape, axis, keep_dims, use_same_quantization",
     [
-        # mean to depthwise + multiply
-        [(1, 8, 16, 16), (1, 2), True, False],
-        [(1, 8, 16, 16), (2, 1), True, False],
-        [(1, 3, 4), (0, 1), True, False],
-        [(8, 5), (1, 0), True, False],
-        [(1, 65, 2, 1), (1, 2), True, False],  # special case when h > 64
         # mean to average pool
         [(1, 8, 16, 16), (1,), True, True],
         [(1, 8, 16, 16), (2,), False, True],
@@ -1557,11 +1552,10 @@ def test_tflite_tanh_legalize():
         [(1, 8, 16, 16), (2,), True, False],
         [(1, 8, 16, 16), (1, 2), False, False],
         [(8, 4), (0,), False, False],
+        [(1, 65, 2, 1), (1, 2), True, False],  # special case when h > 64
     ],
 )
-def test_mean(ifm_shape, axis, keep_dims, use_same_quantization):
-    dtype = "int8"
-
+def test_mean(ifm_shape, axis, keep_dims, use_same_quantization, dtype):
     def create_tflite_graph():
         class Model(tf.Module):
             @tf.function
@@ -1606,6 +1600,7 @@ def test_mean(ifm_shape, axis, keep_dims, 
use_same_quantization):
             input_zero_point=relay.const(0, dtype="int32"),
             output_scale=relay.const(1.0, dtype="float32"),
             output_zero_point=relay.const(0, dtype="int32"),
+            out_dtype=dtype,
         )
 
         func = relay.Function(relay.analysis.free_vars(requantize), requantize)
@@ -1616,7 +1611,6 @@ def test_mean(ifm_shape, axis, keep_dims, 
use_same_quantization):
         out_var = ext_func.body
 
         next_op = out_var
-        mul_op = None
         pooling_op = None
         depthwise_op = None
         if (
@@ -1625,9 +1619,6 @@ def test_mean(ifm_shape, axis, keep_dims, 
use_same_quantization):
             and next_op.op.name == "reshape"
         ):
             next_op = next_op.args[0]
-        if util.is_named_ethosu_op(next_op, "binary_elementwise"):
-            mul_op = next_op
-            next_op = next_op.args[0]
         if util.is_named_ethosu_op(next_op, "pooling"):
             pooling_op = next_op
             next_op = next_op.args[0]
@@ -1654,24 +1645,33 @@ def test_mean(ifm_shape, axis, keep_dims, 
use_same_quantization):
 
         # check IFM
         assert tuple(in_var.checked_type.shape) == ifm_shape
-        assert in_var.checked_type.dtype == dtype
+
+        if use_same_quantization:
+            assert in_var.checked_type.dtype == dtype
+        else:
+            # in_var's dtype is equal to int8 due to TFLite's requantize
+            assert in_var.checked_type.dtype == "int8"
 
         # check OFM
         assert tuple(out_var.checked_type.shape) == out_shape
-        assert out_var.checked_type.dtype == dtype
+        if use_same_quantization:
+            assert out_var.checked_type.dtype == dtype
+        else:
+            # out_var's dtype is equal to int8 due to TFLite's requantize
+            assert out_var.checked_type.dtype == "int8"
 
         # check expected legalization case
-        if axis in [(1, 2), (2, 1), (0, 1), (1, 0)] and keep_dims and dtype == 
"int8":
-            assert depthwise_op and mul_op
-            assert mul_op.attrs.operator_type == "MUL"
-        elif pooling_op:
+        if pooling_op:
             attrs = pooling_op.attrs
             assert (
                 attrs.ifm_scale == attrs.ofm_scale and attrs.ifm_zero_point == 
attrs.ofm_zero_point
             )
         else:
             assert depthwise_op
-            assert not mul_op
+            attrs = depthwise_op.attrs
+            assert (
+                attrs.ifm_scale != attrs.ofm_scale or attrs.ifm_zero_point != 
attrs.ofm_zero_point
+            )
 
     rewriter = legalize.MeanRewriter()
     pattern_table = [
diff --git a/tests/python/contrib/test_ethosu/test_replace_pooling.py 
b/tests/python/contrib/test_ethosu/test_replace_pooling.py
index 1ef59e0b9b..e4438eb62a 100644
--- a/tests/python/contrib/test_ethosu/test_replace_pooling.py
+++ b/tests/python/contrib/test_ethosu/test_replace_pooling.py
@@ -169,12 +169,15 @@ def test_avg_max_pooling_single(
     # hardcoded padding values are used for each case.
     padding = (1, 1, 1, 0) if upscale == "NONE" else (0, 0, 0, 0)
 
-    ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
+    dtype = "int8"
+
+    ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
     pooling = make_ethosu_pooling(
         ifm,
         pooling_type,
         pool_shape,
         ofm_channels,
+        dtype,
         strides,
         padding,
         activation,
@@ -232,6 +235,7 @@ def test_sum_pooling_single(
         pooling_type="SUM",
         pool_shape=(1, 1),
         ofm_channels=1,
+        ofm_dtype="int32",
         strides=(1, 1),
         padding=(0, 0, 0, 0),
         activation=activation,
@@ -276,13 +280,15 @@ def test_correct_stride_with_multiple_pooling():
     pool_shape = (1, 1)
     strides = (1, 1)
     padding = (0, 0, 0, 0)
+    dtype = "int8"
 
-    ifm = relay.var("ifm", shape=ifm_shape, dtype="int8")
+    ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype)
     op = make_ethosu_pooling(
         ifm,
         pooling_type,
         pool_shape,
         ofm_channels,
+        dtype,
         strides,
         padding,
         ifm_layout="NHWC",
@@ -293,6 +299,7 @@ def test_correct_stride_with_multiple_pooling():
         pooling_type,
         pool_shape,
         ofm_channels,
+        dtype,
         strides,
         padding,
         ifm_layout="NHCWB16",
diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py 
b/tests/python/contrib/test_ethosu/test_type_inference.py
index 380d7532f8..48a4dbde81 100644
--- a/tests/python/contrib/test_ethosu/test_type_inference.py
+++ b/tests/python/contrib/test_ethosu/test_type_inference.py
@@ -201,6 +201,7 @@ def test_ethosu_pooling_type_inference(
         pooling_type,
         pool_shape,
         ofm_channels,
+        dtype,
         strides,
         padding,
         ifm_layout=ifm_layout,
@@ -215,6 +216,7 @@ def test_ethosu_pooling_type_inference(
 def test_ethosu_pooling_invalid_pooling_type():
     invalid_pooling_type = "A"
     dtype = "int8"
+
     ifm = relay.var("ifm", shape=[1, 56, 72, 55], dtype=dtype)
     pool_shape = (3, 2)
     ofm_channels = 55
@@ -225,6 +227,7 @@ def test_ethosu_pooling_invalid_pooling_type():
         invalid_pooling_type,
         pool_shape,
         ofm_channels,
+        dtype,
         strides,
         padding,
     )
@@ -246,6 +249,7 @@ def test_ethosu_pooling_invalid_dtype():
         pooling_type,
         pool_shape,
         ofm_channels,
+        "int8",
         strides,
         padding,
     )
@@ -256,12 +260,15 @@ def test_ethosu_pooling_invalid_dtype():
 
 def test_ethosu_pooling_invalid_upscale_method():
     invalid_upscale_method = "FOO"
-    ifm = relay.var("ifm", shape=[1, 56, 72, 55], dtype="int8")
+    dtype = "int8"
+
+    ifm = relay.var("ifm", shape=[1, 56, 72, 55], dtype=dtype)
     pooling = make_ethosu_pooling(
         ifm,
         "MAX",
         (3, 2),
         55,
+        dtype,
         (1, 2),
         (0, 1, 2, 3),
         upscale=invalid_upscale_method,


Reply via email to