areusch commented on code in PR #11044: URL: https://github.com/apache/tvm/pull/11044#discussion_r852425539
########## include/tvm/runtime/crt/aot_executor.h: ########## @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file aot_executor.h + * \brief AoT Executor + */ +#ifndef TVM_RUNTIME_CRT_AOT_EXECUTOR_H_ +#define TVM_RUNTIME_CRT_AOT_EXECUTOR_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include <dlpack/dlpack.h> +#include <tvm/runtime/crt/internal/common/ndarray.h> +#include <tvm/runtime/metadata.h> + +typedef struct TVMMetadata TVMMetadata; + +typedef struct TVMAotExecutor { + /*! \brief The top-level metadata structure */ Review Comment: nit: maybe note that this one comes from the compiled artifact ########## include/tvm/runtime/crt/aot_executor.h: ########## @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file aot_executor.h + * \brief AoT Executor + */ +#ifndef TVM_RUNTIME_CRT_AOT_EXECUTOR_H_ +#define TVM_RUNTIME_CRT_AOT_EXECUTOR_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include <dlpack/dlpack.h> +#include <tvm/runtime/crt/internal/common/ndarray.h> +#include <tvm/runtime/metadata.h> + +typedef struct TVMMetadata TVMMetadata; + +typedef struct TVMAotExecutor { + /*! \brief The top-level metadata structure */ + TVMMetadata* metadata; + /*! \brief The code module that contains both host and device code */ Review Comment: for the time being host == device, so maybe just ```suggestion /*! \brief The code module that contains the compiled model */ ``` ########## include/tvm/runtime/crt/aot_executor.h: ########## @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file aot_executor.h + * \brief AoT Executor + */ +#ifndef TVM_RUNTIME_CRT_AOT_EXECUTOR_H_ +#define TVM_RUNTIME_CRT_AOT_EXECUTOR_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include <dlpack/dlpack.h> +#include <tvm/runtime/crt/internal/common/ndarray.h> +#include <tvm/runtime/metadata.h> + +typedef struct TVMMetadata TVMMetadata; + +typedef struct TVMAotExecutor { + /*! \brief The top-level metadata structure */ + TVMMetadata* metadata; + /*! \brief The code module that contains both host and device code */ + TVMModuleHandle module_handle; + /*! \brief The device type */ + DLDevice device; + /*! \brief List of allocated arguments, input(s), output(s), and pool(s)*/ + TVMNDArray* args; + int64_t num_args; +} TVMAotExecutor; + +/*! + * \brief Allocate a new AotExecutor with TVMPlatformMemoryAllocate and initialize it. + * + * \param module_handle TVM Module that exposes the functions to call. + * \param devices runtime execution device. + * \param executor Pointer which receives a pointer to the newly-created instance. + * \return 0 if successful. + */ +int TVMAotExecutor_Create(TVMModuleHandle module_handle, const DLDevice* devices, + TVMAotExecutor** executor); + +int TVMAotExecutor_Release(TVMAotExecutor* executor, const DLDevice device); Review Comment: do you mind adding docstrings to these? ########## src/runtime/crt/aot_executor/aot_executor.c: ########## @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// LINT_C_FILE + +/*! + * \file aot_executor.c + * \brief implement AoT executor in C + */ + +#include <string.h> +#include <tvm/runtime/c_runtime_api.h> +#include <tvm/runtime/crt/aot_executor.h> +#include <tvm/runtime/crt/logging.h> +#include <tvm/runtime/crt/module.h> +#include <tvm/runtime/crt/packed_func.h> +#include <tvm/runtime/crt/page_allocator.h> + +static void DumpMetadata(TVMMetadata* md) { + LOG_DEBUG("%s:\n", __FUNCTION__); + LOG_DEBUG("\tmod_name=%s\n", md->mod_name); + LOG_DEBUG("\tversion=%ld\n", md->version); + LOG_DEBUG("\tnum_inputs=%ld\n", md->num_inputs); + LOG_DEBUG("\tnum_outputs=%ld\n", md->num_outputs); + LOG_DEBUG("\tnum_pools=%ld\n", md->num_pools); + + int i; + + for (i = 0; i < md->num_inputs; ++i) { + LOG_DEBUG("\tinput[%d]: %s\n", i, md->inputs[i].name); + } + + for (i = 0; i < md->num_outputs; ++i) { + LOG_DEBUG("\toutput[%d]: %s\n", i, md->outputs[i].name); + } + + for (i = 0; i < md->num_pools; ++i) { + LOG_DEBUG("\tpools[%d]: %s\n", i, md->pools[i].name); + } +} + +int TVMAotExecutor_GetNumInputs(TVMAotExecutor* executor) { return executor->metadata->num_inputs; } + +int TVMAotExecutor_GetNumOutputs(TVMAotExecutor* executor) { + return executor->metadata->num_outputs; +} + +int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name) { + int i; + int rv = -1; + + TVMMetadata* md = executor->metadata; + for (i = 0; i < md->num_inputs; ++i) { + if (!strcmp(md->inputs[i].name, name)) { + rv = i; + break; + } + } + CHECK_GE(rv, 0, "cannot find '%s' among input.", name); + return rv; +} + +int TVMAotExecutor_Run(TVMAotExecutor* executor) { + const char* tvm_main_suffix = "___tvm_main__"; + char tvm_main_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; + const size_t max_strlen = TVM_CRT_MAX_STRLEN_FUNCTION_NAME; + + { + size_t len = strnlen(executor->metadata->mod_name, max_strlen); + len += strnlen(tvm_main_suffix, max_strlen); + + CHECK_LT(len, max_strlen, "tvm_main name too long %ld\n", len); + } + + // create main function name string, e.g. "tvmgen_default___tvm_main__" + snprintf(tvm_main_name, sizeof(tvm_main_name), "%s%s", + executor->metadata->mod_name, tvm_main_suffix); + + TVMPackedFunc tvm_main; + TVMArgs temp_args; + + CHECK_LE(executor->num_args, TVM_CRT_MAX_ARGS, "too many args %ld\n", executor->num_args); + + int i; + for (i = 0; i < executor->num_args; ++i) { + temp_args.values[i].v_handle = &executor->args[i].dl_tensor; + temp_args.tcodes[i] = kTVMDLTensorHandle; + } + temp_args.values_count = executor->num_args; + + int status = + TVMPackedFunc_InitModuleFunc(&tvm_main, executor->module_handle, tvm_main_name, &temp_args); + + if (status != 0) { + return status; + } + + CHECK_EQ(tvm_main.Call(&tvm_main), 0, "call to %s failed", tvm_main_name); + + return 0; +} + +int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle, + const DLDevice* device) { + executor->module_handle = module_handle; + executor->device = *device; + + // get a pointer to the PackedFunc get_c_metadata() which gives us access to the top-level + // metadata structure + TVMPackedFunc get_c_metadata; + TVMArgs temp_args; + temp_args.values_count = 0; + + int status = TVMPackedFunc_InitModuleFunc(&get_c_metadata, executor->module_handle, + "get_c_metadata", &temp_args); + if (status != 0) { + return status; + } + + CHECK_EQ(get_c_metadata.Call(&get_c_metadata), 0, "get_c_metadata"); + + // save the returned pointer to the top-level metadata + executor->metadata = (TVMMetadata*)get_c_metadata.ret_value.values[0].v_handle; + + TVMMetadata* md = executor->metadata; + + DumpMetadata(md); + + executor->num_args = md->num_inputs + md->num_outputs + md->num_pools; + + tvm_crt_error_t err = TVMPlatformMemoryAllocate(executor->num_args * sizeof(*executor->args), + executor->device, (void**)(&executor->args)); + if (err != kTvmErrorNoError) { + return -1; Review Comment: i've been trying to pass `err` back as much as possible. i believe the convention is just 0 on success and non-zero on error. what do you think? ########## tests/python/unittest/test_crt.py: ########## @@ -149,20 +151,96 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) { with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): factory = tvm.relay.build(relay_mod, target=TARGET, runtime=runtime) - with _make_session(temp_dir, factory) as sess: - graph_mod = tvm.micro.create_local_graph_executor( - factory.get_graph_json(), sess.get_system_lib(), sess.device - ) + def do_test(graph_mod): + A_data = tvm.nd.array(np.array([2, 3], dtype="uint8"), device=sess.device) assert (A_data.numpy() == np.array([2, 3])).all() B_data = tvm.nd.array(np.array([4, 7], dtype="uint8"), device=sess.device) assert (B_data.numpy() == np.array([4, 7])).all() + assert graph_mod.get_input_index("a") == 0 + assert graph_mod.get_input_index("b") == 1 + graph_mod.run(a=A_data, b=B_data) out = graph_mod.get_output(0) assert (out.numpy() == np.array([6, 10])).all() + with _make_session(temp_dir, factory) as sess: + + graph_mod_local = tvm.micro.create_local_graph_executor( + factory.get_graph_json(), + sess.get_system_lib(), + sess.device) + + do_test(graph_mod_local) + + graph_mod = tvm.contrib.graph_executor.create( + factory.get_graph_json(), + sess.get_system_lib(), + sess.device) + + do_test(graph_mod) + + + [email protected]_micro +def test_aot_executor(): + """Test use of the AOT executor with microTVM.""" + + ws_root = pathlib.Path(os.path.dirname(__file__) + "/micro-workspace") + if ws_root.exists(): + shutil.rmtree(ws_root) + with tvm.contrib.utils.TempDirectory.set_keep_for_debug(): + temp_dir = tvm.contrib.utils.tempdir(ws_root.resolve()) + relay_mod = tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) { + %0 = %a + %b; + %0 + }""" + ) + + runtime = Runtime("crt", {"system-lib": True}) + executor = Executor("aot") + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + factory = tvm.relay.build(relay_mod, target=TARGET, runtime=runtime, executor=executor) + + def do_test(): + aot_executor = tvm.runtime.executor.aot_executor.AotModule( + sess._rpc.get_function("tvm.aot_executor.create")(sess.get_system_lib(), sess.device)) + + assert aot_executor.get_input_index("a") == 0 + assert aot_executor.get_input_index("b") == 1 + + assert aot_executor.get_num_inputs() == 2 + assert aot_executor.get_num_outputs() == 1 + + A_np = np.array([[2, 3]], dtype="uint8") + B_np = np.array([[4, 7]], dtype="uint8") + + A_data = aot_executor.get_input("a").copyfrom(A_np) + B_data = aot_executor.get_input("b").copyfrom(B_np) + + print("A_data: " + str(A_data)) + print("B_data: " + str(B_data)) + + aot_executor.run() + + out = aot_executor.get_output(0) + print("out: " + str(out)) Review Comment: nit: rm prints or use logging ########## src/target/source/source_module.cc: ########## @@ -771,11 +771,41 @@ class MetadataSerializer : public AttrVisitor { std::vector<bool> is_defining_struct_; }; +namespace { +runtime::Module CreateAotMetadataModule(runtime::metadata::Metadata aot_metadata) { + MetadataSerializer serializer; + serializer.CodegenMetadata(aot_metadata); + std::stringstream lookup_func; + lookup_func << "#ifdef __cplusplus\n" + << "extern \"C\"\n" + << "#endif\n"; + + lookup_func << "TVM_DLL int32_t " << ::tvm::runtime::symbol::tvm_get_c_metadata Review Comment: a note for a future PR--we should technically be prefixing the C symbol name with the mod_name, which allows us to compile multiple models into one program. but this would require us to change the logic in CSourceModuleCreate, so let's defer that. ########## src/runtime/crt/aot_executor/aot_executor.c: ########## @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// LINT_C_FILE + +/*! + * \file aot_executor.c + * \brief implement AoT executor in C + */ + +#include <string.h> +#include <tvm/runtime/c_runtime_api.h> +#include <tvm/runtime/crt/aot_executor.h> +#include <tvm/runtime/crt/logging.h> +#include <tvm/runtime/crt/module.h> +#include <tvm/runtime/crt/packed_func.h> +#include <tvm/runtime/crt/page_allocator.h> + +static void DumpMetadata(TVMMetadata* md) { + LOG_DEBUG("%s:\n", __FUNCTION__); + LOG_DEBUG("\tmod_name=%s\n", md->mod_name); + LOG_DEBUG("\tversion=%ld\n", md->version); + LOG_DEBUG("\tnum_inputs=%ld\n", md->num_inputs); + LOG_DEBUG("\tnum_outputs=%ld\n", md->num_outputs); + LOG_DEBUG("\tnum_pools=%ld\n", md->num_pools); + + int i; + + for (i = 0; i < md->num_inputs; ++i) { + LOG_DEBUG("\tinput[%d]: %s\n", i, md->inputs[i].name); + } + + for (i = 0; i < md->num_outputs; ++i) { + LOG_DEBUG("\toutput[%d]: %s\n", i, md->outputs[i].name); + } + + for (i = 0; i < md->num_pools; ++i) { + LOG_DEBUG("\tpools[%d]: %s\n", i, md->pools[i].name); + } +} + +int TVMAotExecutor_GetNumInputs(TVMAotExecutor* executor) { return executor->metadata->num_inputs; } + +int TVMAotExecutor_GetNumOutputs(TVMAotExecutor* executor) { + return executor->metadata->num_outputs; +} + +int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name) { + int i; + int rv = -1; + + TVMMetadata* md = executor->metadata; + for (i = 0; i < md->num_inputs; ++i) { + if (!strcmp(md->inputs[i].name, name)) { + rv = i; + break; + } + } + CHECK_GE(rv, 0, "cannot find '%s' among input.", name); + return rv; +} + +int TVMAotExecutor_Run(TVMAotExecutor* executor) { + const char* tvm_main_suffix = "___tvm_main__"; + char tvm_main_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; + const size_t max_strlen = TVM_CRT_MAX_STRLEN_FUNCTION_NAME; + + { + size_t len = strnlen(executor->metadata->mod_name, max_strlen); + len += strnlen(tvm_main_suffix, max_strlen); + + CHECK_LT(len, max_strlen, "tvm_main name too long %ld\n", len); + } + + // create main function name string, e.g. "tvmgen_default___tvm_main__" + snprintf(tvm_main_name, sizeof(tvm_main_name), "%s%s", + executor->metadata->mod_name, tvm_main_suffix); + + TVMPackedFunc tvm_main; + TVMArgs temp_args; + + CHECK_LE(executor->num_args, TVM_CRT_MAX_ARGS, "too many args %ld\n", executor->num_args); + + int i; + for (i = 0; i < executor->num_args; ++i) { + temp_args.values[i].v_handle = &executor->args[i].dl_tensor; + temp_args.tcodes[i] = kTVMDLTensorHandle; + } + temp_args.values_count = executor->num_args; + + int status = + TVMPackedFunc_InitModuleFunc(&tvm_main, executor->module_handle, tvm_main_name, &temp_args); + + if (status != 0) { + return status; + } + + CHECK_EQ(tvm_main.Call(&tvm_main), 0, "call to %s failed", tvm_main_name); + + return 0; +} + +int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle, + const DLDevice* device) { Review Comment: ```suggestion const DLDevice* devices) { ``` ########## include/tvm/runtime/crt/aot_executor.h: ########## @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file aot_executor.h + * \brief AoT Executor + */ +#ifndef TVM_RUNTIME_CRT_AOT_EXECUTOR_H_ +#define TVM_RUNTIME_CRT_AOT_EXECUTOR_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include <dlpack/dlpack.h> +#include <tvm/runtime/crt/internal/common/ndarray.h> +#include <tvm/runtime/metadata.h> + +typedef struct TVMMetadata TVMMetadata; + +typedef struct TVMAotExecutor { + /*! \brief The top-level metadata structure */ + TVMMetadata* metadata; + /*! \brief The code module that contains both host and device code */ + TVMModuleHandle module_handle; + /*! \brief The device type */ + DLDevice device; + /*! \brief List of allocated arguments, input(s), output(s), and pool(s)*/ + TVMNDArray* args; + int64_t num_args; +} TVMAotExecutor; + +/*! + * \brief Allocate a new AotExecutor with TVMPlatformMemoryAllocate and initialize it. + * + * \param module_handle TVM Module that exposes the functions to call. + * \param devices runtime execution device. Review Comment: i think we expect this to be kDLCPU 0, right? want to add that to comment? ########## src/runtime/crt/aot_executor/aot_executor.c: ########## @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// LINT_C_FILE + +/*! + * \file aot_executor.c + * \brief implement AoT executor in C + */ + +#include <string.h> +#include <tvm/runtime/c_runtime_api.h> +#include <tvm/runtime/crt/aot_executor.h> +#include <tvm/runtime/crt/logging.h> +#include <tvm/runtime/crt/module.h> +#include <tvm/runtime/crt/packed_func.h> +#include <tvm/runtime/crt/page_allocator.h> + +static void DumpMetadata(TVMMetadata* md) { + LOG_DEBUG("%s:\n", __FUNCTION__); + LOG_DEBUG("\tmod_name=%s\n", md->mod_name); + LOG_DEBUG("\tversion=%ld\n", md->version); + LOG_DEBUG("\tnum_inputs=%ld\n", md->num_inputs); + LOG_DEBUG("\tnum_outputs=%ld\n", md->num_outputs); + LOG_DEBUG("\tnum_pools=%ld\n", md->num_pools); + + int i; + + for (i = 0; i < md->num_inputs; ++i) { + LOG_DEBUG("\tinput[%d]: %s\n", i, md->inputs[i].name); + } + + for (i = 0; i < md->num_outputs; ++i) { + LOG_DEBUG("\toutput[%d]: %s\n", i, md->outputs[i].name); + } + + for (i = 0; i < md->num_pools; ++i) { + LOG_DEBUG("\tpools[%d]: %s\n", i, md->pools[i].name); + } +} + +int TVMAotExecutor_GetNumInputs(TVMAotExecutor* executor) { return executor->metadata->num_inputs; } + +int TVMAotExecutor_GetNumOutputs(TVMAotExecutor* executor) { + return executor->metadata->num_outputs; +} + +int TVMAotExecutor_GetInputIndex(TVMAotExecutor* executor, const char* name) { + int i; + int rv = -1; + + TVMMetadata* md = executor->metadata; + for (i = 0; i < md->num_inputs; ++i) { + if (!strcmp(md->inputs[i].name, name)) { + rv = i; + break; + } + } + CHECK_GE(rv, 0, "cannot find '%s' among input.", name); + return rv; +} + +int TVMAotExecutor_Run(TVMAotExecutor* executor) { + const char* tvm_main_suffix = "___tvm_main__"; + char tvm_main_name[TVM_CRT_MAX_STRLEN_FUNCTION_NAME]; + const size_t max_strlen = TVM_CRT_MAX_STRLEN_FUNCTION_NAME; + + { + size_t len = strnlen(executor->metadata->mod_name, max_strlen); + len += strnlen(tvm_main_suffix, max_strlen); + + CHECK_LT(len, max_strlen, "tvm_main name too long %ld\n", len); + } + + // create main function name string, e.g. "tvmgen_default___tvm_main__" + snprintf(tvm_main_name, sizeof(tvm_main_name), "%s%s", + executor->metadata->mod_name, tvm_main_suffix); + + TVMPackedFunc tvm_main; + TVMArgs temp_args; + + CHECK_LE(executor->num_args, TVM_CRT_MAX_ARGS, "too many args %ld\n", executor->num_args); + + int i; + for (i = 0; i < executor->num_args; ++i) { + temp_args.values[i].v_handle = &executor->args[i].dl_tensor; + temp_args.tcodes[i] = kTVMDLTensorHandle; + } + temp_args.values_count = executor->num_args; + + int status = + TVMPackedFunc_InitModuleFunc(&tvm_main, executor->module_handle, tvm_main_name, &temp_args); + + if (status != 0) { + return status; + } + + CHECK_EQ(tvm_main.Call(&tvm_main), 0, "call to %s failed", tvm_main_name); + + return 0; +} + +int TVMAotExecutor_Init(TVMAotExecutor* executor, TVMModuleHandle module_handle, + const DLDevice* device) { + executor->module_handle = module_handle; + executor->device = *device; Review Comment: ```suggestion executor->device = devices[0]; ``` ########## src/runtime/crt/aot_executor_module/aot_executor_module.c: ########## @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// LINT_C_FILE + +/*! + * \file aot_executor_module.c + * \brief wrap aot_executor into a TVMModule for use with RPC. + */ + +#include <stdio.h> +#include <tvm/runtime/crt/aot_executor.h> +#include <tvm/runtime/crt/aot_executor_module.h> +#include <tvm/runtime/crt/func_registry.h> +#include <tvm/runtime/crt/module.h> + +typedef struct { + TVMModule mod; + TVMAotExecutor* executor; +} AotExecutorModule; + +static AotExecutorModule aot_executor; + +int32_t TVMAotExecutorModule_Create(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values, + int* ret_tcodes, void* resource_handle) { + if (aot_executor.executor != NULL) { + return kTvmErrorExecutorModuleAlreadyCreated; + } + + if (nargs != 2) { + return kTvmErrorFunctionCallNumArguments; + } + + if (tcodes[0] != kTVMModuleHandle || tcodes[1] != kDLDevice) { + return kTvmErrorFunctionCallWrongArgType; + } + + DLDevice dev = args[1].v_device; + + if (dev.device_type != kDLCPU) { + return kTvmErrorExecutorModuleBadContext; + } + + TVMAotExecutor_Create(args[0].v_handle, &dev, &aot_executor.executor); + + TVMModuleHandle out_mod; + int ret_value = TVMModCreateFromCModule(&aot_executor.mod, &out_mod); Review Comment: to avoid creating confusion between `ret_values`, suggest naming this `err` or something different than `ret_value` ########## src/target/metadata_module.cc: ########## @@ -119,6 +78,52 @@ static runtime::metadata::Metadata ConvertMetaData( return runtime::metadata::Metadata(std::move(n)); } +static runtime::Module CreateCrtMetadataModule( Review Comment: could you move this up where it was above ConvertMetaData so we can see if there are any diffs here? ########## src/runtime/crt/aot_executor_module/aot_executor_module.c: ########## @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// LINT_C_FILE + +/*! + * \file aot_executor_module.c + * \brief wrap aot_executor into a TVMModule for use with RPC. + */ + +#include <stdio.h> +#include <tvm/runtime/crt/aot_executor.h> +#include <tvm/runtime/crt/aot_executor_module.h> +#include <tvm/runtime/crt/func_registry.h> +#include <tvm/runtime/crt/module.h> + +typedef struct { + TVMModule mod; + TVMAotExecutor* executor; +} AotExecutorModule; + +static AotExecutorModule aot_executor; + +int32_t TVMAotExecutorModule_Create(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values, + int* ret_tcodes, void* resource_handle) { + if (aot_executor.executor != NULL) { + return kTvmErrorExecutorModuleAlreadyCreated; + } + + if (nargs != 2) { + return kTvmErrorFunctionCallNumArguments; + } + + if (tcodes[0] != kTVMModuleHandle || tcodes[1] != kDLDevice) { + return kTvmErrorFunctionCallWrongArgType; + } + + DLDevice dev = args[1].v_device; + + if (dev.device_type != kDLCPU) { + return kTvmErrorExecutorModuleBadContext; + } + + TVMAotExecutor_Create(args[0].v_handle, &dev, &aot_executor.executor); + + TVMModuleHandle out_mod; + int ret_value = TVMModCreateFromCModule(&aot_executor.mod, &out_mod); + if (ret_value != 0) { + ret_tcodes[0] = kTVMNullptr; + TVMAotExecutor_Release(aot_executor.executor, dev); + return ret_value; + } + + ret_values[0].v_handle = out_mod; + ret_tcodes[0] = kTVMModuleHandle; + return kTvmErrorNoError; +} + +int32_t TVMAotExecutorModule_NotImplemented(TVMValue* args, int* tcodes, int nargs, + TVMValue* ret_values, int* ret_tcodes, + void* resource_handle) { + return kTvmErrorFunctionCallNotImplemented; +} + +int32_t TVMAotExecutorModule_GetInput(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values, + int* ret_tcodes, void* resource_handle) { + int index = TVMAotExecutor_GetInputIndex(aot_executor.executor, args[0].v_str); + + if (index < 0) { + return kTvmErrorExecutorModuleNoSuchInput; + } + + ret_values[0].v_handle = (void*)&aot_executor.executor->args[index].dl_tensor; + ret_tcodes[0] = kTVMNDArrayHandle; + + return 0; +} + +int32_t TVMAotExecutorModule_GetOutput(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values, + int* ret_tcodes, void* resource_handle) { + if (nargs != 1) { + return kTvmErrorFunctionCallNumArguments; + } + + if (args[0].v_int64 > TVMAotExecutor_GetNumOutputs(aot_executor.executor)) { + return kTvmErrorFunctionCallInvalidArg; + } + + // index past the input entries + int64_t idx = args[0].v_int64 + TVMAotExecutor_GetNumInputs(aot_executor.executor); + + ret_values[0].v_handle = (void*)&aot_executor.executor->args[idx].dl_tensor; + ret_tcodes[0] = kTVMNDArrayHandle; + + return 0; +} + +int32_t TVMAotExecutorModule_GetInputIndex(TVMValue* args, int* tcodes, int nargs, + TVMValue* ret_values, int* ret_tcodes, + void* resource_handle) { + if (nargs != 1) { + return kTvmErrorFunctionCallNumArguments; + } + + int index = TVMAotExecutor_GetInputIndex(aot_executor.executor, args[0].v_str); + + if (index < 0) { + return kTvmErrorExecutorModuleNoSuchInput; + } + + ret_values[0].v_int64 = index; + ret_tcodes[0] = kTVMArgInt; + return 0; +} + +int32_t TVMAotExecutorModule_GetNumInputs(TVMValue* args, int* tcodes, int nargs, + TVMValue* ret_values, int* ret_tcodes, + void* resource_handle) { + if (nargs != 0) { + return kTvmErrorFunctionCallNumArguments; + } + + ret_values[0].v_int64 = TVMAotExecutor_GetNumInputs(aot_executor.executor); + ret_tcodes[0] = kTVMArgInt; + return 0; +} + +int32_t TVMAotExecutorModule_GetNumOutputs(TVMValue* args, int* tcodes, int nargs, + TVMValue* ret_values, int* ret_tcodes, + void* resource_handle) { + if (nargs != 0) { + return kTvmErrorFunctionCallNumArguments; + } + + ret_values[0].v_int64 = TVMAotExecutor_GetNumOutputs(aot_executor.executor); + ret_tcodes[0] = kTVMArgInt; + return 0; +} + +int32_t TVMAotExecutorModule_Run(TVMValue* args, int* tcodes, int nargs, TVMValue* ret_values, + int* ret_tcodes, void* resource_handle) { + if (nargs != 0) { + return kTvmErrorFunctionCallNumArguments; + } + + return TVMAotExecutor_Run(aot_executor.executor); +} + +static const TVMBackendPackedCFunc aot_executor_registry_funcs[] = { + &TVMAotExecutorModule_GetInput, // get_input + &TVMAotExecutorModule_GetInputIndex, // get_input_index + &TVMAotExecutorModule_NotImplemented, // get_input_info (do not implement) + &TVMAotExecutorModule_GetNumInputs, // get_num_inputs + &TVMAotExecutorModule_GetNumOutputs, // get_num_outputs + &TVMAotExecutorModule_GetOutput, // get_output + &TVMAotExecutorModule_NotImplemented, // load_params (do not implement) + &TVMAotExecutorModule_Run, // run + &TVMAotExecutorModule_NotImplemented, // set_input + &TVMAotExecutorModule_NotImplemented, // share_params (do not implement) +}; + +static const TVMFuncRegistry aot_executor_registry = { + "\x08get_input\0" Review Comment: i think this should be ```suggestion "\x0aget_input\0" ``` since there are 10 functions now -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
