This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1566fb1  [microNPU] Update Conv2D Tests to Use TF API to Gen Test 
Cases (#9508)
1566fb1 is described below

commit 1566fb161446cd80552a29b100896175f8272ac0
Author: Dhruv Chauhan <[email protected]>
AuthorDate: Thu Dec 9 13:28:26 2021 +0000

    [microNPU] Update Conv2D Tests to Use TF API to Gen Test Cases (#9508)
    
    * Current conv2d tests compare the conv2d operator against tvm's execution 
of the default schedule of conv2d as defined in TOPI and that is not bitexact 
with tflite runtime's implemention. Therefore a tolerance of "1" in quantized 
8-bit domain is used.
    
    * Converts the current conv2d tests to use TensorFlow APIs to create a test 
cases for conv2D and compare against TFLite runtime.
---
 .../tvm/relay/backend/contrib/ethosu/__init__.py   |   1 -
 python/tvm/relay/backend/contrib/ethosu/errors.py  |  35 ---
 .../tvm/relay/backend/contrib/ethosu/legalize.py   |   3 -
 .../python/contrib/test_ethosu/relay_ir_builder.py | 295 --------------------
 tests/python/contrib/test_ethosu/test_codegen.py   | 297 +++++++++++++--------
 tests/python/contrib/test_ethosu/test_legalize.py  | 226 ++++++++--------
 6 files changed, 295 insertions(+), 562 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/__init__.py 
b/python/tvm/relay/backend/contrib/ethosu/__init__.py
index ed04c20..c4948d5 100644
--- a/python/tvm/relay/backend/contrib/ethosu/__init__.py
+++ b/python/tvm/relay/backend/contrib/ethosu/__init__.py
@@ -18,7 +18,6 @@
 from . import util
 from . import legalize
 from . import preprocess
-from . import errors
 from . import codegen
 from . import vela_api
 from . import tir_to_cs_translator
diff --git a/python/tvm/relay/backend/contrib/ethosu/errors.py 
b/python/tvm/relay/backend/contrib/ethosu/errors.py
deleted file mode 100644
index 65f3711..0000000
--- a/python/tvm/relay/backend/contrib/ethosu/errors.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=super-init-not-called
-"""This module defines all error types associated with the Arm(R) Ethos(TM)-U 
NPU code generator."""
-
-
-class EthosUCodegenError(Exception):
-    """Base class for all exceptions related to code generation"""
-
-    def __init__(self, data):
-        self.message = "EthosUCodegenError:" + data
-
-    def __str__(self):
-        return self.message
-
-
-class UnsupportedLayout(EthosUCodegenError):
-    """Raised when unsupported layout is encountered during code generation."""
-
-    def __init__(self, layout):
-        super().__init__(f"Unsupported Layout {layout}")
diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py 
b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index b2264f3..0db8db9 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -30,7 +30,6 @@ from tvm.relay.dataflow_pattern import is_op
 from tvm.relay.dataflow_pattern import rewrite
 from tvm.relay.dataflow_pattern import CallPattern
 from tvm.relay.backend.contrib.ethosu import op as ethosu_ops  # type: ignore
-from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout  # type: 
ignore
 from tvm.relay.backend.contrib.ethosu import vela_api
 from tvm.relay.backend.contrib.ethosu import util
 from tvm.relay.op.contrib import ethosu as ethosu_patterns  # type: ignore
@@ -266,8 +265,6 @@ class Conv2DRewriter(DFPatternCallback):
         channels_map = {
             "NHWC": 3,
         }
-        if str(params.ofm.layout) not in channels_map.keys():
-            raise UnsupportedLayout(str(params.ofm.layout))
         kernel_size_map = {
             "HWIO": params.weights.shape[0:2],
             "OHWI": params.weights.shape[1:3],
diff --git a/tests/python/contrib/test_ethosu/relay_ir_builder.py 
b/tests/python/contrib/test_ethosu/relay_ir_builder.py
deleted file mode 100644
index 6169a3e..0000000
--- a/tests/python/contrib/test_ethosu/relay_ir_builder.py
+++ /dev/null
@@ -1,295 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Helper module to build relay operations for testing"""
-
-from pathlib import Path
-import numpy as np
-import math
-
-import tvm
-from tvm import relay
-from tvm.relay.op.contrib import get_pattern_table
-from tvm.relay import qnn
-from tvm.relay.backend.contrib.ethosu.util import get_range_for_dtype_str
-
-
-class TensorType:
-    """A data structure to capture tensor parameters"""
-
-    def __init__(self):
-        self.shape = None
-        self.dtype = None
-        self.zp = None
-        self.sc = None
-        self.layout = None
-
-    def get_dim_size(self, dim):
-        for idx, char in enumerate(self.layout):
-            if dim == char:
-                return self.shape[idx]
-        return None
-
-    def get_dim_index(self, dim):
-        for idx, char in enumerate(self.layout):
-            if dim == char:
-                return idx
-        return None
-
-
-class QnnConv2DParams:
-    """A data structure to capture relay.qnn.op.conv2D parameters"""
-
-    def __init__(self, dtype):
-        self.ifm = TensorType()
-        self.ofm = TensorType()
-        self.kernel = TensorType()
-
-        # default values
-        self.ifm.dtype = dtype
-        self.ifm.layout = "NHWC"
-        ifm_min, ifm_max = get_range_for_dtype_str(self.ifm.dtype)
-        self.ifm.zp = relay.const(np.random.randint(ifm_min, ifm_max), "int32")
-        self.ifm.sc = relay.const(np.random.random() * 2, "float32")
-        self.kernel.dtype = dtype
-        self.kernel.layout = "HWIO"
-        kernel_min, kernel_max = get_range_for_dtype_str(self.kernel.dtype)
-        self.kernel.zp = relay.const(np.random.randint(kernel_min, 
kernel_max), "int32")
-        self.kernel.sc = relay.const(np.random.random() * 2, "float32")
-        self.ofm.layout = "NHWC"
-        self.ofm.dtype = dtype
-        ofm_min, ofm_max = get_range_for_dtype_str(self.ofm.dtype)
-        self.ofm.zp = relay.const(np.random.randint(ofm_min, ofm_max), "int32")
-        self.ofm.sc = relay.const(np.random.random() * 2, "float32")
-        self.dilation = (1, 1)
-
-        self.strides = None
-        self.pad = None
-        self.activation = "NONE"
-        self.clip_min = 0
-        self.clip_max = 0
-
-    def update_output_qnn_params(
-        self, input_dtype="uint8", kernel_dtype="uint8", output_dtype="uint8"
-    ):
-        _, dtype_max = get_range_for_dtype_str(input_dtype)
-        input_max = self.ifm.sc.data.asnumpy() * (dtype_max - 
self.ifm.zp.data.asnumpy())
-        input_min = -self.ifm.sc.data.asnumpy() * self.ifm.zp.data.asnumpy()
-        _, dtype_max = get_range_for_dtype_str(kernel_dtype)
-        kernel_max = np.max(
-            self.kernel.sc.data.asnumpy() * (dtype_max - 
self.kernel.zp.data.asnumpy())
-        )
-        kernel_min = np.min(-self.kernel.sc.data.asnumpy() * 
self.kernel.zp.data.asnumpy())
-        kernel_h = self.kernel.get_dim_size("H")
-        kernel_w = self.kernel.get_dim_size("W")
-        channels = self.kernel.get_dim_size("I")
-        output_limits = [
-            kernel_max * kernel_h * kernel_w * channels * input_max,
-            kernel_min * kernel_h * kernel_w * channels * input_max,
-            kernel_min * kernel_h * kernel_w * channels * input_min,
-            kernel_max * kernel_h * kernel_w * channels * input_min,
-        ]
-        output_max = max(output_limits)
-        output_min = min(output_limits)
-        dtype_min, dtype_max = get_range_for_dtype_str(input_dtype)
-        self.ofm.sc = relay.const((output_max - output_min) / (dtype_max - 
dtype_min), "float32")
-        self.ofm.zp = relay.const(-int(output_min / 
self.ofm.sc.data.asnumpy()), "int32")
-
-
-class PoolingParams:
-    """A data structure to capture relay.op.max_pool2d /
-    relay.op.avg_pool2d parameters
-    """
-
-    def __init__(self, dtype):
-        self.type = None
-        self.size = None
-        self.strides = None
-        self.pad = None
-        self.layout = None
-        self.ifm = TensorType()
-        self.ofm = TensorType()
-
-        # default values
-        self.ifm.dtype = dtype
-        self.ifm.layout = "NHWC"
-        self.ifm.zp = relay.const(np.random.randint(0, 255), "int32")
-        self.ifm.sc = relay.const(np.random.random() * 2, "float32")
-        self.ofm.zp = relay.const(np.random.randint(0, 255), "int32")
-        self.ofm.sc = relay.const(np.random.random() * 2, "float32")
-        self.ofm.dtype = dtype
-        self.dilation = (1, 1)
-
-
-class AddParams:
-    """A data structure to capture relay.qnn.op.add parameters"""
-
-    def __init__(self, dtype):
-        self.ifm0 = TensorType()
-        self.ifm1 = TensorType()
-        self.ofm = TensorType()
-
-        # default values
-        self.ifm0.dtype = dtype
-        self.ifm0.zp = relay.const(np.random.randint(0, 255), "int32")
-        self.ifm0.sc = relay.const(np.random.random() * 2, "float32")
-        self.ifm1.dtype = dtype
-        self.ifm1.zp = relay.const(np.random.randint(0, 255), "int32")
-        self.ifm1.sc = relay.const(np.random.random() * 2, "float32")
-        self.update_output_qnn_params()
-        self.ofm.dtype = dtype
-
-    def update_output_qnn_params(self):
-        ti = np.iinfo(self.ifm0.dtype)
-        dtype_min, dtype_max = int(ti.min), int(ti.max)
-        input1_max = self.ifm0.sc.data.asnumpy() * (dtype_max - 
self.ifm0.zp.data.asnumpy())
-        input1_min = (dtype_min - self.ifm0.sc.data.asnumpy()) * 
self.ifm0.zp.data.asnumpy()
-        input2_max = self.ifm1.sc.data.asnumpy() * (dtype_max - 
self.ifm1.zp.data.asnumpy())
-        input2_min = (dtype_min - self.ifm1.sc.data.asnumpy()) * 
self.ifm1.zp.data.asnumpy()
-        output_max = input1_max + input2_max
-        output_min = input1_min + input2_min
-        self.ofm.sc = relay.const((output_max - output_min) / dtype_max, 
"float32")
-        self.ofm.zp = relay.const(
-            (dtype_min - int(output_min / self.ofm.sc.data.asnumpy())), "int32"
-        )
-
-
-def get_pad_value(data, kernel, stride):
-    """Get the pad tuple of value for SAME padding"""
-
-    out = int(math.ceil(float(data) / float(stride)))
-    pad = max(0, (out - 1) * stride + kernel - data)
-    pad_before = pad // 2
-    pad_after = pad - pad_before
-    return pad_before, pad_after
-
-
-def create_qnn_conv2d(qnn_conv2d_params, ifm_expr):
-    """Create a relay.Expr of relay.qnn.conv2D given the parameters"""
-    v_params = list()
-    params = {
-        "kernel_size": [
-            qnn_conv2d_params.kernel.get_dim_size("H"),
-            qnn_conv2d_params.kernel.get_dim_size("W"),
-        ],
-        "strides": [qnn_conv2d_params.strides[0], 
qnn_conv2d_params.strides[1]],
-        "dilation": [qnn_conv2d_params.dilation[0], 
qnn_conv2d_params.dilation[1]],
-        "padding": [0, 0, 0, 0],
-        "data_layout": qnn_conv2d_params.ifm.layout,
-    }
-    dilated_kernel_h = (
-        qnn_conv2d_params.dilation[0] * 
(qnn_conv2d_params.kernel.get_dim_size("H") - 1) + 1
-    )
-    dilated_kernel_w = (
-        qnn_conv2d_params.dilation[1] * 
(qnn_conv2d_params.kernel.get_dim_size("W") - 1) + 1
-    )
-    if qnn_conv2d_params.pad == "SAME":
-        pad_top, pad_bottom = get_pad_value(
-            qnn_conv2d_params.ifm.get_dim_size("H"), dilated_kernel_h, 
qnn_conv2d_params.strides[0]
-        )
-        pad_left, pad_right = get_pad_value(
-            qnn_conv2d_params.ifm.get_dim_size("W"), dilated_kernel_w, 
qnn_conv2d_params.strides[1]
-        )
-        do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and 
pad_right == 0)
-        if do_pad:
-            params["padding"] = [pad_top, pad_left, pad_bottom, pad_right]
-    qnn_conv2d_params.pad = params["padding"]
-    params["input_zero_point"] = qnn_conv2d_params.ifm.zp
-    params["kernel_zero_point"] = qnn_conv2d_params.kernel.zp
-    params["out_dtype"] = "int32"
-    params["input_scale"] = qnn_conv2d_params.ifm.sc
-    params["kernel_scale"] = qnn_conv2d_params.kernel.sc
-    params["channels"] = int(qnn_conv2d_params.kernel.get_dim_size("O"))
-    params["kernel_layout"] = qnn_conv2d_params.kernel.layout
-    k_shape = qnn_conv2d_params.kernel.shape
-    k_dtype = qnn_conv2d_params.kernel.dtype
-    w = tvm.nd.array(
-        np.random.randint(
-            np.iinfo(k_dtype).min, high=np.iinfo(k_dtype).max, size=k_shape, 
dtype=k_dtype
-        )
-    )
-    weight_expr = relay.const(w, k_dtype)
-    v_params.append(w)
-    qnn_conv2d_expr = qnn.op.conv2d(ifm_expr, weight_expr, **params)
-    b = tvm.nd.array(
-        np.random.randint(
-            0, high=10, size=(qnn_conv2d_params.kernel.get_dim_size("O")), 
dtype="int32"
-        )
-    )
-    v_params.append(b)
-    bias_expr = relay.const(b, "int32")
-    bias = relay.nn.bias_add(
-        qnn_conv2d_expr, bias_expr, 
axis=qnn_conv2d_params.ifm.get_dim_index("C")
-    )
-    bias_scale = relay.const(
-        qnn_conv2d_params.ifm.sc.data.asnumpy() * 
qnn_conv2d_params.kernel.sc.data.asnumpy(),
-        "float32",
-    )
-    req_expr = relay.qnn.op.requantize(
-        bias,
-        bias_scale,  # input zero scale
-        relay.const(0, "int32"),  # input zero point
-        qnn_conv2d_params.ofm.sc,  # output zero scale
-        qnn_conv2d_params.ofm.zp,  # output zero point
-        out_dtype=qnn_conv2d_params.ofm.dtype,
-    )
-    if qnn_conv2d_params.activation != "NONE":
-        assert qnn_conv2d_params.activation == "CLIP"
-        clip_expr = relay.clip(req_expr, qnn_conv2d_params.clip_min, 
qnn_conv2d_params.clip_max)
-        return clip_expr, v_params
-
-    return req_expr, v_params
-
-
-def create_pool2d(pooling_params, ifm_expr):
-    """Create a relay pooling operation"""
-    assert pooling_params.ifm.layout == "NHWC"
-    params = {
-        "pool_size": (pooling_params.size[0], pooling_params.size[1]),
-        "strides": (pooling_params.strides[0], pooling_params.strides[1]),
-        "padding": [0, 0],
-        "layout": "NHWC",
-    }
-    if pooling_params.pad == "SAME":
-        pad_top, pad_bottom = get_pad_value(
-            pooling_params.ifm.shape[1], pooling_params.size[0], 
pooling_params.strides[0]
-        )
-        pad_left, pad_right = get_pad_value(
-            pooling_params.ifm.shape[2], pooling_params.size[1], 
pooling_params.strides[1]
-        )
-        params["padding"] = [pad_top, pad_left, pad_bottom, pad_right]
-    if pooling_params.type == "MAX":
-        out = relay.op.nn.max_pool2d(ifm_expr, **params)
-    else:
-        assert pooling_params.type == "AVG"
-        out = relay.op.cast(ifm_expr, dtype="int32")
-        out = relay.op.nn.avg_pool2d(out, **params)
-        out = relay.op.cast(out, dtype=pooling_params.ofm.dtype)
-    return out
-
-
-def create_qnn_add(ifm0_expr, ifm1_expr, add_params):
-    add = relay.qnn.op.add(
-        lhs=ifm0_expr,
-        rhs=ifm1_expr,
-        lhs_scale=add_params.ifm0.sc,
-        lhs_zero_point=add_params.ifm0.zp,
-        rhs_scale=add_params.ifm1.sc,
-        rhs_zero_point=add_params.ifm1.zp,
-        output_scale=add_params.ofm.sc,
-        output_zero_point=add_params.ofm.zp,
-    )
-    return add
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py 
b/tests/python/contrib/test_ethosu/test_codegen.py
index 21e86c8..0707ec2 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -18,22 +18,25 @@
 import pytest
 
 pytest.importorskip("ethosu.vela")
+
 import numpy as np
 import tflite.Model
 
 import tvm
 import tensorflow as tf
 from tvm import relay
+
 from tvm.relay.expr_functor import ExprMutator
 from tvm.relay.op.annotation import compiler_begin, compiler_end
 from tvm.relay.backend.contrib.ethosu import util
 from tvm.relay.backend.contrib.ethosu import preprocess
+
 from tvm.relay.op.contrib.ethosu import partition_for_ethosu
 from tests.python.relay.aot.aot_test_utils import generate_ref_data
 
-from . import relay_ir_builder
 from . import infra
 
+
 ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", 
"ethos-u55-32"]
 
 
@@ -51,122 +54,192 @@ def get_shape_expr(in_expr, out_expr):
     return shape
 
 
[email protected](
-    "accel_type",
-    ACCEL_TYPES,
-)
-def test_ethosu_conv2d(accel_type):
-    def create_graph_single(input_tensor_name, input_tensor_shape, 
input_tensor_dtype):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[3], 32)
-        c1_params.kernel.sc = relay.const(np.random.rand(32) * 2, "float32")
-        c1_params.strides = (1, 1)
-        c1_params.pad = "VALID"
-        c1_params.update_output_qnn_params(
-            input_tensor_dtype, input_tensor_dtype, input_tensor_dtype
-        )
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, 
dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
-
-        f = relay.Function([input0], c1)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c1_params]
-
-    def create_graph_double(input_tensor_name, input_tensor_shape, 
input_tensor_dtype):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8)
-        c1_params.strides = (2, 2)
-        c1_params.pad = "VALID"
-        c1_params.update_output_qnn_params(
-            input_tensor_dtype, input_tensor_dtype, input_tensor_dtype
-        )
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, 
dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
-
-        c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c2_params.ifm.shape = c1_params.ofm.shape
-        c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16)
-        c2_params.strides = (1, 1)
-        c2_params.pad = "SAME"
-        c2_params.update_output_qnn_params()
-        c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1)
-        c2_params.ofm.shape = get_shape_expr(input0, c2)
-
-        f = relay.Function([input0], c2)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c2_params, c1_params]
-
-    def create_graph_activation(input_tensor_name, input_tensor_shape, 
input_tensor_dtype):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8)
-        c1_params.strides = (2, 2)
-        c1_params.pad = "VALID"
-        c1_params.activation = "CLIP"
-        c1_params.clip_min = 90
-        c1_params.clip_max = 110
-        c1_params.update_output_qnn_params(
-            input_tensor_dtype, input_tensor_dtype, input_tensor_dtype
-        )
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, 
dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
-
-        c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c2_params.ifm.shape = c1_params.ofm.shape
-        c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16)
-        c2_params.strides = (1, 1)
-        c2_params.pad = "SAME"
-        c2_params.update_output_qnn_params()
-        c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1)
-        c2_params.ofm.shape = get_shape_expr(input0, c2)
-
-        f = relay.Function([input0], c2)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c2_params, c1_params]
-
-    test_cases = [
-        (create_graph_single, ["input", (1, 300, 300, 3), "int8"]),
-        (create_graph_double, ["input", (1, 128, 256, 4), "int8"]),
-        (create_graph_activation, ["input", (1, 64, 100, 4), "int8"]),
-    ]
-    np.random.seed(42)
-    for test_case in test_cases:
-        relay_module, conv_params = test_case[0](*test_case[1])
-        input_tensor, input_shape, input_dtype = test_case[1]
-        mod = partition_for_ethosu(relay_module)
-
-        # Generate reference data
-        in_min, in_max = util.get_range_for_dtype_str(input_dtype)
-        input_data = {
-            input_tensor: np.random.randint(
-                in_min, high=in_max, size=input_shape, dtype=input_dtype
-            )
-        }
-        output_data = generate_ref_data(relay_module, input_data)
[email protected]("ifm_shape", [(1, 299, 299, 3), (1, 55, 55, 3)])
[email protected]("kernel_shape", [(3, 2), (1, 3)])
[email protected]("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 
1))])
[email protected]("padding", ["SAME", "VALID"])
[email protected]("accel_type", ACCEL_TYPES)
[email protected]("activation", ["NONE", "RELU"])
+def test_ethosu_conv2d_single(
+    ifm_shape,
+    kernel_shape,
+    strides,
+    dilation,
+    padding,
+    accel_type,
+    activation,
+):
+    dtype = "int8"
+
+    def create_tflite_graph_single():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x):
+                # Use tf.nn API to create the model
+                tf_strides = [1, strides[0], strides[1], 1]
+                op = tf.nn.conv2d(
+                    x,
+                    filters=tf.constant(
+                        np.random.uniform(size=[kernel_shape[0], 
kernel_shape[1], 3, 3]),
+                        dtype=tf.float32,
+                    ),
+                    strides=tf_strides,
+                    padding=padding,
+                    dilations=dilation,
+                )
+                if activation:
+                    op = tf.nn.relu(op)
+                return op
 
-        compiled_models = infra.build_source(
-            mod, input_data, output_data, accel_type, output_tolerance=1
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
         )
 
-        # Assumes only two runtime.Modules are created -- i.e. single offload 
module
-        ethosu_module = (
-            
compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    tflite_graph = create_tflite_graph_single()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    relay_module, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+    mod = partition_for_ethosu(relay_module, params)
+
+    # Generate reference data
+    input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+    compiled_models = infra.build_source(
+        mod,
+        input_data,
+        output_data,
+        accel_type,
+    )
+
+    # Assumes only two runtime.Modules are created -- i.e. single offload 
module
+    ethosu_module = 
compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
+
+    # Verify generated C source
+    get_artifacts = 
tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
+    infra.print_payload(cmms)
+    infra.verify_source(compiled_models, accel_type)
+
+
[email protected]("ifm_shape", [(1, 214, 227, 3), (1, 27, 42, 3)])
[email protected]("kernel_shape", [(3, 2), (1, 3)])
[email protected]("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 
1))])
[email protected]("padding", ["SAME", "VALID"])
[email protected]("accel_type", ACCEL_TYPES)
[email protected]("activation", ["NONE", "RELU"])
+def test_ethosu_conv2d_double(
+    ifm_shape,
+    kernel_shape,
+    strides,
+    dilation,
+    padding,
+    accel_type,
+    activation,
+):
+    dtype = "int8"
+
+    def create_tflite_graph_double():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function_double(self, x):
+                # Use tf.nn API to create the model with two convolutions
+                op = tf.nn.conv2d(
+                    x,
+                    filters=tf.constant(
+                        np.random.uniform(size=[kernel_shape[0], 
kernel_shape[1], 3, 3]),
+                        dtype=tf.float32,
+                    ),
+                    strides=strides,
+                    padding=padding,
+                    data_format="NHWC",
+                    dilations=dilation,
+                )
+                # Second convolution
+                op2 = tf.nn.conv2d(
+                    op,
+                    filters=tf.constant(
+                        np.random.uniform(size=(kernel_shape[0], 
kernel_shape[1], 3, 3)),
+                        dtype=tf.float32,
+                    ),
+                    strides=strides,
+                    padding=padding,
+                    data_format="NHWC",
+                    dilations=dilation,
+                )
+                if activation:
+                    op2 = tf.nn.relu(op2)
+                return op2
+
+        model = Model()
+        concrete_func = model.tf_function_double.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
         )
 
-        # Verify generated C source
-        get_artifacts = 
tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
-        compilation_artifacts = get_artifacts(ethosu_module)
-        cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
-        infra.print_payload(cmms)
-        infra.verify_source(compiled_models, accel_type)
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    tflite_graph = create_tflite_graph_double()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    relay_module, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+    mod = partition_for_ethosu(relay_module, params)
+
+    # Generate reference data
+    input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
+
+    compiled_models = infra.build_source(
+        mod,
+        input_data,
+        output_data,
+        accel_type,
+    )
+
+    # Assumes only two runtime.Modules are created -- i.e. single offload 
module
+    ethosu_module = 
compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
+
+    # Verify generated C source
+    get_artifacts = 
tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+    compilation_artifacts = get_artifacts(ethosu_module)
+    cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
+    infra.print_payload(cmms)
+    infra.verify_source(compiled_models, accel_type)
 
 
 def _compare_ethosu_with_reference(
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py 
b/tests/python/contrib/test_ethosu/test_legalize.py
index 946aa95..9dc94d9 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -19,6 +19,8 @@
 import pytest
 
 pytest.importorskip("ethosu.vela")
+
+import math
 import numpy as np
 import tensorflow as tf
 import tflite.Model
@@ -31,7 +33,6 @@ from tvm.relay.op.contrib import ethosu
 from tvm.relay.backend.contrib.ethosu import util
 from tvm.relay.build_module import bind_params_by_name
 
-from . import relay_ir_builder
 from . import infra
 
 
@@ -229,128 +230,121 @@ INVERSE_LAYOUT_TRANSFORM_OHWI_MAP = {
 }
 
 
-def test_ethosu_conv2d_legalize():
-    def create_graph_single(input_tensor_name, input_tensor_shape, 
input_tensor_dtype):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[3], 32)
-        c1_params.strides = (1, 1)
-        c1_params.pad = "VALID"
-        c1_params.activation = "CLIP"
-        c1_params.clip_min = 23
-        c1_params.clip_max = 180
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, 
dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
-
-        f = relay.Function([input0], c1)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c1_params]
-
-    def create_graph_double(input_tensor_name, input_tensor_shape, 
input_tensor_dtype):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.kernel.shape = (7, 7, c1_params.ifm.shape[3], 8)
-        c1_params.strides = (2, 2)
-        c1_params.pad = "VALID"
-        c1_params.activation = "CLIP"
-        c1_params.clip_min = 10
-        c1_params.clip_max = 240
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, 
dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
-
-        c2_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c2_params.ifm.shape = c1_params.ofm.shape
-        c2_params.kernel.shape = (5, 5, c2_params.ifm.shape[3], 16)
-        c2_params.strides = (1, 1)
-        c2_params.pad = "SAME"
-        c2, new_params = relay_ir_builder.create_qnn_conv2d(c2_params, c1)
-        c2_params.ofm.shape = get_shape_expr(input0, c2)
-
-        f = relay.Function([input0], c2)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c2_params, c1_params]
[email protected]("ifm_shape", [(1, 299, 299, 3), (1, 55, 55, 3)])
[email protected]("kernel_shape", [(3, 2), (1, 3)])
[email protected]("padding", ["SAME", "VALID"])
[email protected]("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 
1))])
[email protected]("activation", [None, "RELU"])
+def test_tflite_conv2d_legalize(ifm_shape, kernel_shape, padding, strides, 
dilation, activation):
+    dtype = "int8"
 
-    def verify_tensor(tensor_type, expr):
-        assert list(tensor_type.shape) == list(expr.checked_type.shape)
-        assert str(tensor_type.dtype) == str(expr.checked_type.dtype)
+    def create_tflite_graph_single():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, input_shape):
+                op = tf.nn.conv2d(
+                    input_shape,
+                    filters=tf.constant(
+                        np.random.uniform(size=(kernel_shape[0], 
kernel_shape[1], 3, 3)),
+                        dtype=tf.float32,
+                    ),
+                    strides=strides,
+                    padding=padding,
+                    data_format="NHWC",
+                    dilations=dilation,
+                )
+                if activation:
+                    op = tf.nn.relu(op)
+                return op
 
-    def verify_linear(ext_func, conv2d_params):
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
         op = ext_func.body
-        for param in conv2d_params:
-            verify_tensor(param.ifm, op.args[0])
-            verify_tensor(param.ofm, op)
-
-            # This will be in OHWI layout
-            weights_ohwi = op.args[1].data.asnumpy()
-            weights_layout = str(param.kernel.layout)
-            weights = np.transpose(weights_ohwi, 
INVERSE_LAYOUT_TRANSFORM_OHWI_MAP[weights_layout])
-            assert weights.shape == param.kernel.shape
-            assert weights.dtype == param.kernel.dtype
-
-            assert list(op.args[2].checked_type.shape)[0] == 
weights_ohwi.shape[0]
-
-            assert float(op.attrs.ifm_scale) == 
float(param.ifm.sc.data.asnumpy())
-            assert int(op.attrs.ifm_zero_point) == 
int(param.ifm.zp.data.asnumpy())
-            assert int(op.attrs.weight_zero_point) == 
int(param.kernel.zp.data.asnumpy())
-            assert float(op.attrs.ofm_scale) == 
float(param.ofm.sc.data.asnumpy())
-            assert int(op.attrs.ofm_zero_point) == 
int(param.ofm.zp.data.asnumpy())
-            assert int(op.attrs.ofm_channels) == int(weights_ohwi.shape[0])
-            assert list(op.attrs.padding) == list(param.pad)
-            assert list(op.attrs.strides) == list(param.strides)
-            assert list(op.attrs.dilation) == list(param.dilation)
-            assert str(op.attrs.activation) == str(param.activation)
-            assert int(op.attrs.clip_min) == int(param.clip_min)
-            assert int(op.attrs.clip_max) == int(param.clip_max)
-            op = op.args[0]
+        ofm_channels = op.attrs.ofm_channels
 
-    test_cases = [
-        (create_graph_single, ["input", (1, 299, 299, 3), "uint8"]),
-        (create_graph_double, ["input", (1, 128, 256, 4), "uint8"]),
-    ]
-    for test_case in test_cases:
-        mod, conv_params = test_case[0](*test_case[1])
-        mod = ethosu.partition_for_ethosu(mod)
-        mod = legalize.LegalizeConv2D()(mod)
-        verify_linear(mod["tvmgen_default_ethos_u_main_0"], conv_params)
-
-
-def test_ethosu_conv2d_legalize_errors():
-    def create_graph_single_unsupported_ifm_layout(
-        input_tensor_name, input_tensor_shape, input_tensor_dtype
-    ):
-        c1_params = relay_ir_builder.QnnConv2DParams(input_tensor_dtype)
-        c1_params.ifm.shape = input_tensor_shape
-        c1_params.ifm.layout = "NCHW"
-        c1_params.kernel.shape = (3, 3, c1_params.ifm.shape[1], 32)
-        c1_params.strides = (1, 1)
-        c1_params.pad = "VALID"
-        c1_params.activation = "CLIP"
-        c1_params.clip_min = 23
-        c1_params.clip_max = 180
-        input0 = relay.var(input_tensor_name, shape=c1_params.ifm.shape, 
dtype=c1_params.ifm.dtype)
-        c1, new_params = relay_ir_builder.create_qnn_conv2d(c1_params, input0)
-        c1_params.ofm.shape = get_shape_expr(input0, c1)
-
-        f = relay.Function([input0], c1)
-        mod = tvm.IRModule()
-        mod["main"] = f
-        return mod, [c1_params]
+        # check IFM
+        ifm = op.args[0].checked_type
+        assert list(ifm.shape) == list(ifm_shape)
+        assert str(ifm.dtype) == dtype
+        assert ifm.shape[3] == ofm_channels
+
+        # check OFM
+        ofm = op.checked_type
+        expected_ofm_shape = infra.compute_ofm_shape(
+            ifm_shape, padding, kernel_shape, strides, dilation
+        )
+        assert list(ofm.shape) == list(expected_ofm_shape)
+        assert str(ofm.dtype) == dtype
+        assert ofm.shape[3] == ofm_channels
+
+        # check weights
+        weights_ohwi = op.args[1].data.asnumpy()
+        assert str(weights_ohwi.dtype) == dtype
+        assert weights_ohwi.shape[0] == ofm_channels
+        assert weights_ohwi.shape[1] == kernel_shape[0]
+        assert weights_ohwi.shape[2] == kernel_shape[1]
+        assert weights_ohwi.shape[3] == 3
 
-    test_cases = [
-        (create_graph_single_unsupported_ifm_layout, ["input", (1, 3, 299, 
299), "uint8"]),
+        # Check that scale_bias matches weight tensor
+        assert list(op.args[2].checked_type.shape)[0] == ofm_channels
+
+        expected_padding = infra.compute_padding_shape(
+            ifm_shape,
+            expected_ofm_shape,
+            padding,
+            (kernel_shape[0], kernel_shape[1]),
+            strides,
+            dilation,
+        )
+        assert list(op.attrs.padding) == list(expected_padding)
+        assert list(op.attrs.strides) == list(strides)
+        assert list(op.attrs.dilation) == list(dilation)
+        if activation == "RELU":
+            assert str(op.attrs.activation) == "CLIP"
+
+    conv2d_pattern_table = [
+        (
+            ethosu.QnnConv2DParams.composite_name,
+            ethosu.qnn_conv2d_pattern(),
+            lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
+        )
     ]
 
-    for test_case in test_cases:
-        mod, conv_params = test_case[0](*test_case[1])
-        mod = ethosu.partition_for_ethosu(mod)
-        with pytest.raises(
-            tvm._ffi.base.TVMError, match="EthosUCodegenError: Unsupported 
Layout NCHW"
-        ):
-            mod = legalize.LegalizeConv2D()(mod)
+    tflite_graph = create_tflite_graph_single()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, conv_params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod["main"] = bind_params_by_name(mod["main"], conv_params)
+    mod = partition_ethosu_by_table(mod, conv2d_pattern_table)
+
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.Conv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+
+    verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
 @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)])

Reply via email to