This is an automated email from the ASF dual-hosted git repository.
areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 67714c6 [1/3][AOT][DeviceAPI] Connecting devices structure to
relevant operators (#9395)
67714c6 is described below
commit 67714c64f2c9215809872dbf5e5267c4df45f3e1
Author: Christopher Sidebottom <[email protected]>
AuthorDate: Fri Nov 12 20:16:11 2021 +0000
[1/3][AOT][DeviceAPI] Connecting devices structure to relevant operators
(#9395)
* [AOT][DeviceAPI] Connecting devices structure to relevant operators
This patch adds support for passing the device context via the unpacked API
in AOT, generating an additional struct if necessary:
```c
/*!
* \brief Device context pointers for TVM module "default"
*/
struct tvmgen_default_devices {
void* npu;
};
```
Which is then added as an argument to the entry function:
```c
/*!
* \brief entrypoint function for TVM module "default"
* \param inputs Input tensors for the module
* \param outputs Output tensors for the module
* \param devices Device context pointers for the module
*/
int32_t tvmgen_default_run(
struct tvmgen_default_inputs* inputs,
struct tvmgen_default_outputs* outputs,
struct tvmgen_default_devices* devices
);
```
I've temporarily added the collection of external code generators to the TE
compiler pending proper annotation of the eventual functions.
Co-authored-by: Grant Watson <[email protected]>
* Correct "use_device_api" attribute name on Target
Co-authored-by: Grant Watson <[email protected]>
---
apps/microtvm/ethosu/Makefile | 4 +-
apps/microtvm/ethosu/include/tvm_ethosu_runtime.h | 30 ++++++
apps/microtvm/ethosu/src/demo.c | 7 +-
apps/microtvm/ethosu/src/tvm_ethosu_runtime.c | 34 +++++++
python/tvm/driver/tvmc/composite_target.py | 2 +-
python/tvm/micro/model_library_format.py | 7 +-
.../tvm/relay/backend/contrib/ethosu/_ffi_api.py | 2 +-
python/tvm/relay/backend/contrib/ethosu/codegen.py | 6 +-
.../tvm/relay/backend/contrib/ethosu/legalize.py | 2 +-
python/tvm/relay/backend/contrib/ethosu/util.py | 2 +-
.../tvm/relay/backend/contrib/ethosu/vela_api.py | 2 +-
python/tvm/relay/backend/executor_factory.py | 11 +-
python/tvm/relay/build_module.py | 16 ++-
python/tvm/relay/op/contrib/ethosu.py | 30 +++---
src/relay/backend/aot_executor_codegen.cc | 73 ++++++++++++--
src/relay/backend/build_module.cc | 6 ++
src/relay/backend/contrib/ethosu/compiler_attrs.cc | 6 +-
src/relay/backend/contrib/ethosu/preprocess.cc | 4 +-
src/relay/backend/contrib/ethosu/source_module.cc | 31 +++---
src/relay/backend/graph_executor_codegen.cc | 3 +
src/relay/backend/te_compiler.cc | 15 ++-
src/relay/backend/te_compiler.h | 7 ++
src/relay/transforms/partition_graph.cc | 1 +
.../contrib/ethosu/bare_metal/tvm_ethosu_runtime.c | 34 +++++++
.../contrib/ethosu/bare_metal/tvm_ethosu_runtime.h | 30 ++++++
src/runtime/meta_data.h | 6 +-
src/target/source/interface_c.cc | 39 +++++--
src/target/source/source_module.cc | 28 +++++-
src/target/target_kind.cc | 2 +
tests/cpp/target/source/interface_c_test.cc | 112 +++++++++++++++++++--
tests/micro/zephyr/test_zephyr_aot.py | 4 +-
tests/micro/zephyr/test_zephyr_armv7m.py | 2 +-
tests/python/contrib/test_ethosu/infra.py | 7 +-
.../contrib/test_ethosu/test_attr_passing.py | 6 +-
tests/python/contrib/test_ethosu/test_codegen.py | 18 ++--
tests/python/contrib/test_ethosu/test_legalize.py | 44 ++++----
.../python/contrib/test_ethosu/test_preprocess.py | 14 +--
tests/python/relay/aot/aot_test_utils.py | 61 ++++++++---
tests/python/relay/aot/corstone300.mk | 8 +-
39 files changed, 564 insertions(+), 152 deletions(-)
diff --git a/apps/microtvm/ethosu/Makefile b/apps/microtvm/ethosu/Makefile
index 3707999..d624571 100644
--- a/apps/microtvm/ethosu/Makefile
+++ b/apps/microtvm/ethosu/Makefile
@@ -35,7 +35,7 @@ RANLIB = arm-none-eabi-ranlib
PKG_CFLAGS = ${PKG_COMPILE_OPTS} \
-I${STANDALONE_CRT_PATH}/include \
-I${STANDALONE_CRT_PATH}/src/runtime/crt/include \
- -Iinclude \
+ -I${PWD}/include \
-I${CORSTONE_300_PATH} \
-I${ETHOSU_PATH}/core_driver/include \
-I${CMSIS_PATH}/Device/ARM/${ARM_CPU}/Include/ \
@@ -95,7 +95,7 @@ ${BUILD_DIR}/ethosu_core_driver/libethosu_core_driver.a:
$(QUIET)cd $(abspath $(BUILD_DIR)/ethosu_core_driver) && $(MAKE)
# Build demo application
-$(BUILD_DIR)/demo: src/demo.c $(BUILD_DIR)/stack_allocator.o
$(BUILD_DIR)/crt_backend_api.o ${BUILD_DIR}/libcodegen.a
${BUILD_DIR}/libcmsis_startup.a
${BUILD_DIR}/ethosu_core_driver/libethosu_core_driver.a ${BUILD_DIR}/libuart.a
+$(BUILD_DIR)/demo: src/demo.c src/tvm_ethosu_runtime.c
$(BUILD_DIR)/stack_allocator.o $(BUILD_DIR)/crt_backend_api.o
${BUILD_DIR}/libcodegen.a ${BUILD_DIR}/libcmsis_startup.a
${BUILD_DIR}/ethosu_core_driver/libethosu_core_driver.a ${BUILD_DIR}/libuart.a
$(QUIET)mkdir -p $(@D)
$(QUIET)$(CC) $(PKG_CFLAGS) -o $@ $^ $(PKG_LDFLAGS)
diff --git a/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h
b/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h
new file mode 100644
index 0000000..06188ba
--- /dev/null
+++ b/apps/microtvm/ethosu/include/tvm_ethosu_runtime.h
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_ETHOSU_ETHOSU_RUNTIME_H_
+#define TVM_RUNTIME_CONTRIB_ETHOSU_ETHOSU_RUNTIME_H_
+
+#include <ethosu_driver.h>
+#include <stddef.h>
+#include <stdint.h>
+
+int32_t TVMEthosULaunch(struct ethosu_driver* resource_handle, void* cms_data,
size_t cms_data_size,
+ uint64_t* base_addrs, size_t* base_addrs_size, int
num_tensors);
+
+#endif // TVM_RUNTIME_CONTRIB_ETHOSU_ETHOSU_RUNTIME_H_
diff --git a/apps/microtvm/ethosu/src/demo.c b/apps/microtvm/ethosu/src/demo.c
index 5ad4353..1ee13db 100644
--- a/apps/microtvm/ethosu/src/demo.c
+++ b/apps/microtvm/ethosu/src/demo.c
@@ -45,7 +45,12 @@ int main(int argc, char** argv) {
struct tvmgen_default_inputs inputs = {
.input = input,
};
- tvmgen_default_run(&inputs, &outputs);
+ struct ethosu_driver* driver = ethosu_reserve_driver();
+ struct tvmgen_default_devices devices = {
+ .ethos_u = driver,
+ };
+ tvmgen_default_run(&inputs, &outputs, &devices);
+ ethosu_release_driver(driver);
// Calculate index of max value
uint8_t max_value = 0;
diff --git a/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
b/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
new file mode 100644
index 0000000..6b7399b
--- /dev/null
+++ b/apps/microtvm/ethosu/src/tvm_ethosu_runtime.c
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "tvm_ethosu_runtime.h"
+
+#include <ethosu_driver.h>
+
+int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t
cms_data_size,
+ uint64_t* base_addrs, size_t* base_addrs_size, int
num_tensors) {
+ int32_t result =
+ ethosu_invoke(driver, cms_data, cms_data_size, base_addrs,
base_addrs_size, num_tensors);
+
+ // Map errors in invoke to TVM errors
+ if (result != 0) {
+ return -1;
+ }
+ return 0;
+}
diff --git a/python/tvm/driver/tvmc/composite_target.py
b/python/tvm/driver/tvmc/composite_target.py
index 0c04d2b..848af1e 100644
--- a/python/tvm/driver/tvmc/composite_target.py
+++ b/python/tvm/driver/tvmc/composite_target.py
@@ -65,7 +65,7 @@ REGISTERED_CODEGEN = {
"pass_pipeline": partition_for_ethosn78,
},
"ethos-u": {
- "config_key": "relay.ext.ethosu.options",
+ "config_key": "relay.ext.ethos-u.options",
"pass_pipeline": partition_for_ethosu,
},
"bnns": {
diff --git a/python/tvm/micro/model_library_format.py
b/python/tvm/micro/model_library_format.py
index f031ace..038cd0d 100644
--- a/python/tvm/micro/model_library_format.py
+++ b/python/tvm/micro/model_library_format.py
@@ -44,13 +44,13 @@ class UnsupportedInModelLibraryFormatError(Exception):
"""Raised when export_model_library_format does not support the given
Module tree."""
-def generate_c_interface_header(module_name, inputs, outputs, include_path):
+def generate_c_interface_header(module_name, inputs, outputs, devices,
include_path):
"""Generate C Interface header to be included in MLF"""
mangled_name = to_c_variable_style(prefix_generated_name(module_name))
metadata_header = os.path.join(include_path, f"{mangled_name}.h")
interface_c_create = tvm._ffi.get_global_func("runtime.InterfaceCCreate")
- interface_c_module = interface_c_create(module_name, inputs, outputs)
+ interface_c_module = interface_c_create(module_name, inputs, outputs,
devices)
with open(metadata_header, "w") as header_file:
header_file.write(interface_c_module.get_source())
@@ -318,7 +318,8 @@ def _export_graph_model_library_format(
include_path = codegen_dir / "host" / "include"
include_path.mkdir()
inputs, outputs = _get_inputs_and_outputs_from_module(mod)
- generate_c_interface_header(mod.libmod_name, inputs, outputs,
include_path)
+ devices = mod.get_devices()
+ generate_c_interface_header(mod.libmod_name, inputs, outputs, devices,
include_path)
parameters_dir = tempdir / "parameters"
parameters_dir.mkdir()
diff --git a/python/tvm/relay/backend/contrib/ethosu/_ffi_api.py
b/python/tvm/relay/backend/contrib/ethosu/_ffi_api.py
index ccf1039..22eb982 100644
--- a/python/tvm/relay/backend/contrib/ethosu/_ffi_api.py
+++ b/python/tvm/relay/backend/contrib/ethosu/_ffi_api.py
@@ -17,4 +17,4 @@
"""FFI APIs for relay transformation passes."""
import tvm._ffi # type: ignore
-tvm._ffi._init_api("relay.ext.ethosu", __name__)
+tvm._ffi._init_api("relay.ext.ethos-u", __name__)
diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py
b/python/tvm/relay/backend/contrib/ethosu/codegen.py
index 827146f..5fe51b4 100644
--- a/python/tvm/relay/backend/contrib/ethosu/codegen.py
+++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py
@@ -24,7 +24,7 @@ from tvm.relay.backend.contrib.ethosu import
tir_to_cs_translator
from tvm.relay.backend.contrib.ethosu import util
-@tvm._ffi.register_func("relay.ext.ethosu")
+@tvm._ffi.register_func("relay.ext.ethos-u")
def ethosu_compiler(external_function):
"""The entry-point to a compile a external relay function of
NPU compatible operators to generated command stream.
@@ -38,11 +38,11 @@ def ethosu_compiler(external_function):
input_size = util.calculate_size_bytes(external_function.params[0])
output_size = util.calculate_size_bytes(external_function.body)
cmms, encoded_constants, scratch_size = _compile(external_function)
- ethosu_runtime = tvm._ffi.get_global_func("runtime.module.ethosu.create")
+ ethosu_runtime = tvm._ffi.get_global_func("runtime.module.ethos-u.create")
return ethosu_runtime(func_name, cmms, encoded_constants, scratch_size,
input_size, output_size)
-@tvm._ffi.register_func("relay.ext.ethosu.constant_updater")
+@tvm._ffi.register_func("relay.ext.ethos-u.constant_updater")
def constant_updater(expr, symbol): # pylint: disable=unused-argument
"""
The constant updater process happen after lowering in the core compiler.
diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py
b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index d0d04ce..7a63351 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -128,7 +128,7 @@ class Conv2DRewriter(DFPatternCallback):
def __init__(self):
super().__init__(require_type=True)
- self.pattern = (wildcard().has_attr({"Composite":
"ethosu.qnn_conv2d"}))(wildcard())
+ self.pattern = (wildcard().has_attr({"Composite":
"ethos-u.qnn_conv2d"}))(wildcard())
def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map:
tvm.ir.container.Map
diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py
b/python/tvm/relay/backend/contrib/ethosu/util.py
index 8afb6eb..370821a 100644
--- a/python/tvm/relay/backend/contrib/ethosu/util.py
+++ b/python/tvm/relay/backend/contrib/ethosu/util.py
@@ -154,7 +154,7 @@ def round_up(a: int, b: int) -> int:
def get_accelerator_config():
"""Get the variant of the accelerator to compile for"""
- compiler_attrs =
tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")()
+ compiler_attrs =
tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
return compiler_attrs.accelerator_config
diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py
b/python/tvm/relay/backend/contrib/ethosu/vela_api.py
index 69095e4..345d459 100644
--- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py
+++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py
@@ -381,7 +381,7 @@ def get_accelerator_config() -> vapi.NpuAccelerator:
"ethos-u55-64": vapi.NpuAccelerator.Ethos_U55_64,
"ethos-u55-32": vapi.NpuAccelerator.Ethos_U55_32,
}
- compiler_attrs =
tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")()
+ compiler_attrs =
tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
accel_config_str = compiler_attrs.accelerator_config
assert accel_config_str in npu_accel_str_map.keys(), f"{accel_config_str}
is not supported"
return npu_accel_str_map[accel_config_str]
diff --git a/python/tvm/relay/backend/executor_factory.py
b/python/tvm/relay/backend/executor_factory.py
index 7b147b4..db33c1b 100644
--- a/python/tvm/relay/backend/executor_factory.py
+++ b/python/tvm/relay/backend/executor_factory.py
@@ -85,9 +85,11 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
The parameters of module
function_metadata : Map of String to FunctionInfo
This holds a map function names to their information
+ devices : List[str]
+ List of devices used in the module
"""
- def __init__(self, ir_mod, target, libmod, libmod_name, params,
function_metadata):
+ def __init__(self, ir_mod, target, libmod, libmod_name, params,
function_metadata, devices):
self.ir_mod = ir_mod
self.target = target
self.lib = libmod
@@ -95,6 +97,10 @@ class AOTExecutorFactoryModule(ExecutorFactoryModule):
self.params = params
self.iter_cnt = 0
self.function_metadata = function_metadata
+ self.devices = devices
+
+ def get_devices(self):
+ return self.devices
def get_params(self):
return self.params
@@ -152,6 +158,9 @@ class GraphExecutorFactoryModule(ExecutorFactoryModule):
def export_library(self, file_name, fcompile=None, addons=None, **kwargs):
return self.module.export_library(file_name, fcompile, addons,
**kwargs)
+ def get_devices(self):
+ return []
+
def get_params(self):
return self.params
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index e708635..38c9a40 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -101,6 +101,7 @@ class BuildModule(object):
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]
self._get_function_metadata = self.mod["get_function_metadata"]
+ self._get_devices = self.mod["get_devices"]
def build(
self, mod, target=None, target_host=None, params=None,
executor="graph", mod_name=None
@@ -236,6 +237,10 @@ class BuildModule(object):
each PrimFunc"""
return self._get_function_metadata()
+ def get_devices(self):
+ """Returns a list of devices configured in this module"""
+ return self._get_devices()
+
def get_params(self):
"""Return the updated weights."""
params = self._get_params_func()
@@ -370,14 +375,21 @@ def build(ir_mod, target=None, target_host=None,
params=None, mod_name="default"
mod=ir_mod, target=target, params=params, executor=executor,
mod_name=mod_name
)
func_metadata = bld_mod.get_function_metadata()
+ devices = bld_mod.get_devices()
if executor == "aot":
executor_factory = _executor_factory.AOTExecutorFactoryModule(
- ir_mod, target, runtime_mod, mod_name, params, func_metadata
+ ir_mod, target, runtime_mod, mod_name, params, func_metadata,
devices
)
elif executor == "graph":
executor_factory = _executor_factory.GraphExecutorFactoryModule(
- ir_mod, target, executor_config, runtime_mod, mod_name,
params, func_metadata
+ ir_mod,
+ target,
+ executor_config,
+ runtime_mod,
+ mod_name,
+ params,
+ func_metadata,
)
else:
assert False, "Executor " + executor + " not supported"
diff --git a/python/tvm/relay/op/contrib/ethosu.py
b/python/tvm/relay/op/contrib/ethosu.py
index 25538ca..a255f93 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -186,7 +186,7 @@ class QnnConv2DParams:
and extract quantization information of all the associated tensors.
"""
- composite_name = "ethosu.qnn_conv2d"
+ composite_name = "ethos-u.qnn_conv2d"
# The NPU only supports padding upto the numbers as follows
padding_bounds = [31, 31, 32, 32]
activation_map = {"clip": "CLIP"}
@@ -275,7 +275,7 @@ class QnnDepthwiseConv2DParams(QnnConv2DParams):
and extract the parameter information.
"""
- composite_name = "ethosu.depthwise_conv2d"
+ composite_name = "ethos-u.depthwise_conv2d"
# The hardware only supports padding upto the numbers as follows
padding_bounds = [31, 31, 32, 32]
@@ -343,11 +343,11 @@ def qnn_depthwise_conv2d_pattern() ->
tvm.relay.dataflow_pattern.DFPattern:
class MaxPool2DParams:
"""
- This class will parse a call to a ethosu.maxpool2d composite function
+ This class will parse a call to a ethos-u.maxpool2d composite function
and extract the parameter information.
"""
- composite_name = "ethosu.maxpool2d"
+ composite_name = "ethos-u.maxpool2d"
# The hardware only supports padding upto the numbers as follows
padding_bounds = [127, 127, 128, 128]
@@ -399,11 +399,11 @@ def qnn_maxpool2d_pattern() ->
tvm.relay.dataflow_pattern.DFPattern:
class AvgPool2DParams:
"""
- This class will parse a call to a ethosu.avgpool2d composite function
+ This class will parse a call to a ethos-u.avgpool2d composite function
and extract the parameter information.
"""
- composite_name = "ethosu.avgpool2d"
+ composite_name = "ethos-u.avgpool2d"
# The hardware only supports padding upto the numbers as follows
padding_bounds = [127, 127, 128, 128]
@@ -547,7 +547,7 @@ class AddParams(BinaryElementwiseParams):
and extract the parameter information.
"""
- composite_name = "ethosu.add"
+ composite_name = "ethos-u.add"
def __init__(self, func_body: Call):
BinaryElementwiseParams.__init__(self, func_body, "ADD", True)
@@ -589,7 +589,7 @@ class SubParams(BinaryElementwiseParams):
and extract the parameter information.
"""
- composite_name = "ethosu.sub"
+ composite_name = "ethos-u.sub"
def __init__(self, func_body: Call):
BinaryElementwiseParams.__init__(self, func_body, "SUB", True)
@@ -631,7 +631,7 @@ class MulParams(BinaryElementwiseParams):
and extract the parameter information.
"""
- composite_name = "ethosu.mul"
+ composite_name = "ethos-u.mul"
def __init__(self, func_body: Call):
BinaryElementwiseParams.__init__(self, func_body, "MUL", True)
@@ -673,7 +673,7 @@ class MinParams(BinaryElementwiseParams):
and extract the parameter information.
"""
- composite_name = "ethosu.min"
+ composite_name = "ethos-u.min"
def __init__(self, func_body: Call):
BinaryElementwiseParams.__init__(self, func_body, "MIN", False)
@@ -708,7 +708,7 @@ class MaxParams(BinaryElementwiseParams):
and extract the parameter information.
"""
- composite_name = "ethosu.max"
+ composite_name = "ethos-u.max"
def __init__(self, func_body: Call):
BinaryElementwiseParams.__init__(self, func_body, "MAX", False)
@@ -743,7 +743,7 @@ class ShlParams(BinaryElementwiseParams):
and extract the parameter information.
"""
- composite_name = "ethosu.shl"
+ composite_name = "ethos-u.shl"
def __init__(self, func_body: Call):
BinaryElementwiseParams.__init__(self, func_body, "SHL", False)
@@ -768,7 +768,7 @@ def shl_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
return pattern
-@register_pattern_table("ethosu")
+@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern,
Callable]]:
return [
(
@@ -848,10 +848,10 @@ def partition_for_ethosu(
if params:
mod["main"] = bind_params_by_name(mod["main"], params)
- pattern = relay.op.contrib.get_pattern_table("ethosu")
+ pattern = relay.op.contrib.get_pattern_table("ethos-u")
mod = relay.transform.InferType()(mod)
mod = relay.transform.MergeComposite(pattern)(mod)
- mod = relay.transform.AnnotateTarget("ethosu")(mod)
+ mod = relay.transform.AnnotateTarget("ethos-u")(mod)
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.InferType()(mod)
mod = relay.transform.PartitionGraph()(mod)
diff --git a/src/relay/backend/aot_executor_codegen.cc
b/src/relay/backend/aot_executor_codegen.cc
index 62bb715..c240ec8 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -18,8 +18,8 @@
*/
/*!
- * \file relay/backend/graph_codegen.cc
- * \brief Graph runtime codegen
+ * \file src/relay/backend/aot_executor_codegen.cc
+ * \brief AOT executor codegen
*/
#include <tvm/ir/module.h>
@@ -43,6 +43,7 @@
#include "../op/annotation/annotation.h"
#include "../op/call/call.h"
#include "../transforms/device_aware_visitors.h"
+#include "./name_transforms.h"
#include "./te_compiler.h"
#include "./utils.h"
@@ -316,7 +317,6 @@ class AOTExecutorCodegen : public MixedModeVisitor {
*/
void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) {
std::string func_name = call_lowered_props.lowered_func->name_hint;
-
tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
std::vector<tir::Stmt> create_func_call_stmts;
@@ -346,15 +346,21 @@ class AOTExecutorCodegen : public MixedModeVisitor {
calling_pattern = tvm::tir::builtin::call_extern();
}
- create_func_call_stmts.push_back(
- tir::Evaluate(tvm::tir::Call(DataType::Int(32), calling_pattern,
args)));
+ GlobalVar global_var = call_lowered_props.lowered_func;
+ bool has_c_device_api_context = device_contexts_.count(global_var) != 0;
+ if (has_c_device_api_context) {
+ args.push_back(device_contexts_[global_var]);
+ }
+
+ tir::Evaluate func_call(tvm::tir::Call(DataType::Int(32), calling_pattern,
args));
+ create_func_call_stmts.push_back(func_call);
tir::Stmt body = tir::SeqStmt(create_func_call_stmts);
stmts_.push_back(body);
}
/*!
- * brief Copy a variable to the output. This function is mainly used in edge
cases
+ * \brief Copy a variable to the output. This function is mainly used in
edge cases
* when we want to return an input or a parameter.
* TODO(giuseros): we should try to avoid unnecessary copy to the output,
e.g., in a
* copy-on-write fashion.
@@ -387,6 +393,39 @@ class AOTExecutorCodegen : public MixedModeVisitor {
stmts_.push_back(tir::LetStmt(tmp1, tostore, copy));
}
+ /*
+ * \brief Collects device context variables for passing to operators
+ */
+ void CollectDeviceVariables(const Map<GlobalVar, String>& device_contexts) {
+ Map<TargetKind, tir::Var> target_contexts;
+ TargetKindAttrMap<Bool> target_attr_map =
tvm::TargetKind::GetAttrMap<Bool>("use_device_api");
+
+ for (const auto& it : device_contexts) {
+ const GlobalVar& global_var = it.first;
+ const std::string device_context_name = it.second;
+
+ Optional<TargetKind> target_kind =
tvm::TargetKind::Get(device_context_name);
+ if (!target_kind || !target_attr_map.count(target_kind.value())) {
+ return;
+ }
+ if (target_attr_map[target_kind.value()]) {
+ std::string context_name = SanitizeName(device_context_name);
+ tir::Var device_context_var("device_context_" + context_name,
DataType::Handle());
+
+ auto pair = target_contexts.find(target_kind.value());
+ if (pair != target_contexts.end()) {
+ device_context_var = (*pair).second;
+ } else {
+ main_signature_.push_back(device_context_var);
+ devices_.push_back(context_name);
+ target_contexts.Set(target_kind.value(), device_context_var);
+ }
+
+ device_contexts_.Set(global_var, device_context_var);
+ }
+ }
+ }
+
/*!
* Utility function to string together different arguments
*/
@@ -558,6 +597,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
runtime::Module* mod_;
/*! \brief list of input expressions (i.e., variable passed by the user) */
std::vector<Var> input_vars_;
+ /*! \brief list of device contexts used */
+ std::vector<String> devices_;
+ /*! \brief map of GlobalVars to C Device API contexts */
+ Map<GlobalVar, tir::Var> device_contexts_;
/*! \brief input and output variables belonging to the main function
signature */
Array<tir::Var> main_signature_;
/*! \brief target device */
@@ -671,6 +714,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
main_signature_.push_back(tir::Var("output", DataType::Handle()));
}
+ CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar,
String>>("device_contexts").value());
VisitExpr(lowered_main_func->body);
// Create the runner function. Please note that the function is not legal
yet
@@ -734,11 +778,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
std::vector<String> input_var_names(input_vars_.size());
std::transform(input_vars_.begin(), input_vars_.end(),
input_var_names.begin(),
[](Var input_var) -> String { return
input_var->name_hint(); });
- ret.metadata =
- runtime::Metadata(input_var_names, return_sid_.size(),
runtime::kTvmExecutorAot, mod_name);
+
+ ret.metadata = runtime::Metadata(input_var_names, devices_,
return_sid_.size(),
+ runtime::kTvmExecutorAot, mod_name);
return ret;
}
-};
+
+ /*!
+ * \brief Get list of devices found
+ * \return List of devices
+ */
+ Array<String> ListDevices() { return devices_; }
+}; // namespace backend
class AOTExecutorCodegenModule : public runtime::ModuleNode {
public:
@@ -781,6 +832,10 @@ class AOTExecutorCodegenModule : public
runtime::ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->output_.function_metadata;
});
+ } else if (name == "get_devices") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ *rv = this->codegen_->ListDevices();
+ });
} else if (name == "get_metadata") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
output_.metadata; });
diff --git a/src/relay/backend/build_module.cc
b/src/relay/backend/build_module.cc
index e074ddc..24706fb 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -96,6 +96,8 @@ struct ExecutorCodegen {
return CallFunc<Map<Target, IRModule>>("get_irmodule", nullptr);
}
+ Array<String> ListDevices() { return CallFunc<Array<String>>("get_devices");
}
+
runtime::Metadata GetMetadata() { return
CallFunc<runtime::Metadata>("get_metadata"); }
virtual ~ExecutorCodegen() {}
@@ -195,6 +197,10 @@ class RelayBuildModule : public runtime::ModuleNode {
this->SetParam(kv.first, kv.second->data);
}
});
+ } else if (name == "get_devices") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ *rv = this->executor_codegen_->ListDevices();
+ });
} else if (name == "get_irmodule") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->executor_codegen_->GetIRModule();
diff --git a/src/relay/backend/contrib/ethosu/compiler_attrs.cc
b/src/relay/backend/contrib/ethosu/compiler_attrs.cc
index 6a87d11..5795db2 100644
--- a/src/relay/backend/contrib/ethosu/compiler_attrs.cc
+++ b/src/relay/backend/contrib/ethosu/compiler_attrs.cc
@@ -55,17 +55,17 @@ class EthosUCompilerConfig : public Attrs {
};
TVM_REGISTER_NODE_TYPE(EthosUCompilerConfigNode);
-TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.ethosu.options",
EthosUCompilerConfig);
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.ethos-u.options",
EthosUCompilerConfig);
auto GetCompilerAttrs() {
auto ctx = transform::PassContext::Current();
- auto cfg = ctx->GetConfig<EthosUCompilerConfig>("relay.ext.ethosu.options");
+ auto cfg = ctx->GetConfig<EthosUCompilerConfig>("relay.ext.ethos-u.options");
if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<EthosUCompilerConfig>();
}
return cfg;
}
-TVM_REGISTER_GLOBAL("relay.ext.ethosu.get_compiler_attrs").set_body_typed(GetCompilerAttrs);
+TVM_REGISTER_GLOBAL("relay.ext.ethos-u.get_compiler_attrs").set_body_typed(GetCompilerAttrs);
} // namespace ethosu
} // namespace contrib
diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc
b/src/relay/backend/contrib/ethosu/preprocess.cc
index ac52844..571a56a 100644
--- a/src/relay/backend/contrib/ethosu/preprocess.cc
+++ b/src/relay/backend/contrib/ethosu/preprocess.cc
@@ -198,7 +198,7 @@ class ExternalFuncIOHandler : public ExprRewriter {
return post;
}
if (auto compiler = func->GetAttr<String>(attr::kCompiler)) {
- if (compiler == "ethosu") {
+ if (compiler == "ethos-u") {
auto ext_input = std::move(post_call->args[0]);
auto arg_dtype =
Downcast<TensorType>(post_call->args[0]->checked_type())->dtype;
if (post_call->args.size() > 1) {
@@ -261,7 +261,7 @@ Pass PreprocessExternalFuncIO() {
return Sequential({preprocess_pass, InferType()});
}
-TVM_REGISTER_GLOBAL("relay.ext.ethosu.PreprocessExternalFuncIO")
+TVM_REGISTER_GLOBAL("relay.ext.ethos-u.PreprocessExternalFuncIO")
.set_body_typed(transform::PreprocessExternalFuncIO);
} // namespace transform
diff --git a/src/relay/backend/contrib/ethosu/source_module.cc
b/src/relay/backend/contrib/ethosu/source_module.cc
index e3f48bc..18a6951 100644
--- a/src/relay/backend/contrib/ethosu/source_module.cc
+++ b/src/relay/backend/contrib/ethosu/source_module.cc
@@ -183,7 +183,7 @@ class EthosUModuleNode : public ModuleNode {
*/
void PrintRuntimeFunctionHeader(std::stringstream& ss, std::string
func_name) {
ss << "TVM_DLL int32_t ";
- ss << func_name << "(void* input, void* output) {\n";
+ ss << func_name << "(void* input, void* output, void* resource_handle)
{\n";
}
/*!
@@ -219,7 +219,7 @@ class EthosUModuleNode : public ModuleNode {
ss << "#include <stdio.h>\n";
ss << "#include <stdlib.h>\n";
ss << "#include <tvm/runtime/crt/module.h>\n";
- ss << "#include <ethosu_driver.h>\n";
+ ss << "#include <tvm_ethosu_runtime.h>\n";
ss << "\n";
size_t weights_size = (weights_bias_hex.size() / 2);
ss << "static const size_t weights_size = " <<
std::to_string(weights_size) << ";\n";
@@ -243,7 +243,7 @@ class EthosUModuleNode : public ModuleNode {
PrintExternCPrefix(ss);
ss << "static int32_t " << func_no_dashes + "_(int8_t* in0, "
- << "size_t in0_size, int8_t* out0, size_t out0_size) {\n";
+ << "size_t in0_size, int8_t* out0, size_t out0_size, void*
resource_handle) {\n";
ss << " int num_tensors = 5;\n";
ss << " void* cms_data = (void*)(cms_data_data);\n";
ss << " int64_t device_type = kDLCPU;\n";
@@ -263,30 +263,25 @@ class EthosUModuleNode : public ModuleNode {
ss << SetBaseAddress(3, "in0");
ss << SetBaseAddress(4, "out0");
ss << "\n";
- ss << " struct ethosu_driver *drv = ethosu_reserve_driver();\n";
- ss << " int32_t result = ethosu_invoke(drv, cms_data, cms_data_size,
base_addrs, "
- "base_addrs_size, "
- "num_tensors);\n";
- ss << " ethosu_release_driver(drv);\n";
+ ss << " int32_t result = TVMEthosULaunch(resource_handle, cms_data,
cms_data_size, "
+ "base_addrs, base_addrs_size, num_tensors);\n";
if (scratch_size > 0) {
ss << " TVMBackendFreeWorkspace(device_type, device_id, scratch);\n";
}
- ss << " if (result != 0) {\n";
- ss << " return -1;\n";
- ss << " } else {\n";
- ss << " return 0;\n";
- ss << " }\n";
+ ss << " return result;\n";
ss << "}\n";
ss << "\n";
PrintExternCPostfix(ss);
ss << "\n";
PrintExternCPrefix(ss);
ss << "// Wrapper function is provided to allow for easier debugging\n";
- ss << "inline static int32_t " + func_no_dashes + "_wrapper_(void* input,
void* output) {\n";
+ ss << "inline static int32_t " + func_no_dashes +
+ "_wrapper_(void* input, void* output, void* resource_handle)
{\n";
ss << " size_t input_data_size = " << input_size << ";\n";
ss << " size_t output_data_size = " << output_size << ";\n";
ss << " return " + func_no_dashes +
- "_((int8_t*)input, input_data_size, (int8_t*)output,
output_data_size);\n";
+ "_((int8_t*)input, input_data_size, (int8_t*)output,
output_data_size, " +
+ "resource_handle);\n";
ss << "}\n";
PrintExternCPostfix(ss);
ss << "\n";
@@ -294,7 +289,7 @@ class EthosUModuleNode : public ModuleNode {
PrintRuntimeFunctionHeader(ss, func_name);
EnterScope();
PrintIndents(ss);
- ss << "return " << func_no_dashes << "_wrapper_(input, output);\n";
+ ss << "return " << func_no_dashes << "_wrapper_(input, output,
resource_handle);\n";
ExitScope();
ss << "}\n";
PrintExternCPostfix(ss);
@@ -317,14 +312,14 @@ inline EthosUModuleNode* EthosUModule::operator->() {
return static_cast<EthosUModuleNode*>(get_mutable());
}
-TVM_REGISTER_GLOBAL("runtime.module.ethosu.create")
+TVM_REGISTER_GLOBAL("runtime.module.ethos-u.create")
.set_body_typed([](String func_name, String cmms_hex, String
weights_bias_hex,
Integer scratch_size, Integer input_size, Integer
output_size) {
return EthosUModuleNode::Create(func_name, cmms_hex, weights_bias_hex,
scratch_size,
input_size, output_size);
});
-TVM_REGISTER_GLOBAL("runtime.module.ethosu.getcs").set_body_typed([](EthosUModule
mod) {
+TVM_REGISTER_GLOBAL("runtime.module.ethos-u.getcs").set_body_typed([](EthosUModule
mod) {
return mod->GetCS();
});
diff --git a/src/relay/backend/graph_executor_codegen.cc
b/src/relay/backend/graph_executor_codegen.cc
index 1bab2c9..1456f7e 100644
--- a/src/relay/backend/graph_executor_codegen.cc
+++ b/src/relay/backend/graph_executor_codegen.cc
@@ -685,6 +685,9 @@ class GraphExecutorCodegenModule : public
runtime::ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->output_.external_mods;
});
+ } else if (name == "get_devices") {
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
Array<String>(); });
} else if (name == "get_metadata") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
this->output_.metadata; });
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 99b7d64..69e6be6 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -178,6 +178,8 @@ class TECompilerImpl : public TECompilerNode {
return ret;
}
+ Map<GlobalVar, String> GetDeviceContexts() { return device_contexts_; }
+
void Clear() final { cache_.clear(); }
// List all items in the cache.
@@ -227,6 +229,9 @@ class TECompilerImpl : public TECompilerNode {
ir_module->Add(global_var, key->source_func);
value->cached_func = CachedFunc(target, global_var, {}, {},
te::Schedule{nullptr},
tir::PrimFunc{nullptr}, {}, ir_module);
+ // Collect these here as it's removed in LowerExternalFunctions()
+ std::string codegen_name =
key->source_func->GetAttr<String>(attr::kCompiler).value();
+ device_contexts_.Set(global_var, codegen_name);
return value;
}
@@ -313,6 +318,8 @@ class TECompilerImpl : public TECompilerNode {
std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
/*! \brief the cache key of the function that is being lowered currently*/
CCacheKey cur_ccache_key_;
+ /*! \brief Map of GlobalVar to C Device API context names */
+ Map<GlobalVar, String> device_contexts_;
};
TECompiler::TECompiler() {
@@ -920,8 +927,12 @@ IRModule LowerTE(const IRModule& module, const String&
module_name,
// Copy the lowered functions into the return module
updated_module->Update(compiler->GetLoweredFunctions());
- // Annotate the module with the external modules and function info
- updated_module = WithAttr(updated_module, "external_mods",
compiler->LowerExternalFunctions());
+ // Annotate the module with C Device API context mapping, the external
modules and function info
+ // this is until we have Target's annotated for the C Device API
+ // TODO(Mousius) - Remove "device_contexts" as soon as we have the graph
annotated properly with
+ // Target's
+ updated_module = WithAttrs(updated_module, {{"external_mods",
compiler->LowerExternalFunctions()},
+ {"device_contexts",
compiler->GetDeviceContexts()}});
return updated_module;
}
diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h
index da7333d..b5d5b50 100644
--- a/src/relay/backend/te_compiler.h
+++ b/src/relay/backend/te_compiler.h
@@ -109,6 +109,13 @@ class TECompilerNode : public Object {
*/
virtual tvm::Array<tvm::runtime::Module> LowerExternalFunctions() = 0;
+ /*!
+ * \brief Get C Device API context mapping
+ * \return Map of GlobalVar to associated C Device API context name (either
Target or kCompiler
+ * annotated)
+ */
+ virtual Map<GlobalVar, String> GetDeviceContexts() = 0;
+
virtual std::unordered_map<std::string, int> GetOpWeights() = 0;
/*! \brief clear the cache. */
diff --git a/src/relay/transforms/partition_graph.cc
b/src/relay/transforms/partition_graph.cc
index 6e52cbf..99799fd 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -49,6 +49,7 @@
namespace tvm {
namespace relay {
+
namespace partitioning {
/*! \brief This struct maintains the required metadata for a region to
generate a corresponding
diff --git a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
new file mode 100644
index 0000000..6b7399b
--- /dev/null
+++ b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "tvm_ethosu_runtime.h"
+
+#include <ethosu_driver.h>
+
+int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t
cms_data_size,
+ uint64_t* base_addrs, size_t* base_addrs_size, int
num_tensors) {
+ int32_t result =
+ ethosu_invoke(driver, cms_data, cms_data_size, base_addrs,
base_addrs_size, num_tensors);
+
+ // Map errors in invoke to TVM errors
+ if (result != 0) {
+ return -1;
+ }
+ return 0;
+}
diff --git a/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h
b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h
new file mode 100644
index 0000000..d62afc4
--- /dev/null
+++ b/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.h
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_ETHOSU_BARE_METAL_TVM_ETHOSU_RUNTIME_H_
+#define TVM_RUNTIME_CONTRIB_ETHOSU_BARE_METAL_TVM_ETHOSU_RUNTIME_H_
+
+#include <ethosu_driver.h>
+#include <stddef.h>
+#include <stdint.h>
+
+int32_t TVMEthosULaunch(struct ethosu_driver* driver, void* cms_data, size_t
cms_data_size,
+ uint64_t* base_addrs, size_t* base_addrs_size, int
num_tensors);
+
+#endif // TVM_RUNTIME_CONTRIB_ETHOSU_BARE_METAL_TVM_ETHOSU_RUNTIME_H_
diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h
index 66d9a44..fd612b0 100644
--- a/src/runtime/meta_data.h
+++ b/src/runtime/meta_data.h
@@ -58,6 +58,8 @@ class MetadataNode : public Object {
Array<String> inputs;
/*! \brief number of outputs of the main function */
int num_outputs = 1;
+ /*! \brief device contexts information for the main function */
+ Array<String> devices;
/*! \brief the executor to be used to run the model */
String executor = kTvmExecutorGraph;
@@ -73,9 +75,11 @@ class MetadataNode : public Object {
*/
class Metadata : public ObjectRef {
public:
- TVM_DLL Metadata(Array<String> inputs, int num_outputs, String executor,
String mod_name) {
+ TVM_DLL Metadata(Array<String> inputs, Array<String> devices, int
num_outputs, String executor,
+ String mod_name) {
auto n = make_object<MetadataNode>();
n->inputs = inputs;
+ n->devices = devices;
n->num_outputs = num_outputs;
n->executor = executor;
n->mod_name = mod_name;
diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc
index 4089ccc..fd11164 100644
--- a/src/target/source/interface_c.cc
+++ b/src/target/source/interface_c.cc
@@ -40,8 +40,9 @@ using namespace tvm::relay::backend;
class InterfaceCNode : public runtime::ModuleNode {
public:
- InterfaceCNode(std::string module_name, Array<String> inputs, Array<String>
outputs)
- : module_name_(module_name), inputs_(inputs), outputs_(outputs) {}
+ InterfaceCNode(std::string module_name, Array<String> inputs, Array<String>
outputs,
+ Array<String> devices)
+ : module_name_(module_name), inputs_(inputs), outputs_(outputs),
devices_(devices) {}
const char* type_key() const { return "h"; }
std::string GetSource(const std::string& format) final {
@@ -52,6 +53,12 @@ class InterfaceCNode : public runtime::ModuleNode {
EmitStruct(code, "inputs", inputs_);
EmitBrief(code, "Output tensor pointers");
EmitStruct(code, "outputs", outputs_);
+
+ if (!devices_.empty()) {
+ EmitBrief(code, "Device context pointers");
+ EmitStruct(code, "devices", devices_);
+ }
+
EmitRunFunction(code);
EmitLowerHeaderGuard(code);
@@ -108,26 +115,40 @@ class InterfaceCNode : public runtime::ModuleNode {
std::string run_function =
ToCVariableStyle(PrefixGeneratedName({module_name_, "run"}));
std::string inputs_struct =
ToCVariableStyle(PrefixGeneratedName({module_name_, "inputs"}));
std::string outputs_struct =
ToCVariableStyle(PrefixGeneratedName({module_name_, "outputs"}));
+ std::string devices_struct =
ToCVariableStyle(PrefixGeneratedName({module_name_, "devices"}));
code_stream << "/*!\n"
<< " * \\brief entrypoint function for TVM module \"" <<
module_name_ << "\"\n"
<< " * \\param inputs Input tensors for the module \n"
- << " * \\param outputs Output tensors for the module \n"
- << " */\n"
+ << " * \\param outputs Output tensors for the module \n";
+
+ if (!devices_.empty()) {
+ code_stream << " * \\param devices Device context pointers for the
module \n";
+ }
+
+ code_stream << " */\n"
<< "int32_t " << run_function << "(\n"
- << " struct " << inputs_struct << "* inputs,\n"
- << " struct " << outputs_struct << "* outputs\n"
- << ");\n";
+ << " struct " << inputs_struct << "* inputs,\n";
+
+ if (!devices_.empty()) {
+ code_stream << " struct " << outputs_struct << "* outputs,\n";
+ code_stream << " struct " << devices_struct << "* devices\n";
+ } else {
+ code_stream << " struct " << outputs_struct << "* outputs\n";
+ }
+
+ code_stream << ");\n";
}
std::string module_name_;
Array<String> inputs_;
Array<String> outputs_;
+ Array<String> devices_;
};
runtime::Module InterfaceCCreate(std::string module_name, Array<String> inputs,
- Array<String> outputs) {
- auto n = make_object<InterfaceCNode>(module_name, inputs, outputs);
+ Array<String> outputs, Array<String> devices)
{
+ auto n = make_object<InterfaceCNode>(module_name, inputs, outputs, devices);
return runtime::Module(n);
}
diff --git a/src/target/source/source_module.cc
b/src/target/source/source_module.cc
index 9b93b07..21f82c3 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -240,7 +240,8 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
const std::string& mod_name) {
code_ << "#include <" << mod_name << ".h>\n";
code_ << "TVM_DLL int32_t " << run_func << "(";
- unsigned int total_args = (metadata_->inputs.size() +
metadata_->num_outputs);
+ unsigned int total_args =
+ (metadata_->inputs.size() + metadata_->devices.size() +
metadata_->num_outputs);
for (unsigned int i = 0; i < total_args; ++i) {
code_ << "void* arg" << i;
if (i + 1 != total_args) {
@@ -249,10 +250,16 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
}
code_ << ");\n";
code_ << "int32_t " << entrypoint_name << "(";
- code_ << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "*
inputs,"
- << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "*
outputs"
- << ") {";
- code_ << "return " << run_func << "(";
+ code_ << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "*
inputs,";
+ if (!metadata_->devices.empty()) {
+ code_ << "struct " << runtime::get_name_mangled(mod_name, "outputs") <<
"* outputs,";
+ code_ << "struct " << runtime::get_name_mangled(mod_name, "devices") <<
"* devices";
+ } else {
+ code_ << "struct " << runtime::get_name_mangled(mod_name, "outputs") <<
"* outputs";
+ }
+
+ code_ << ") {"
+ << "return " << run_func << "(";
for (const auto& input : metadata_->inputs) {
std::string sanitised_input = input;
std::replace_if(sanitised_input.begin(), sanitised_input.end(),
isNotAlnum, '_');
@@ -268,6 +275,17 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
}
}
}
+
+ if (!metadata_->devices.empty()) {
+ code_ << ",";
+ for (const String& device : metadata_->devices) {
+ code_ << "devices->" << device;
+ if (device != metadata_->devices.back()) {
+ code_ << ",";
+ }
+ }
+ }
+
code_ << ");\n";
code_ << "}\n";
}
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 4403af2..9f7bc56 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -373,6 +373,8 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break
TVM_REGISTER_TARGET_KIND("composite",
kDLCPU).add_attr_option<Array<Target>>("devices");
+TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU).set_attr<Bool>("use_device_api",
Bool(true));
+
/********** Registry **********/
TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds);
diff --git a/tests/cpp/target/source/interface_c_test.cc
b/tests/cpp/target/source/interface_c_test.cc
index c53af43..7bfea17 100644
--- a/tests/cpp/target/source/interface_c_test.cc
+++ b/tests/cpp/target/source/interface_c_test.cc
@@ -29,7 +29,7 @@ namespace tvm {
namespace codegen {
runtime::Module InterfaceCCreate(std::string module_name, Array<String> inputs,
- Array<String> outputs);
+ Array<String> outputs, Array<String> devices);
namespace {
@@ -49,7 +49,7 @@ TEST(InterfaceAPI, ContainsHeaderGuards) {
<< "#endif\n\n"
<< "#endif // TVMGEN_ULTIMATE_CAT_SPOTTER_H_\n";
- runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter",
{"input"}, {"output"});
+ runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter",
{"input"}, {"output"}, {});
std::string header_source = test_module->GetSource();
ASSERT_THAT(header_source, HasSubstr(upper_header_guard.str()));
@@ -69,7 +69,29 @@ TEST(InterfaceAPI, ContainsRunFunction) {
<< " struct tvmgen_ultimate_cat_spotter_outputs* outputs\n"
<< ");\n";
- runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter",
{"input"}, {"output"});
+ runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter",
{"input"}, {"output"}, {});
+ std::string header_source = test_module->GetSource();
+
+ ASSERT_THAT(header_source, HasSubstr(run_function.str()));
+}
+
+TEST(InterfaceAPI, ContainsRunFunctionWithDevices) {
+ std::stringstream run_function;
+
+ run_function << "/*!\n"
+ << " * \\brief entrypoint function for TVM module
\"ultimate_cat_spotter\"\n"
+ << " * \\param inputs Input tensors for the module \n"
+ << " * \\param outputs Output tensors for the module \n"
+ << " * \\param devices Device context pointers for the module
\n"
+ << " */\n"
+ << "int32_t tvmgen_ultimate_cat_spotter_run(\n"
+ << " struct tvmgen_ultimate_cat_spotter_inputs* inputs,\n"
+ << " struct tvmgen_ultimate_cat_spotter_outputs* outputs,\n"
+ << " struct tvmgen_ultimate_cat_spotter_devices* devices\n"
+ << ");\n";
+
+ runtime::Module test_module =
+ InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
{"device"});
std::string header_source = test_module->GetSource();
ASSERT_THAT(header_source, HasSubstr(run_function.str()));
@@ -85,7 +107,7 @@ TEST(InterfaceAPI, ContainsInputStructSingle) {
<< " void* input;\n"
<< "};\n\n";
- runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter",
{"input"}, {"output"});
+ runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter",
{"input"}, {"output"}, {});
std::string header_source = test_module->GetSource();
ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
@@ -100,7 +122,7 @@ TEST(InterfaceAPI, ContainsInputStructMany) {
<< "};\n\n";
runtime::Module test_module =
- InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"},
{"output"});
+ InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"},
{"output"}, {});
std::string header_source = test_module->GetSource();
ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
@@ -115,7 +137,7 @@ TEST(InterfaceAPI, ContainsInputStructSanitised) {
<< "};\n\n";
runtime::Module test_module =
- InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"},
{"output"});
+ InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"},
{"output"}, {});
std::string header_source = test_module->GetSource();
ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
@@ -123,7 +145,7 @@ TEST(InterfaceAPI, ContainsInputStructSanitised) {
TEST(InterfaceAPI, ContainsInputStructClash) {
runtime::Module test_module =
- InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"},
{"output"});
+ InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"},
{"output"}, {});
ASSERT_THROW(test_module->GetSource(), InternalError);
}
@@ -137,7 +159,7 @@ TEST(InterfaceAPI, ContainsOutputStructSingle) {
<< " void* output;\n"
<< "};\n\n";
- runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter",
{"input"}, {"output"});
+ runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter",
{"input"}, {"output"}, {});
std::string header_source = test_module->GetSource();
ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
@@ -152,7 +174,7 @@ TEST(InterfaceAPI, ContainsOutputStructMany) {
<< "};\n\n";
runtime::Module test_module =
- InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1",
"output2"});
+ InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1",
"output2"}, {});
std::string header_source = test_module->GetSource();
ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
@@ -167,7 +189,7 @@ TEST(InterfaceAPI, ContainsOutputStructSanitised) {
<< "};\n\n";
runtime::Module test_module =
- InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1",
"output-2"});
+ InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1",
"output-2"}, {});
std::string header_source = test_module->GetSource();
ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
@@ -175,7 +197,75 @@ TEST(InterfaceAPI, ContainsOutputStructSanitised) {
TEST(InterfaceAPI, ContainsOutputStructClash) {
runtime::Module test_module =
- InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+",
"output-"});
+ InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+",
"output-"}, {});
+ ASSERT_THROW(test_module->GetSource(), InternalError);
+}
+
+TEST(InterfaceAPI, NoDeviceAPIStructIfNoDevices) {
+ std::stringstream device_struct;
+
+ device_struct << "/*!\n"
+ << " * \\brief Device context pointers for TVM module
\"ultimate_cat_spotter\" \n"
+ << " */\n"
+ << "struct tvmgen_ultimate_cat_spotter_devices {\n"
+ << "};\n\n";
+
+ runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter",
{"input"}, {"output"}, {});
+ std::string header_source = test_module->GetSource();
+
+ ASSERT_THAT(header_source, Not(HasSubstr(device_struct.str())));
+}
+
+TEST(InterfaceAPI, ContainsDeviceStructSingle) {
+ std::stringstream device_struct;
+
+ device_struct << "/*!\n"
+ << " * \\brief Device context pointers for TVM module
\"ultimate_cat_spotter\" \n"
+ << " */\n"
+ << "struct tvmgen_ultimate_cat_spotter_devices {\n"
+ << " void* device;\n"
+ << "};\n\n";
+
+ runtime::Module test_module =
+ InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
{"device"});
+ std::string header_source = test_module->GetSource();
+
+ ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
+}
+
+TEST(InterfaceAPI, ContainsDeviceStructMany) {
+ std::stringstream device_struct;
+
+ device_struct << "struct tvmgen_ultimate_cat_spotter_devices {\n"
+ << " void* device1;\n"
+ << " void* device2;\n"
+ << "};\n\n";
+
+ runtime::Module test_module =
+ InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
{"device1", "device2"});
+ std::string header_source = test_module->GetSource();
+
+ ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
+}
+
+TEST(InterfaceAPI, ContainsDeviceStructSanitised) {
+ std::stringstream device_struct;
+
+ device_struct << "struct tvmgen_ultimate_cat_spotter_devices {\n"
+ << " void* device_1;\n"
+ << " void* device_2;\n"
+ << "};\n\n";
+
+ runtime::Module test_module =
+ InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
{"device+1", "device+2"});
+ std::string header_source = test_module->GetSource();
+
+ ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
+}
+
+TEST(InterfaceAPI, ContainsDeviceStructClash) {
+ runtime::Module test_module =
+ InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
{"device+", "device-"});
ASSERT_THROW(test_module->GetSource(), InternalError);
}
diff --git a/tests/micro/zephyr/test_zephyr_aot.py
b/tests/micro/zephyr/test_zephyr_aot.py
index 7cd32f4..4324570 100644
--- a/tests/micro/zephyr/test_zephyr_aot.py
+++ b/tests/micro/zephyr/test_zephyr_aot.py
@@ -89,7 +89,7 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug):
model_files_path = os.path.join(tar_temp_dir, "include")
os.mkdir(model_files_path)
header_path = generate_c_interface_header(
- lowered.libmod_name, ["input_1"], ["output"],
model_files_path
+ lowered.libmod_name, ["input_1"], ["output"], [],
model_files_path
)
tf.add(header_path, arcname=os.path.relpath(header_path,
tar_temp_dir))
@@ -150,7 +150,7 @@ def test_qemu_make_fail(temp_dir, board, west_cmd,
tvm_debug):
model_files_path = os.path.join(tar_temp_dir, "include")
os.mkdir(model_files_path)
header_path = generate_c_interface_header(
- lowered.libmod_name, ["input_1"], ["output"],
model_files_path
+ lowered.libmod_name, ["input_1"], ["output"], [],
model_files_path
)
tf.add(header_path, arcname=os.path.relpath(header_path,
tar_temp_dir))
test_utils.create_header_file(
diff --git a/tests/micro/zephyr/test_zephyr_armv7m.py
b/tests/micro/zephyr/test_zephyr_armv7m.py
index 2366bad..9364b54 100644
--- a/tests/micro/zephyr/test_zephyr_armv7m.py
+++ b/tests/micro/zephyr/test_zephyr_armv7m.py
@@ -112,7 +112,7 @@ def _generate_project(temp_dir, board, west_cmd, lowered,
build_config, sample,
test_utils.loadCMSIS(model_files_path)
tf.add(model_files_path,
arcname=os.path.relpath(model_files_path, tar_temp_dir))
header_path = generate_c_interface_header(
- lowered.libmod_name, ["input_1"], ["output"],
model_files_path
+ lowered.libmod_name, ["input_1"], ["output"], [],
model_files_path
)
tf.add(header_path, arcname=os.path.relpath(header_path,
tar_temp_dir))
diff --git a/tests/python/contrib/test_ethosu/infra.py
b/tests/python/contrib/test_ethosu/infra.py
index 17d3fad..d37d915 100644
--- a/tests/python/contrib/test_ethosu/infra.py
+++ b/tests/python/contrib/test_ethosu/infra.py
@@ -198,11 +198,16 @@ def _create_test_runner(accel):
prologue="""
uart_init();
EthosuInit();
+
+ struct ethosu_driver* ethos_u = ethosu_reserve_driver();
+ """,
+ epilogue="""
+ ethosu_release_driver(ethos_u);
""",
includes=["uart.h", "ethosu_55.h", "ethosu_mod.h", "hard_fault.h"],
parameters={"ETHOSU_TEST_ROOT": test_root, "NPU_VARIANT": ethosu_macs},
pass_config={
- "relay.ext.ethosu.options": {
+ "relay.ext.ethos-u.options": {
"accelerator_config": accel,
}
},
diff --git a/tests/python/contrib/test_ethosu/test_attr_passing.py
b/tests/python/contrib/test_ethosu/test_attr_passing.py
index 6b99a5c..5aab39a 100644
--- a/tests/python/contrib/test_ethosu/test_attr_passing.py
+++ b/tests/python/contrib/test_ethosu/test_attr_passing.py
@@ -26,9 +26,9 @@ def test_compiler_attr():
config = {
"accelerator_config": "ethos-u55-32",
}
- with tvm.transform.PassContext(opt_level=3,
config={"relay.ext.ethosu.options": config}):
+ with tvm.transform.PassContext(opt_level=3,
config={"relay.ext.ethos-u.options": config}):
with tvm.target.Target("c -device=micro_dev"):
- compiler_attrs =
tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")()
+ compiler_attrs =
tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
accel_config_str = compiler_attrs.accelerator_config
assert accel_config_str == config["accelerator_config"]
@@ -39,7 +39,7 @@ def test_compiler_attr_default():
}
with tvm.transform.PassContext(opt_level=3):
with tvm.target.Target("c -device=micro_dev"):
- compiler_attrs =
tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")()
+ compiler_attrs =
tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")()
accel_config_str = compiler_attrs.accelerator_config
assert accel_config_str == default_config["accelerator_config"]
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py
b/tests/python/contrib/test_ethosu/test_codegen.py
index a5686c8..e29bfa2 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -159,7 +159,7 @@ def test_ethosu_conv2d(accel_type):
ethosu_module = imported_modules[0]
# Verify generated C source
- get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+ get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
cmms = get_cs(ethosu_module)
cmms = bytes.fromhex(cmms)
infra.print_payload(cmms)
@@ -246,7 +246,7 @@ def test_tflite_depthwise_conv2d(
ethosu_module = imported_modules[0]
# Verify generated C source
- get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+ get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
cmms = get_cs(ethosu_module)
cmms = bytes.fromhex(cmms)
@@ -333,7 +333,7 @@ def test_ethosu_pooling(
ethosu_module = imported_modules[0]
# Verify generated C source
- get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+ get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
cmms = get_cs(ethosu_module)
cmms = bytes.fromhex(cmms)
@@ -427,7 +427,7 @@ def test_ethosu_binary_elementwise(
ethosu_module = imported_modules[0]
# Verify generated C source
- get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+ get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
cmms = get_cs(ethosu_module)
cmms = bytes.fromhex(cmms)
@@ -484,7 +484,7 @@ def test_ethosu_left_shift_binary_elemwise(
ethosu_module = imported_modules[0]
# Verify generated C source
- get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+ get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
cmms = get_cs(ethosu_module)
cmms = bytes.fromhex(cmms)
@@ -519,12 +519,12 @@ def test_ethosu_right_shift_binary_elemwise(
ifm, ifm2, ifm_shape[3], ifm2_shape[3], "SHR", ofm_dtype,
reversed_operands
)
- glb_ethosu = relay.GlobalVar("tvmgen_default_ethosu_main_0")
+ glb_ethosu = relay.GlobalVar("tvmgen_default_ethos_u_main_0")
func = (
relay.Function([ifms], shr_op)
.with_attr("Inline", 1)
- .with_attr("Compiler", "ethosu")
- .with_attr("global_symbol", "tvmgen_default_ethosu_main_0")
+ .with_attr("Compiler", "ethos-u")
+ .with_attr("global_symbol", "tvmgen_default_ethos_u_main_0")
.with_attr("Primitive", 1)
)
mod = tvm.IRModule()
@@ -583,7 +583,7 @@ def test_ethosu_right_shift_binary_elemwise(
ethosu_module = imported_modules[0]
# Verify generated C source
- get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs")
+ get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs")
cmms = get_cs(ethosu_module)
cmms = bytes.fromhex(cmms)
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py
b/tests/python/contrib/test_ethosu/test_legalize.py
index 2a84a23..166a965 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -40,7 +40,7 @@ def partition_ethosu_by_table(mod, pattern_table):
wouldn't attempt to offload an operator without full stack support."""
mod = relay.transform.InferType()(mod)
mod = relay.transform.MergeComposite(pattern_table)(mod)
- mod = relay.transform.AnnotateTarget("ethosu")(mod)
+ mod = relay.transform.AnnotateTarget("ethos-u")(mod)
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.InferType()(mod)
mod = relay.transform.PartitionGraph()(mod)
@@ -59,7 +59,7 @@ def test_split_indices_legalize():
def expected_mod_axis1():
expected_ir_string = """
#[version = "0.0.5"]
- def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32])
-> (Tensor[(1, 5, 50, 3), float32],\
+ def @tvmgen_default_ethos_u_main_0(%x: Tensor[(1, 50, 50, 3),
float32]) -> (Tensor[(1, 5, 50, 3), float32],\
Tensor[(1, 15,
50, 3), float32],\
Tensor[(1, 25,
50, 3), float32],\
Tensor[(1, 5,
50, 3), float32]) {
@@ -80,7 +80,7 @@ def test_split_indices_legalize():
def expected_mod_axis2():
expected_ir_string = """
#[version = "0.0.5"]
- def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32])
-> (Tensor[(1, 50, 5, 3), float32],\
+ def @tvmgen_default_ethos_u_main_0(%x: Tensor[(1, 50, 50, 3),
float32]) -> (Tensor[(1, 50, 5, 3), float32],\
Tensor[(1, 50,
15, 3), float32],\
Tensor[(1, 50,
25, 3), float32],\
Tensor[(1, 50,
5, 3), float32]) {
@@ -99,13 +99,13 @@ def test_split_indices_legalize():
return tvm.parser.fromtext(expected_ir_string)
mod_axis1 = tvm.IRModule()
- mod_axis1["tvmgen_default_ethosu_main_0"] = create_graph(1)
+ mod_axis1["tvmgen_default_ethos_u_main_0"] = create_graph(1)
mod_axis1 = legalize.LegalizeSplit()(mod_axis1)
expected_axis1 = expected_mod_axis1()
tvm.ir.assert_structural_equal(mod_axis1, expected_axis1)
mod_axis2 = tvm.IRModule()
- mod_axis2["tvmgen_default_ethosu_main_0"] = create_graph(2)
+ mod_axis2["tvmgen_default_ethos_u_main_0"] = create_graph(2)
mod_axis2 = legalize.LegalizeSplit()(mod_axis2)
expected_axis2 = expected_mod_axis2()
tvm.ir.assert_structural_equal(mod_axis2, expected_axis2)
@@ -127,7 +127,7 @@ def test_split_sections_legalize():
def expected_mod_axis1():
expected_ir_string = """
#[version = "0.0.5"]
- def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32])
-> (Tensor[(1, 10, 50, 3), float32],\
+ def @tvmgen_default_ethos_u_main_0(%x: Tensor[(1, 50, 50, 3),
float32]) -> (Tensor[(1, 10, 50, 3), float32],\
Tensor[(1, 10,
50, 3), float32],\
Tensor[(1, 10,
50, 3), float32],\
Tensor[(1, 10,
50, 3), float32],\
@@ -162,7 +162,7 @@ def test_split_sections_legalize():
def expected_mod_axis2():
expected_ir_string = """
#[version = "0.0.5"]
- def @tvmgen_default_ethosu_main_0(%x: Tensor[(1, 50, 50, 3), float32])
-> (Tensor[(1, 50, 10, 3), float32],\
+ def @tvmgen_default_ethos_u_main_0(%x: Tensor[(1, 50, 50, 3),
float32]) -> (Tensor[(1, 50, 10, 3), float32],\
Tensor[(1, 50,
10, 3), float32],\
Tensor[(1, 50,
10, 3), float32],\
Tensor[(1, 50,
10, 3), float32],\
@@ -195,13 +195,13 @@ def test_split_sections_legalize():
return tvm.parser.fromtext(expected_ir_string)
mod_axis1 = tvm.IRModule()
- mod_axis1["tvmgen_default_ethosu_main_0"] = create_graph(1, 5)
+ mod_axis1["tvmgen_default_ethos_u_main_0"] = create_graph(1, 5)
mod_axis1 = legalize.LegalizeSplit()(mod_axis1)
expected_axis1 = expected_mod_axis1()
tvm.ir.assert_structural_equal(mod_axis1, expected_axis1)
mod_axis2 = tvm.IRModule()
- mod_axis2["tvmgen_default_ethosu_main_0"] = create_graph(2, 5)
+ mod_axis2["tvmgen_default_ethos_u_main_0"] = create_graph(2, 5)
mod_axis2 = legalize.LegalizeSplit()(mod_axis2)
expected_axis2 = expected_mod_axis2()
tvm.ir.assert_structural_equal(mod_axis2, expected_axis2)
@@ -314,7 +314,7 @@ def test_ethosu_conv2d_legalize():
mod, conv_params = test_case[0](*test_case[1])
mod = ethosu.partition_for_ethosu(mod)
mod = legalize.LegalizeConv2D()(mod)
- verify_linear(mod["tvmgen_default_ethosu_main_0"], conv_params)
+ verify_linear(mod["tvmgen_default_ethos_u_main_0"], conv_params)
def test_ethosu_conv2d_legalize_errors():
@@ -457,10 +457,10 @@ def test_tflite_depthwise_conv_2d_legalize(
mod["main"] = bind_params_by_name(mod["main"], params)
mod = partition_ethosu_by_table(mod, depthwise_pattern_table)
- mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
- legalize.DepthwiseConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"]
+ mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+ legalize.DepthwiseConv2DRewriter(),
mod["tvmgen_default_ethos_u_main_0"]
)
- verify(mod["tvmgen_default_ethosu_main_0"])
+ verify(mod["tvmgen_default_ethos_u_main_0"])
@pytest.mark.parametrize("pooling_type", ["MAX", "AVG"])
@@ -552,10 +552,10 @@ def test_tflite_pool2d_legalize(
)
mod = partition_ethosu_by_table(mod, pattern_table)
- mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
- rewriter, mod["tvmgen_default_ethosu_main_0"]
+ mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+ rewriter, mod["tvmgen_default_ethos_u_main_0"]
)
- verify(mod["tvmgen_default_ethosu_main_0"])
+ verify(mod["tvmgen_default_ethos_u_main_0"])
@pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"])
@@ -687,10 +687,10 @@ def test_tflite_binary_elemwise_legalize(
)
mod = partition_ethosu_by_table(mod, pattern_table)
- mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
- rewriter, mod["tvmgen_default_ethosu_main_0"]
+ mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+ rewriter, mod["tvmgen_default_ethos_u_main_0"]
)
- verify(mod["tvmgen_default_ethosu_main_0"])
+ verify(mod["tvmgen_default_ethos_u_main_0"])
@pytest.mark.parametrize(
@@ -740,10 +740,10 @@ def
test_ethosu_left_shift_binary_elemwise_legalize(ifm_shape, ifm2_shape, rever
mod = create_graph()
mod = partition_ethosu_by_table(mod, pattern_table)
- mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite(
- rewriter, mod["tvmgen_default_ethosu_main_0"]
+ mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+ rewriter, mod["tvmgen_default_ethos_u_main_0"]
)
- verify(mod["tvmgen_default_ethosu_main_0"])
+ verify(mod["tvmgen_default_ethos_u_main_0"])
if __name__ == "__main__":
diff --git a/tests/python/contrib/test_ethosu/test_preprocess.py
b/tests/python/contrib/test_ethosu/test_preprocess.py
index f2c7b0a..41831f2 100644
--- a/tests/python/contrib/test_ethosu/test_preprocess.py
+++ b/tests/python/contrib/test_ethosu/test_preprocess.py
@@ -57,7 +57,7 @@ def test_single_io():
mod = tvm.IRModule()
x = relay.var("x", shape=(10, 10))
- glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0")
+ glb_symbol_f1, mod = create_external_func1(mod, "ethos-u", "ethosu_0")
r = relay.Call(glb_symbol_f1, [x])
main = relay.Function([x], r)
mod["main"] = main
@@ -93,7 +93,7 @@ def test_2ins_single_out():
x = relay.var("x", shape=(10, 10))
w0 = relay.var("w0", shape=(10, 10))
- glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0")
+ glb_symbol_f1, mod = create_external_func1(mod, "ethos-u", "ethosu_0")
r = relay.Call(glb_symbol_f1, [x, w0])
main = relay.Function([x, w0], r)
mod["main"] = main
@@ -130,7 +130,7 @@ def test_2ins_single_out():
# concat
ifms = relay.concatenate((x_reshaped, w0_reshaped), 0)
- glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0")
+ glb_symbol_f1, mod = create_external_func1(mod, "ethos-u", "ethosu_0")
r = relay.Call(glb_symbol_f1, [ifms])
main = relay.Function([x, w0], r)
mod["main"] = main
@@ -165,7 +165,7 @@ def test_single_in_2outs():
mod = tvm.IRModule()
x = relay.var("x", shape=(10, 10))
- glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0")
+ glb_symbol_f1, mod = create_external_func1(mod, "ethos-u", "ethosu_0")
pq_tuple = relay.Call(glb_symbol_f1, [x])
p0 = relay.TupleGetItem(pq_tuple, 0)
q0 = relay.TupleGetItem(pq_tuple, 1)
@@ -196,7 +196,7 @@ def test_single_in_2outs():
mod = tvm.IRModule()
x = relay.var("x", shape=(10, 10))
- glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0")
+ glb_symbol_f1, mod = create_external_func1(mod, "ethos-u", "ethosu_0")
ofms = relay.Call(glb_symbol_f1, [x])
# splits
@@ -254,7 +254,7 @@ def test_4ins_2outs():
w1 = relay.var("w1", shape=(10, 10))
w2 = relay.var("w2", shape=(10, 10))
- glb_symbol_f1, mod = create_external_func1(mod, "ethosu", "ethosu_0")
+ glb_symbol_f1, mod = create_external_func1(mod, "ethos-u", "ethosu_0")
pq_tuple = relay.Call(glb_symbol_f1, [x, w0, w1, w2])
p0 = relay.TupleGetItem(pq_tuple, 0)
@@ -313,7 +313,7 @@ def test_4ins_2outs():
ifms = relay.concatenate((x_reshaped, w0_reshaped, w1_reshaped,
w2_reshaped), 0)
# call
- glb_func, mod = create_external_func1(mod, "ethosu", "ethosu_0")
+ glb_func, mod = create_external_func1(mod, "ethos-u", "ethosu_0")
ofms = relay.Call(glb_func, [ifms])
# splits
diff --git a/tests/python/relay/aot/aot_test_utils.py
b/tests/python/relay/aot/aot_test_utils.py
index b214472..7d8a4f0 100644
--- a/tests/python/relay/aot/aot_test_utils.py
+++ b/tests/python/relay/aot/aot_test_utils.py
@@ -116,6 +116,8 @@ class AOTTestRunner(NamedTuple):
Premade Makefile to use from the AOT test folder
prologue: str
Code to prepend to the main function
+ epilogue: str
+ Code to append to the main function
includes: List[str]
Additional includes required to run the AOT test runner
parameters: Dict[str, str]
@@ -126,6 +128,7 @@ class AOTTestRunner(NamedTuple):
makefile: str = "default"
prologue: str = ""
+ epilogue: str = ""
includes: List[str] = []
parameters: Dict[str, str] = {}
pass_config: Dict[str, Any] = {}
@@ -320,6 +323,16 @@ def emit_main_data(main_file, input_map, output_list,
mod_name):
main_file.write(f'#include
"{mangle_name(mod_name,"output_data")}{i}.h"\n')
+def emit_main_device_structs(main_file, devices, mod_name):
+ if devices:
+ main_file.write(
+ f"struct {mangle_name(mod_name, 'devices')} {mangle_name(mod_name,
'devices')} = {{"
+ )
+ for device in devices:
+ main_file.write(f"\t.{device} = {device},\n")
+ main_file.write("};\n")
+
+
def emit_main_data_structs(main_file, input_map, output_list, mod_name):
main_file.write(
f"struct {mangle_name(mod_name, 'inputs')} {mangle_name(mod_name,
'inputs')} = {{"
@@ -359,10 +372,20 @@ def emit_main_data_setup(main_file, input_map,
output_list, mod_name):
main_file.write("};\n")
-def emit_main_c_interface_call(main_file, mod_name):
- main_file.write(
- f'{mangle_name(mod_name,"run")}(&{mangle_name(mod_name,"inputs")},
&{mangle_name(mod_name,"outputs")});\n'
- )
+def emit_main_c_interface_call(main_file, devices, mod_name):
+ if devices:
+ main_file.write(
+ f'{mangle_name(mod_name,"run")}('
+ f'&{mangle_name(mod_name,"inputs")}, '
+ f'&{mangle_name(mod_name,"outputs")}, '
+ f'&{mangle_name(mod_name,"devices")});\n'
+ )
+ else:
+ main_file.write(
+ f'{mangle_name(mod_name,"run")}('
+ f'&{mangle_name(mod_name,"inputs")}, '
+ f'&{mangle_name(mod_name,"outputs")});\n'
+ )
def emit_main_fake_packed_values(main_file):
@@ -446,7 +469,8 @@ def emit_main_init_memory_manager(main_file):
main_file.write("\n")
-def emit_main_epilogue(main_file):
+def emit_main_epilogue(main_file, custom_epilogue):
+ main_file.write(custom_epilogue)
main_file.write(f'printf("{AOT_SUCCESS_TOKEN}\\n");')
main_file.write("return 0;")
main_file.write("}\n")
@@ -469,10 +493,11 @@ def emit_main_micro_include(main_file, mod_name):
def create_main(
test_name,
- models,
+ compiled_models,
output_path,
custom_includes,
custom_prologue,
+ custom_epilogue,
data_linkage,
interface_api,
workspace_bytes,
@@ -484,27 +509,34 @@ def create_main(
emit_main_common_includes(main_file, custom_includes)
if interface_api == "c":
- for model in models:
+ for compiled_model in compiled_models:
+ model = compiled_model.model
emit_main_micro_include(main_file, model.name)
- for model in models:
+ for compiled_model in compiled_models:
+ model = compiled_model.model
emit_main_data(main_file, model.inputs, model.outputs, model.name)
emit_main_prologue(main_file, custom_prologue, workspace_bytes,
data_linkage)
emit_main_init_memory_manager(main_file)
if interface_api == "c":
- for model in models:
+ for compiled_model in compiled_models:
+ model = compiled_model.model
+ devices = compiled_model.executor_factory.get_devices()
+ emit_main_device_structs(main_file, devices, model.name)
emit_main_data_structs(main_file, model.inputs, model.outputs,
model.name)
- emit_main_c_interface_call(main_file, model.name)
+ emit_main_c_interface_call(main_file, devices, model.name)
else:
emit_main_fake_packed_values(main_file)
- for model in models:
+ for compiled_model in compiled_models:
+ model = compiled_model.model
emit_main_data_setup(main_file, model.inputs, model.outputs,
model.name)
emit_main_packed_call(main_file, model.inputs, model.outputs,
model.name)
- for model in models:
+ for compiled_model in compiled_models:
+ model = compiled_model.model
emit_main_compare(main_file, model.outputs,
model.output_tolerance, model.name)
- emit_main_epilogue(main_file)
+ emit_main_epilogue(main_file, custom_epilogue)
def create_header_file(tensor_name, npy_data, output_path, data_linkage):
@@ -646,10 +678,11 @@ def run_and_check(
create_main(
"test.c",
- [compiled_model.model for compiled_model in models],
+ models,
build_path,
runner.includes,
runner.prologue,
+ runner.epilogue,
data_linkage,
interface_api,
workspace_bytes,
diff --git a/tests/python/relay/aot/corstone300.mk
b/tests/python/relay/aot/corstone300.mk
index 553ed84..bf4e388 100644
--- a/tests/python/relay/aot/corstone300.mk
+++ b/tests/python/relay/aot/corstone300.mk
@@ -40,6 +40,7 @@ CC_OPTS = CC=$(CC) AR=$(AR) RANLIB=$(RANLIB)
PKG_CFLAGS = ${PKG_COMPILE_OPTS} \
${CFLAGS} \
-I$(build_dir)/../include \
+ -I${TVM_ROOT}/src/runtime/contrib/ethosu/bare_metal \
-I$(CODEGEN_ROOT)/host/include \
-I${PLATFORM_PATH} \
-I${DRIVER_PATH}/include \
@@ -70,6 +71,7 @@ CMSIS_NN_LIBS = $(wildcard
${CMSIS_PATH}/CMSIS/NN/build/Source/*/*.a)
ifdef ETHOSU_TEST_ROOT
ETHOSU_DRIVER_LIBS = $(wildcard ${DRIVER_PATH}/build/*.a)
+ETHOSU_RUNTIME=$(build_dir)/tvm_ethosu_runtime.o
ETHOSU_INCLUDE=-I$(ETHOSU_TEST_ROOT)
endif
@@ -83,6 +85,10 @@ $(build_dir)/crt_backend_api.o:
$(TVM_ROOT)/src/runtime/crt/common/crt_backend_a
$(QUIET)mkdir -p $(@D)
$(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^
+$(build_dir)/tvm_ethosu_runtime.o:
$(TVM_ROOT)/src/runtime/contrib/ethosu/bare_metal/tvm_ethosu_runtime.c
+ $(QUIET)mkdir -p $(@D)
+ $(QUIET)$(CC) -c $(PKG_CFLAGS) -o $@ $^
+
$(build_dir)/libcodegen.a: $(CODEGEN_SRCS)
$(QUIET)cd $(abspath $(CODEGEN_ROOT)/host/src) && $(CC) -c
$(PKG_CFLAGS) $(CODEGEN_SRCS)
$(QUIET)$(AR) -cr $(abspath $(build_dir)/libcodegen.a) $(CODEGEN_OBJS)
@@ -100,7 +106,7 @@ ${build_dir}/libuart.a: $(UART_SRCS)
$(QUIET)$(AR) -cr $(abspath $(build_dir)/libuart.a) $(abspath
$(build_dir))/libuart/*.o
$(QUIET)$(RANLIB) $(abspath $(build_dir)/libuart.a)
-$(build_dir)/aot_test_runner: $(build_dir)/test.c
$(build_dir)/crt_backend_api.o $(build_dir)/stack_allocator.o
${build_dir}/libcmsis_startup.a ${build_dir}/libuart.a
$(build_dir)/libcodegen.a $(CMSIS_NN_LIBS) $(ETHOSU_DRIVER_LIBS)
+$(build_dir)/aot_test_runner: $(build_dir)/test.c
$(build_dir)/crt_backend_api.o $(build_dir)/stack_allocator.o
${build_dir}/libcmsis_startup.a ${build_dir}/libuart.a
$(build_dir)/libcodegen.a $(CMSIS_NN_LIBS) $(ETHOSU_DRIVER_LIBS)
$(ETHOSU_RUNTIME)
$(QUIET)mkdir -p $(@D)
$(QUIET)$(CC) $(PKG_CFLAGS) $(ETHOSU_INCLUDE) -o $@ -Wl,--whole-archive
$^ -Wl,--no-whole-archive $(PKG_LDFLAGS)