This is an automated email from the ASF dual-hosted git repository.
manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8226bd0 [microNPU] Some housekeeping in the test_ethosu folder
(#10824)
8226bd0 is described below
commit 8226bd0af01aa35d6f9ca07a862ebe98b32fd337
Author: Elen Kalda <[email protected]>
AuthorDate: Thu Mar 31 16:57:54 2022 +0100
[microNPU] Some housekeeping in the test_ethosu folder (#10824)
* [microNPU] Some housekeeping in the test_ethosu folder
* Move the utility functions from test_codegen.py into infra.py for
wider accessibility
* Remove some unused code
* Make the conv2d codegen tests more general
* Update test_identity_optimizer.py
* Update test_lut_optimizer.py
---
tests/python/contrib/test_ethosu/infra.py | 141 ++++++--
tests/python/contrib/test_ethosu/test_codegen.py | 393 +++++----------------
.../contrib/test_ethosu/test_identity_optimizer.py | 7 +-
tests/python/contrib/test_ethosu/test_legalize.py | 14 -
.../contrib/test_ethosu/test_lut_optimizer.py | 3 +-
5 files changed, 199 insertions(+), 359 deletions(-)
diff --git a/tests/python/contrib/test_ethosu/infra.py
b/tests/python/contrib/test_ethosu/infra.py
index 00af0a8..25b4b1b 100644
--- a/tests/python/contrib/test_ethosu/infra.py
+++ b/tests/python/contrib/test_ethosu/infra.py
@@ -28,7 +28,8 @@ from typing import List
import os
import struct
-import numpy
+import numpy as np
+import tflite.Model
import math
from enum import IntEnum
import tensorflow as tf
@@ -41,7 +42,11 @@ import tvm
from tvm import relay
import tvm.relay.backend.contrib.ethosu.op as ethosu_ops
from tvm.topi.nn.utils import get_pad_tuple
+from tvm.relay.expr_functor import ExprMutator
+from tvm.relay.op.annotation import compiler_begin, compiler_end
+from tvm.relay.backend.contrib.ethosu import preprocess
+from tvm.relay.op.contrib.ethosu import partition_for_ethosu
from tests.python.relay.aot.aot_test_utils import (
AOTCompiledTestModel,
AOTDataLinkage,
@@ -180,13 +185,13 @@ class InputGenerator:
self._random_state = random_state
def generate(self, size, dtype):
- if dtype == numpy.float32:
+ if dtype == np.float32:
print("random float32")
return self._random_state.uniform(-1, 1, size).astype(dtype)
else:
- print("random (u)int min=%d max=%d", numpy.iinfo(dtype).min,
numpy.iinfo(dtype).max)
- low = numpy.iinfo(dtype).min
- high = numpy.iinfo(dtype).max + 1
+ print("random (u)int min=%d max=%d", np.iinfo(dtype).min,
np.iinfo(dtype).max)
+ low = np.iinfo(dtype).min
+ high = np.iinfo(dtype).max + 1
return self._random_state.randint(low, high, size, dtype)
@@ -213,7 +218,7 @@ def generate_ref_data_tflite(model):
# Initialize random generators with a fixed seed to get deterministic
results
seed = 0
- random_state = numpy.random.RandomState(seed)
+ random_state = np.random.RandomState(seed)
inputgen = InputGenerator(random_state)
@@ -237,31 +242,117 @@ def generate_ref_data_tflite(model):
return input_data, expected_output_data
-def make_partitioned_function(relay_op):
+def get_tflite_graph(tf_func, shapes, ranges=None):
+ tensor_specs = [tf.TensorSpec(shape, dtype=tf.float32) for shape in shapes]
+ if not ranges:
+ ranges = [(0, 1) for _ in shapes]
+ concrete_func = tf_func.get_concrete_function(*tensor_specs)
- ifm0 = relay.analysis.free_vars(relay_op)
- ifm_shape = ifm0[0].type_annotation.shape
- ifm_dtype = ifm0[0].type_annotation.dtype
+ # Convert the model
+ def representative_dataset():
+ for _ in range(100):
+ inputs = []
+ for i, shape in enumerate(shapes):
+ data = np.random.uniform(
+ low=ranges[i][0], high=ranges[i][1], size=tuple(shape)
+ ).astype("float32")
+ inputs.append(data)
- ifm = relay.var("ifm", shape=ifm_shape, dtype=ifm_dtype)
+ yield inputs
- glb_ethosu = relay.GlobalVar("tvmgen_default_ethosu_main_0")
+ converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+ tflite_graph = converter.convert()
- func = (
- relay.Function(ifm0, relay_op)
- .with_attr("Inline", 1)
- .with_attr("Compiler", "ethos-u")
- .with_attr("global_symbol", "tvmgen_default_ethosu_main_0")
- .with_attr("Primitive", 1)
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+ relay_module, params = relay.frontend.from_tflite(tflite_model)
+ mod = partition_for_ethosu(relay_module, params)
+ return mod, tflite_graph
+
+
+def compare_ethosu_with_reference(
+ mod, input_data, output_data, accel_type, output_tolerance=0,
print_cmm=False
+):
+ compiled_models = build_source(
+ mod,
+ input_data,
+ output_data,
+ accel_type,
+ output_tolerance=output_tolerance,
)
- mod = tvm.IRModule()
- mod[glb_ethosu] = func
- mod = relay.transform.InferType()(mod)
- call = relay.Call(glb_ethosu, [ifm])
- mod["main"] = relay.Function([ifm], call)
- mod = relay.transform.InferType()(mod)
+ # Assumes only two runtime.Modules are created -- i.e. single offload
module
+ ethosu_module =
compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
+
+ # Verify generated C source
+ if print_cmm:
+ get_artifacts =
tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
+ compilation_artifacts = get_artifacts(ethosu_module)
+ cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
+ print_payload(cmms)
+
+ verify_source(compiled_models, accel_type)
+
+
+def compare_tvm_with_tflite(
+ tf_func, shapes, accel_type, ranges=None, output_tolerance=0,
print_cmm=False
+):
+ mod, tflite_graph = get_tflite_graph(tf_func, shapes, ranges)
+
+ # Generate reference data
+ input_data, output_data = generate_ref_data_tflite(tflite_graph)
+
+ compare_ethosu_with_reference(
+ mod,
+ input_data,
+ output_data,
+ accel_type,
+ output_tolerance=output_tolerance,
+ print_cmm=print_cmm,
+ )
+
+
+class EthosUAnnotator(ExprMutator):
+ """Annotate entire graph for Ethos-U offload"""
+
+ def __init__(self):
+ super(EthosUAnnotator, self).__init__()
+ self.compiler = "ethos-u"
+ self.last_call = True
+
+ def visit_call(self, call):
+ curr_last = self.last_call
+ self.last_call = False
+
+ params = []
+ for arg in call.args:
+ param = super().visit(arg)
+ if isinstance(param, relay.expr.Var):
+ param = compiler_begin(param, self.compiler)
+ params.append(param)
+ new_call = relay.Call(call.op, params, call.attrs)
+ if curr_last:
+ new_call = compiler_end(new_call, self.compiler)
+ return new_call
+
+ def visit_constant(self, constant):
+ new_constant = compiler_begin(constant, self.compiler)
+ return new_constant
+
+
+def create_ethosu_partition(mod):
+ mod["main"] = EthosUAnnotator().visit(mod["main"])
+ mod = relay.transform.MergeCompilerRegions()(mod)
+ mod = relay.transform.InferType()(mod)
+ mod = relay.transform.PartitionGraph()(mod)
+ mod = relay.transform.InferType()(mod)
+ mod = preprocess.preprocess_ext_io()(mod)
return mod
@@ -269,7 +360,7 @@ def generate_weights_data(shape, dtype):
size = 1
for dim in shape:
size *= dim
- return (numpy.arange(size) % 255).reshape(shape).astype(dtype)
+ return (np.arange(size) % 255).reshape(shape).astype(dtype)
def get_convolutional_args(call, include_buffers=False,
remove_constants=False):
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py
b/tests/python/contrib/test_ethosu/test_codegen.py
index 4934920..2e378aa 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -26,10 +26,7 @@ import tvm
import tensorflow as tf
from tvm import relay
-from tvm.relay.expr_functor import ExprMutator
-from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.backend.contrib.ethosu import util
-from tvm.relay.backend.contrib.ethosu import preprocess
from tvm.relay.op.contrib.ethosu import partition_for_ethosu
from tests.python.relay.aot.aot_test_utils import generate_ref_data
@@ -40,21 +37,7 @@ from . import infra
ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64",
"ethos-u55-32", "ethos-u65-256"]
-def infer_type_function_pass(func):
- mod = tvm.IRModule()
- mod["test"] = func
- mod = relay.transform.InferType()(mod)
- return mod["test"]
-
-
-def get_shape_expr(in_expr, out_expr):
- main_f = relay.Function([in_expr], out_expr)
- main_f = infer_type_function_pass(main_f)
- shape = [int(i) for i in main_f.body.checked_type.shape]
- return shape
-
-
[email protected]("ifm_shape", [(1, 299, 299, 3), (1, 55, 55, 3)])
[email protected]("ifm_shape", [(1, 299, 299, 2), (1, 55, 55, 3)])
@pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)])
@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1,
1))])
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
@@ -70,80 +53,29 @@ def test_ethosu_conv2d_single(
activation,
):
np.random.seed(0)
- dtype = "int8"
-
- def create_tflite_graph_single():
- class Model(tf.Module):
- @tf.function
- def tf_function(self, x):
- # Use tf.nn API to create the model
- tf_strides = [1, strides[0], strides[1], 1]
- op = tf.nn.conv2d(
- x,
- filters=tf.constant(
- np.random.uniform(size=[kernel_shape[0],
kernel_shape[1], 3, 3]),
- dtype=tf.float32,
- ),
- strides=tf_strides,
- padding=padding,
- dilations=dilation,
- )
- if activation:
- op = tf.nn.relu(op)
- return op
- model = Model()
- concrete_func = model.tf_function.get_concrete_function(
- tf.TensorSpec(ifm_shape, dtype=tf.float32)
+ @tf.function
+ def conv2d(x):
+ # Use tf.nn API to create the model
+ tf_strides = [1, strides[0], strides[1], 1]
+ op = tf.nn.conv2d(
+ x,
+ filters=tf.constant(
+ np.random.uniform(size=[kernel_shape[0], kernel_shape[1],
ifm_shape[3], 3]),
+ dtype=tf.float32,
+ ),
+ strides=tf_strides,
+ padding=padding,
+ dilations=dilation,
)
+ if activation:
+ op = tf.nn.relu(op)
+ return op
- # Convert the model
- def representative_dataset():
- for _ in range(100):
- data = np.random.rand(*tuple(ifm_shape))
- yield [data.astype(np.float32)]
-
- converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset
- converter.target_spec.supported_ops =
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- tflite_model = converter.convert()
- return tflite_model
-
- tflite_graph = create_tflite_graph_single()
- tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
-
- relay_module, params = relay.frontend.from_tflite(
- tflite_model,
- shape_dict={"input": ifm_shape},
- dtype_dict={"input": dtype},
- )
- mod = partition_for_ethosu(relay_module, params)
-
- # Generate reference data
- input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
-
- compiled_models = infra.build_source(
- mod,
- input_data,
- output_data,
- accel_type,
- )
-
- # Assumes only two runtime.Modules are created -- i.e. single offload
module
- ethosu_module =
compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- get_artifacts =
tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_models, accel_type)
+ infra.compare_tvm_with_tflite(conv2d, [ifm_shape], accel_type)
[email protected]("ifm_shape", [(1, 214, 227, 3), (1, 27, 42, 3)])
[email protected]("ifm_shape", [(1, 214, 227, 2), (1, 27, 42, 3)])
@pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)])
@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1,
1))])
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
@@ -159,89 +91,36 @@ def test_ethosu_conv2d_double(
activation,
):
np.random.seed(0)
- dtype = "int8"
- def create_tflite_graph_double():
- class Model(tf.Module):
- @tf.function
- def tf_function_double(self, x):
- # Use tf.nn API to create the model with two convolutions
- op = tf.nn.conv2d(
- x,
- filters=tf.constant(
- np.random.uniform(size=[kernel_shape[0],
kernel_shape[1], 3, 3]),
- dtype=tf.float32,
- ),
- strides=strides,
- padding=padding,
- data_format="NHWC",
- dilations=dilation,
- )
- # Second convolution
- op2 = tf.nn.conv2d(
- op,
- filters=tf.constant(
- np.random.uniform(size=(kernel_shape[0],
kernel_shape[1], 3, 3)),
- dtype=tf.float32,
- ),
- strides=strides,
- padding=padding,
- data_format="NHWC",
- dilations=dilation,
- )
- if activation:
- op2 = tf.nn.relu(op2)
- return op2
-
- model = Model()
- concrete_func = model.tf_function_double.get_concrete_function(
- tf.TensorSpec(ifm_shape, dtype=tf.float32)
+ @tf.function
+ def conv2d_double(x):
+ # Use tf.nn API to create the model with two convolutions
+ op = tf.nn.conv2d(
+ x,
+ filters=tf.constant(
+ np.random.uniform(size=[kernel_shape[0], kernel_shape[1],
ifm_shape[3], 5]),
+ dtype=tf.float32,
+ ),
+ strides=strides,
+ padding=padding,
+ dilations=dilation,
)
+ # Second convolution
+ op2 = tf.nn.conv2d(
+ op,
+ filters=tf.constant(
+ np.random.uniform(size=(kernel_shape[0], kernel_shape[1], 5,
3)),
+ dtype=tf.float32,
+ ),
+ strides=strides,
+ padding=padding,
+ dilations=dilation,
+ )
+ if activation:
+ op2 = tf.nn.relu(op2)
+ return op2
- # Convert the model
- def representative_dataset():
- for _ in range(100):
- data = np.random.rand(*tuple(ifm_shape))
- yield [data.astype(np.float32)]
-
- converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset
- converter.target_spec.supported_ops =
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- tflite_model = converter.convert()
- return tflite_model
-
- tflite_graph = create_tflite_graph_double()
- tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
-
- relay_module, params = relay.frontend.from_tflite(
- tflite_model,
- shape_dict={"input": ifm_shape},
- dtype_dict={"input": dtype},
- )
- mod = partition_for_ethosu(relay_module, params)
-
- # Generate reference data
- input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
-
- compiled_models = infra.build_source(
- mod,
- input_data,
- output_data,
- accel_type,
- )
-
- # Assumes only two runtime.Modules are created -- i.e. single offload
module
- ethosu_module =
compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- get_artifacts =
tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
- infra.verify_source(compiled_models, accel_type)
+ infra.compare_tvm_with_tflite(conv2d_double, [ifm_shape], accel_type)
@pytest.mark.parametrize("weight_min, weight_max", [(0.0, 1e-11), (-1e10,
1e10)])
@@ -277,121 +156,7 @@ def test_out_of_range_scaling(weight_min, weight_max):
op = tf.nn.relu(op)
return op
- _compare_tvm_with_tflite(conv_invalid_scale, [ifm_shape], accel_type)
-
-
-def _compare_ethosu_with_reference(
- mod, input_data, output_data, accel_type, output_tolerance=0,
print_cmm=False
-):
- compiled_models = infra.build_source(
- mod,
- input_data,
- output_data,
- accel_type,
- output_tolerance=output_tolerance,
- )
-
- # Assumes only two runtime.Modules are created -- i.e. single offload
module
- ethosu_module =
compiled_models[0].executor_factory.lib.imported_modules[0].imported_modules[0]
-
- # Verify generated C source
- if print_cmm:
- get_artifacts =
tvm._ffi.get_global_func("runtime.module.ethos-u.get_artifacts")
- compilation_artifacts = get_artifacts(ethosu_module)
- cmms = bytes.fromhex(compilation_artifacts[0].command_stream)
- infra.print_payload(cmms)
-
- infra.verify_source(compiled_models, accel_type)
-
-
-def _compare_tvm_with_tflite(
- tf_func, shapes, accel_type, ranges=None, output_tolerance=0,
print_cmm=False
-):
- mod, tflite_graph = _get_tflite_graph(tf_func, shapes, ranges)
-
- # Generate reference data
- input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
-
- _compare_ethosu_with_reference(
- mod,
- input_data,
- output_data,
- accel_type,
- output_tolerance=output_tolerance,
- print_cmm=print_cmm,
- )
-
-
-def _get_tflite_graph(tf_func, shapes, ranges=None):
- tensor_specs = [tf.TensorSpec(shape, dtype=tf.float32) for shape in shapes]
- if not ranges:
- ranges = [(0, 1) for _ in shapes]
- concrete_func = tf_func.get_concrete_function(*tensor_specs)
-
- # Convert the model
- def representative_dataset():
- for _ in range(100):
- inputs = []
- for i, shape in enumerate(shapes):
- data = np.random.uniform(
- low=ranges[i][0], high=ranges[i][1], size=tuple(shape)
- ).astype("float32")
- inputs.append(data)
-
- yield inputs
-
- converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
- converter.representative_dataset = representative_dataset
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.int8
- converter.inference_output_type = tf.int8
- tflite_graph = converter.convert()
-
- tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
-
- relay_module, params = relay.frontend.from_tflite(tflite_model)
- mod = partition_for_ethosu(relay_module, params)
- return mod, tflite_graph
-
-
-class EthosUAnnotator(ExprMutator):
- """Annotate entire graph for Ethos-U offload"""
-
- def __init__(self):
- super(EthosUAnnotator, self).__init__()
- self.compiler = "ethos-u"
- self.last_call = True
-
- def visit_call(self, call):
- curr_last = self.last_call
- self.last_call = False
-
- params = []
- for arg in call.args:
- param = super().visit(arg)
- if isinstance(param, relay.expr.Var):
- param = compiler_begin(param, self.compiler)
- params.append(param)
-
- new_call = relay.Call(call.op, params, call.attrs)
- if curr_last:
- new_call = compiler_end(new_call, self.compiler)
- return new_call
-
- def visit_constant(self, constant):
- new_constant = compiler_begin(constant, self.compiler)
- return new_constant
-
-
-def _create_ethosu_partition(mod):
- mod["main"] = EthosUAnnotator().visit(mod["main"])
- mod = relay.transform.MergeCompilerRegions()(mod)
- mod = relay.transform.InferType()(mod)
- mod = relay.transform.PartitionGraph()(mod)
- mod = relay.transform.InferType()(mod)
- mod = preprocess.preprocess_ext_io()(mod)
- return mod
+ infra.compare_tvm_with_tflite(conv_invalid_scale, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -426,7 +191,7 @@ def test_tflite_depthwise_conv2d(
op = tf.nn.relu(op)
return op
- _compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type)
@pytest.mark.parametrize(
@@ -460,7 +225,7 @@ def test_ethosu_pooling(
op = tf.nn.relu(op)
return op
- _compare_tvm_with_tflite(pooling, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(pooling, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -500,7 +265,7 @@ def test_ethosu_binary_elementwise(
op = tf.nn.relu(op)
return op
- _compare_tvm_with_tflite(
+ infra.compare_tvm_with_tflite(
binary_elementwise,
shapes=[ifm_shape, ifm2_shape],
ranges=[(0, 1), (0, 2)],
@@ -529,7 +294,7 @@ def test_binary_add_with_non_4d_shapes(
def binary_elementwise(lhs, rhs):
return tf.math.add(lhs, rhs)
- _compare_tvm_with_tflite(
+ infra.compare_tvm_with_tflite(
binary_elementwise,
shapes=[ifm_shape, ifm2_shape],
ranges=[(0, 1), (0, 2)],
@@ -673,7 +438,7 @@ def test_elementwise_add_from_constant_scalar(accel_type,
dtype, constant):
}
output_data = generate_ref_data(cpu_mod, input_data)
- _compare_ethosu_with_reference(
+ infra.compare_ethosu_with_reference(
ethosu_mod, input_data, output_data, accel_type, output_tolerance=0
)
@@ -712,7 +477,7 @@ def test_ethosu_left_shift_binary_elemwise(
output_data = generate_ref_data(cpu_mod, input_data)
ethosu_mod = partition_for_ethosu(cpu_mod)
- _compare_ethosu_with_reference(
+ infra.compare_ethosu_with_reference(
ethosu_mod, input_data, output_data, accel_type, output_tolerance=0
)
@@ -771,9 +536,9 @@ def test_ethosu_right_shift_binary_elemwise(
"ifm2": rhs,
}
output_data = {"output": generate_output_data(input_data)[0]}
- ethosu_mod = _create_ethosu_partition(cpu_mod)
+ ethosu_mod = infra.create_ethosu_partition(cpu_mod)
- _compare_ethosu_with_reference(ethosu_mod, input_data, output_data,
accel_type)
+ infra.compare_ethosu_with_reference(ethosu_mod, input_data, output_data,
accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -800,9 +565,9 @@ def test_ethosu_identity_codegen(ifm_shape, ifm_scale,
ifm_zp, ofm_scale, ofm_zp
cpu_mod = create_model()
input_data = {"ifm": np.random.randint(-120, high=120, size=ifm_shape,
dtype="int8")}
output_data = {"output": generate_output_data(input_data)[0]}
- ethosu_mod = _create_ethosu_partition(cpu_mod)
+ ethosu_mod = infra.create_ethosu_partition(cpu_mod)
- _compare_ethosu_with_reference(
+ infra.compare_ethosu_with_reference(
ethosu_mod, input_data, output_data, accel_type, output_tolerance=1
)
@@ -831,9 +596,9 @@ def test_relay_reshape_codegen(ifm_shape, new_shape,
accel_type):
cpu_mod = create_model()
input_data = {"ifm": np.random.randint(-128, high=127, size=ifm_shape,
dtype="int8")}
output_data = generate_ref_data(cpu_mod, input_data)
- ethosu_mod = _create_ethosu_partition(cpu_mod)
+ ethosu_mod = infra.create_ethosu_partition(cpu_mod)
- _compare_ethosu_with_reference(ethosu_mod, input_data, output_data,
accel_type)
+ infra.compare_ethosu_with_reference(ethosu_mod, input_data, output_data,
accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -853,7 +618,7 @@ def test_tflite_slice(accel_type, ifm_shape, begin, size):
def slice_func(x):
return tf.slice(x, begin, size)
- _compare_tvm_with_tflite(slice_func, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(slice_func, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -868,7 +633,7 @@ def test_tflite_strided_slice(accel_type, ifm_shape, begin,
end):
def strided_slice_func(x):
return tf.strided_slice(x, begin, end)
- _compare_tvm_with_tflite(strided_slice_func, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(strided_slice_func, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -890,7 +655,7 @@ def test_ethosu_unary_elementwise(
op = tf.math.abs(x)
return op
- _compare_tvm_with_tflite(abs_func, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(abs_func, [ifm_shape], accel_type)
def test_ethosu_section_name():
@@ -904,7 +669,7 @@ def test_ethosu_section_name():
op = tf.nn.depthwise_conv2d(x, weight, strides=tf_strides,
padding="SAME", dilations=(2, 2))
return op
- mod, tflite_graph = _get_tflite_graph(depthwise_conv2d, [(1, 55, 55, 3)])
+ mod, tflite_graph = infra.get_tflite_graph(depthwise_conv2d, [(1, 55, 55,
3)])
# Generate reference data
input_data, output_data = infra.generate_ref_data_tflite(tflite_graph)
@@ -953,9 +718,9 @@ def test_ethosu_clz(accel_type):
cpu_mod = create_model()
input_data = {"ifm": np.random.randint(-500000, high=500000,
size=ifm_shape, dtype="int32")}
output_data = {"output": generate_output_data(input_data)[0]}
- ethosu_mod = _create_ethosu_partition(cpu_mod)
+ ethosu_mod = infra.create_ethosu_partition(cpu_mod)
- _compare_ethosu_with_reference(ethosu_mod, input_data, output_data,
accel_type)
+ infra.compare_ethosu_with_reference(ethosu_mod, input_data, output_data,
accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -968,7 +733,7 @@ def test_tflite_tanh(accel_type):
op = tf.nn.tanh(x)
return op
- _compare_tvm_with_tflite(tanh_func, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(tanh_func, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -992,7 +757,7 @@ def test_tflite_concat(shapes, axis, accel_type):
# TODO(lhutton1) For now output is not bit exact with TFLite.
# This is because TFLite reference kernels are not being used.
# For this, TFLite will need upgrading to 2.6.
- _compare_tvm_with_tflite(concat_func, shapes, accel_type,
output_tolerance=1)
+ infra.compare_tvm_with_tflite(concat_func, shapes, accel_type,
output_tolerance=1)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -1005,7 +770,7 @@ def test_tflite_sigmoid(accel_type):
op = tf.nn.sigmoid(x)
return op
- _compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(sigmoid_function, [ifm_shape], accel_type)
# This codegen test checks both, split and split_v
@@ -1029,7 +794,7 @@ def test_tflite_split(accel_type, ifm_shape,
num_or_size_splits, axis):
op = tf.split(x, num_or_size_splits, axis=axis)
return op
- _compare_tvm_with_tflite(split_func, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(split_func, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -1060,7 +825,7 @@ def test_ethosu_requantize(accel_type, ifm_shape,
ifm_scale, ifm_zp, ofm_scale,
output_data = generate_ref_data(cpu_mod, input_data)
ethosu_mod = partition_for_ethosu(cpu_mod)
- _compare_ethosu_with_reference(ethosu_mod, input_data, output_data,
accel_type)
+ infra.compare_ethosu_with_reference(ethosu_mod, input_data, output_data,
accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -1072,7 +837,7 @@ def test_tflite_expand_dims(accel_type, ifm_shape, axis):
def expand_dims_func(x):
return tf.expand_dims(x, axis=axis)
- _compare_tvm_with_tflite(expand_dims_func, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(expand_dims_func, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -1086,7 +851,7 @@ def test_tflite_squeeze(accel_type, ifm_shape, axis):
def squeeze_func(x):
return tf.squeeze(x, axis=axis)
- _compare_tvm_with_tflite(squeeze_func, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(squeeze_func, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -1104,7 +869,7 @@ def test_tflite_resize2d_nearest_neighbor(accel_type,
ifm_shape, size):
x, size, align_corners=align_corners, half_pixel_centers=False
)
- _compare_tvm_with_tflite(resize_model, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(resize_model, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -1130,7 +895,7 @@ def test_tflite_resize2d_bilinear(accel_type, ifm_shape,
size, align_corners):
# TODO(lhutton1) For now output is not bit exact with TFLite.
# This is because TFLite reference kernels are not being used.
# For this, TFLite will need upgrading to 2.6.
- _compare_tvm_with_tflite(resize_model, [ifm_shape], accel_type,
output_tolerance=1)
+ infra.compare_tvm_with_tflite(resize_model, [ifm_shape], accel_type,
output_tolerance=1)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -1170,7 +935,7 @@ def test_tflite_transpose_convolution(
op = tf.nn.bias_add(op, bias)
return op
- _compare_tvm_with_tflite(conv2d_transpose, [ifm_shape],
accel_type=accel_type)
+ infra.compare_tvm_with_tflite(conv2d_transpose, [ifm_shape],
accel_type=accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -1193,7 +958,7 @@ def test_tflite_pack(accel_type, ifm_shapes, axis):
# TODO(lhutton1) For now output is not bit exact with TFLite.
# This is because TFLite reference kernels are not being used.
# For this, TFLite will need upgrading to 2.6.
- _compare_tvm_with_tflite(pack_func, ifm_shapes, accel_type,
output_tolerance=1)
+ infra.compare_tvm_with_tflite(pack_func, ifm_shapes, accel_type,
output_tolerance=1)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -1208,7 +973,7 @@ def test_tflite_unpack(accel_type, ifm_shape, axis):
def unpack_func(x):
return tf.unstack(x, axis=axis)
- _compare_tvm_with_tflite(unpack_func, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(unpack_func, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -1221,7 +986,7 @@ def test_tflite_leaky_relu(accel_type, ifm_shape, alpha):
def leaky_relu_func(x):
return tf.nn.leaky_relu(x, alpha=alpha)
- _compare_tvm_with_tflite(leaky_relu_func, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(leaky_relu_func, [ifm_shape], accel_type)
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@@ -1253,7 +1018,7 @@ def test_tflite_fully_connected(
x = tf.nn.relu(x)
return x
- _compare_tvm_with_tflite(fully_connected, [ifm_shape], accel_type)
+ infra.compare_tvm_with_tflite(fully_connected, [ifm_shape], accel_type)
if __name__ == "__main__":
diff --git a/tests/python/contrib/test_ethosu/test_identity_optimizer.py
b/tests/python/contrib/test_ethosu/test_identity_optimizer.py
index 8a42fe8..f37509e 100644
--- a/tests/python/contrib/test_ethosu/test_identity_optimizer.py
+++ b/tests/python/contrib/test_ethosu/test_identity_optimizer.py
@@ -32,7 +32,6 @@ from tvm.relay.backend.contrib.ethosu.codegen import
relay_to_tir
from tvm.relay.backend.contrib.ethosu.codegen import IdentityOptimizer
from . import infra
-from .test_codegen import _compare_tvm_with_tflite
def _optimize(func, optimize=True):
@@ -323,7 +322,7 @@ def test_same_output():
z = tf.reshape(z, (1, 1, 25, 8))
return z
- _compare_tvm_with_tflite(model, ifm_shapes, "ethos-u55-256")
+ infra.compare_tvm_with_tflite(model, ifm_shapes, "ethos-u55-256")
def test_multi_output_identity_has_same_output():
@@ -341,7 +340,7 @@ def test_multi_output_identity_has_same_output():
y = tf.concat(outputs, axis=0)
return y
- _compare_tvm_with_tflite(model, [ifm_shape], "ethos-u55-256")
+ infra.compare_tvm_with_tflite(model, [ifm_shape], "ethos-u55-256")
def test_multiple_transform_ops_same_output():
@@ -356,4 +355,4 @@ def test_multiple_transform_ops_same_output():
x = tf.reshape(x, (12,))
return x
- _compare_tvm_with_tflite(model, [ifm_shape], "ethos-u55-256")
+ infra.compare_tvm_with_tflite(model, [ifm_shape], "ethos-u55-256")
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py
b/tests/python/contrib/test_ethosu/test_legalize.py
index 32cf2c1..455a5ac 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -223,20 +223,6 @@ def test_split_sections_legalize():
tvm.ir.assert_structural_equal(mod_axis2, expected_axis2)
-def infer_type_function_pass(func):
- mod = tvm.IRModule()
- mod["test"] = func
- mod = relay.transform.InferType()(mod)
- return mod["test"]
-
-
-def get_shape_expr(in_expr, out_expr):
- main_f = relay.Function([in_expr], out_expr)
- main_f = infer_type_function_pass(main_f)
- shape = [int(i) for i in main_f.body.checked_type.shape]
- return shape
-
-
INVERSE_LAYOUT_TRANSFORM_OHWI_MAP = {
"HWIO": [1, 2, 3, 0],
"HWOI": [1, 2, 0, 3],
diff --git a/tests/python/contrib/test_ethosu/test_lut_optimizer.py
b/tests/python/contrib/test_ethosu/test_lut_optimizer.py
index db2a1d5..87e6257 100644
--- a/tests/python/contrib/test_ethosu/test_lut_optimizer.py
+++ b/tests/python/contrib/test_ethosu/test_lut_optimizer.py
@@ -30,7 +30,6 @@ from tvm.relay.backend.contrib.ethosu.codegen import
LUTsOptimizer
from tvm.relay.backend.contrib.ethosu.codegen import relay_to_tir
from tvm.relay.op.contrib.ethosu import partition_for_ethosu
-from .test_codegen import _get_tflite_graph
from . import infra
@@ -121,7 +120,7 @@ def test_lut_optimizer_runs_in_compilation_pipeline():
op = tf.nn.depthwise_conv2d(op, weight2, (1, 1, 1, 1), "VALID")
return tf.nn.tanh(op)
- mod, _ = _get_tflite_graph(get_graph, [ifm_shape])
+ mod, _ = infra.get_tflite_graph(get_graph, [ifm_shape])
mod = partition_for_ethosu(mod)
mod = relay_to_tir(mod)