masahi commented on a change in pull request #10283:
URL: https://github.com/apache/tvm/pull/10283#discussion_r817085788
##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -857,11 +905,30 @@ class AOTExecutorCodegen : public MixedModeVisitor {
ICHECK(target_host_.defined()) << "require a target_host to be given for
AOT codegen";
VLOG(1) << "target host: " << target_host_->ToDebugString();
+ Runtime runtime_config =
mod->GetAttr<Runtime>(tvm::attr::kRuntime).value();
Executor executor_config =
mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
String interface_api =
executor_config->GetAttr<String>("interface-api").value_or("packed");
Integer workspace_byte_alignment =
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
use_unpacked_api_ =
executor_config->GetAttr<Bool>("unpacked-api").value_or(Bool(false));
+ use_call_cpacked_ = !use_unpacked_api_;
+
+ // Validate choice of use_unpacked_api_ and use_call_cpacked_
+ if (runtime_config->name == kTvmRuntimeCrt) {
+ CHECK(interface_api == "packed" || static_cast<bool>(use_unpacked_api_)
== true)
+ << "Either need interface_api == \"packed\" (got: " << interface_api
+ << ") or unpacked-api == true (got: " << use_unpacked_api_
+ << ") when targeting c runtime";
+ } else if (runtime_config->name == kTvmRuntimeCpp) {
+ CHECK(static_cast<bool>(use_unpacked_api_) == false &&
Review comment:
Should use `ICHECK`, also at L918
##########
File path: src/runtime/aot_executor/aot_executor.h
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Defines an implementation of Module-based Model Runtime Interface
that works with
+ * Ahead-of-Time compilation.
+ * \file aot_executor.h
+ */
+#ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_
+#define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_
+
+#include <tvm/runtime/metadata.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
+
+#include <string>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+
+class TVM_DLL AotExecutor : public ModuleNode {
+ public:
+ /*!
+ * \brief Implements member function lookup for this Module for the frontend.
+ * \param name The name of the function.
+ * \param sptr_to_self The pointer to the module node.
+ * \return The corresponding member function.
+ */
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) override;
+
+ /*!
+ * \return The type key of the executor.
+ */
+ const char* type_key() const final { return "AotExecutor"; }
+
+ void Run();
+
+ /*!
+ * \brief Initialize the AOT executor with metadata, runtime::Module, and
device.
+ * \param module The module containing the compiled functions for the host
+ * processor.
+ * \param devs The device of the host and devices where graph nodes will be
Review comment:
There is no "graph".
And only the first element of `devs` is ever used in `aot_executor.cc`.
##########
File path: src/target/source/source_module.cc
##########
@@ -539,6 +801,32 @@ runtime::Module CreateCSourceCrtMetadataModule(const
Array<runtime::Module>& mod
return std::move(csrc_metadata_module);
}
+runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata
metadata) {
+ MetadataSerializer serializer;
+ serializer.CodegenMetadata(metadata);
+ std::stringstream lookup_func;
+ lookup_func << "#ifdef __cplusplus\n"
+ << "extern \"C\"\n"
+ << "#endif\n";
+
+ lookup_func << "TVM_DLL int32_t get_c_metadata(TVMValue* arg_values, int*
arg_tcodes, int "
+ "num_args, TVMValue* ret_values, int* ret_tcodes, void*
resource_handle) {"
+ << std::endl;
+ lookup_func << " ret_values[0].v_handle = (void*) &" <<
MetadataSerializer::kGlobalSymbol
+ << ";" << std::endl;
+ lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl;
+ lookup_func << " return 0;" << std::endl;
+ lookup_func << "};" << std::endl;
+
+ auto mod = MetadataModuleCreate(metadata);
+ std::vector<String> func_names{"get_c_metadata"};
+ // definer.GetOutput() +
Review comment:
Remove?
##########
File path: src/target/source/codegen_source_base.h
##########
@@ -157,6 +158,19 @@ runtime::Module CSourceModuleCreate(const String& code,
const String& fmt,
const Array<String>& func_names,
const Array<String>& const_vars = {});
+/*!
+ * \brief Wrap the submodules in a metadata module.
+ * \param params The variable to constant mapping that is collected by the host
+ * module.
+ * \param target_module The main TIR-lowered internal runtime module
+ * \param modules All the external modules that needs to be imported inside
the metadata module(s).
+ * \param target The target that all the modules are compiled for
Review comment:
`metadata` not documented
##########
File path: src/tir/transforms/legalize_packed_calls.cc
##########
@@ -75,6 +75,12 @@ class PackedCallLegalizer : public StmtExprMutator {
new_stmts.push_back(tir::Evaluate(
tvm::tir::Call(DataType::Handle(),
tvm::tir::builtin::tvm_struct_set(),
{sid_array, 0, tir::builtin::kArrData,
call->args[i]})));
+ new_stmts.push_back(tir::Evaluate(
+ tvm::tir::Call(DataType::Handle(),
tvm::tir::builtin::tvm_struct_set(),
+ {sid_array, 0, tir::builtin::kArrDeviceType,
kDLCPU})));
+ new_stmts.push_back(tir::Evaluate(
+ tvm::tir::Call(DataType::Handle(),
tvm::tir::builtin::tvm_struct_set(),
+ {sid_array, 0, tir::builtin::kArrDeviceId,
0})));
Review comment:
What are these changes for?
##########
File path: src/target/source/source_module.cc
##########
@@ -539,6 +792,32 @@ runtime::Module CreateCSourceCrtMetadataModule(const
Array<runtime::Module>& mod
return std::move(csrc_metadata_module);
}
+runtime::Module CreateCSourceCppMetadataModule(runtime::metadata::Metadata
metadata) {
+ MetadataSerializer serializer;
+ serializer.CodegenMetadata(metadata);
+ std::stringstream lookup_func;
+ lookup_func << "#ifdef __cplusplus\n"
+ << "extern \"C\"\n"
+ << "#endif\n";
+
+ lookup_func << "TVM_DLL int32_t get_c_metadata(TVMValue* arg_values, int*
arg_tcodes, int "
+ "num_args, TVMValue* ret_values, int* ret_tcodes, void*
resource_handle) {"
+ << std::endl;
+ lookup_func << " ret_values[0].v_handle = (void*) &" <<
MetadataSerializer::kGlobalSymbol
+ << ";" << std::endl;
+ lookup_func << " ret_tcodes[0] = kTVMOpaqueHandle;" << std::endl;
+ lookup_func << " return 0;" << std::endl;
+ lookup_func << "};" << std::endl;
+
+ auto mod = MetadataModuleCreate(metadata);
+ std::vector<String> func_names{"get_c_metadata"};
Review comment:
Rather than hard code `get_c_metadata` everywhere, we should introduce a
new symbol at
https://github.com/apache/tvm/blob/d8f639a84fd9a11f03a949eb751a18018a9c1033/include/tvm/runtime/module.h#L219
(similar to `tvm_module_main`)
##########
File path: src/runtime/metadata.cc
##########
@@ -52,5 +62,65 @@ TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data)
TVM_REGISTER_OBJECT_TYPE(TensorInfoNode);
} // namespace metadata
+
+class MetadataModuleNode : public ::tvm::runtime::ModuleNode {
+ public:
+ explicit MetadataModuleNode(runtime::metadata::Metadata metadata) {
+ // CHECK((metadata.defined() && code.size() > 0) || (!metadata.defined()
&& code.size() == 0))
+ // << "metadata and code must both be either defined (when passed from
compiler) or undefined
+ // "
+ // << "(when passed from runtime)";
+ metadata_ = metadata;
+ // code_ = code;
+ }
+
+ const char* type_key() const { return "metadata_module"; }
+
+ static Module LoadFromBinary() {
+ return
Module(make_object<MetadataModuleNode>(runtime::metadata::Metadata()));
+ }
+
+ void SaveToBinary(dmlc::Stream* stream) final {}
+
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) {
+ if (name == "get_metadata") {
+ return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
+ if (!metadata_.defined()) {
+ TVMFunctionHandle f_handle;
+ int32_t ret_code = TVMBackendGetFuncFromEnv(this, "get_c_metadata",
&f_handle);
+ CHECK_EQ(ret_code, 0) << "Unable to locate get_c_metadata
PackedFunc";
+
+ TVMValue ret_value;
+ int ret_type_code;
+ ret_code = TVMFuncCall(f_handle, nullptr, nullptr, 0, &ret_value,
&ret_type_code);
+ CHECK_EQ(ret_code, 0) << "Invoking get_c_metadata: TVMFuncCall
returned " << ret_code;
+
+ CHECK_EQ(ret_type_code, kTVMOpaqueHandle)
+ << "Expected kOpaqueHandle returned; got " << ret_type_code;
+ CHECK(ret_value.v_handle != nullptr) << "get_c_metadata returned
nullptr";
+
Review comment:
`ICHECK` or `ICHECK_EQ` in this function
##########
File path: src/target/metadata.h
##########
@@ -86,19 +97,22 @@ class VisitableMetadataNode : public
::tvm::runtime::metadata::MetadataNode {
class InMemoryMetadataNode : public
::tvm::target::metadata::VisitableMetadataNode {
public:
InMemoryMetadataNode()
- : InMemoryMetadataNode(0 /* version */, {} /* inputs */, {} /* outputs
*/,
+ : InMemoryMetadataNode(0 /* version */, {} /* inputs */, {} /* outputs
*/, {} /* pools */,
"" /* mod_name */) {}
InMemoryMetadataNode(int64_t version,
const
::std::vector<::tvm::runtime::metadata::TensorInfo>& inputs,
const
::std::vector<::tvm::runtime::metadata::TensorInfo>& outputs,
+ const
::std::vector<::tvm::runtime::metadata::TensorInfo>& pools,
const ::tvm::runtime::String mod_name)
: VisitableMetadataNode{&storage_},
inputs_{new struct TVMTensorInfo[inputs.size()]()},
inputs_objs_{inputs},
outputs_{new struct TVMTensorInfo[outputs.size()]()},
outputs_objs_{outputs},
+ pools_{new struct TVMTensorInfo[pools.size()]()},
Review comment:
can use `make_unique` here (as @kparzysz-quic recommended in other
places)
##########
File path: src/target/source/source_module.cc
##########
@@ -518,6 +522,264 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
}
};
+static std::string address_from_parts(const std::vector<std::string>& parts) {
+ std::stringstream ss;
+ for (unsigned int i = 0; i < parts.size(); ++i) {
+ if (i > 0) {
+ ss << "_";
+ }
+ ss << parts[i];
+ }
+ return ss.str();
+}
+
+class MetadataQueuer : public AttrVisitor {
+ public:
+ using QueueItem = std::tuple<std::string, runtime::metadata::MetadataBase>;
+ explicit MetadataQueuer(std::vector<QueueItem>* queue) : queue_{queue} {}
+
+ void Visit(const char* key, double* value) final {}
+ void Visit(const char* key, int64_t* value) final {}
+ void Visit(const char* key, uint64_t* value) final {}
+ void Visit(const char* key, int* value) final {}
+ void Visit(const char* key, bool* value) final {}
+ void Visit(const char* key, std::string* value) final {}
+ void Visit(const char* key, DataType* value) final {}
+ void Visit(const char* key, runtime::NDArray* value) final {}
+ void Visit(const char* key, void** value) final {}
+
+ void Visit(const char* key, ObjectRef* value) final {
+ address_parts_.push_back(key);
+ if (value->as<runtime::metadata::MetadataBaseNode>() != nullptr) {
+ auto metadata = Downcast<runtime::metadata::MetadataBase>(*value);
+ const runtime::metadata::MetadataArrayNode* arr =
+ value->as<runtime::metadata::MetadataArrayNode>();
+ std::cout << "Is array? " << arr << std::endl;
+ if (arr != nullptr) {
+ for (unsigned int i = 0; i < arr->array.size(); i++) {
+ ObjectRef o = arr->array[i];
+ std::cout << "queue-visiting array element " << i << ": " <<
o->type_index() << " ("
+ << o.operator->() << ")" << std::endl;
+ if (o.as<runtime::metadata::MetadataBaseNode>() != nullptr) {
+ std::stringstream ss;
+ ss << i;
+ address_parts_.push_back(ss.str());
+ runtime::metadata::MetadataBase metadata =
Downcast<runtime::metadata::MetadataBase>(o);
+ ReflectionVTable::Global()->VisitAttrs(metadata.operator->(),
this);
+ address_parts_.pop_back();
+ }
+ }
+ } else {
+ ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
+ }
+
+ queue_->push_back(std::make_tuple(address_from_parts(address_parts_),
+
Downcast<runtime::metadata::MetadataBase>(*value)));
+ }
+ address_parts_.pop_back();
+ }
+
+ private:
+ std::vector<QueueItem>* queue_;
+ std::vector<std::string> address_parts_;
+};
+
+class MetadataSerializer : public AttrVisitor {
+ public:
+ static constexpr const char* kGlobalSymbol = "kTvmgenMetadata";
+ using MetadataTypeIndex = ::tvm::runtime::metadata::MetadataTypeIndex;
+
+ MetadataSerializer() : is_first_item_{true} {}
+
+ void WriteComma() {
+ if (is_first_item_) {
+ is_first_item_ = false;
+ } else {
+ code_ << ", " << std::endl;
+ }
+ }
+
+ void WriteKey(const char* key) {
+ if (key != nullptr) {
+ code_ << " /* " << key << "*/";
+ }
+ }
+
+ void Visit(const char* key, double* value) final {
+ WriteComma();
+ code_.setf(std::ios::hex | std::ios::showbase | std::ios::fixed |
std::ios::scientific,
+ std::ios::basefield | std::ios::showbase |
std::ios::floatfield);
+ code_ << *value;
+ WriteKey(key);
+ }
+
+ void Visit(const char* key, int64_t* value) final {
+ WriteComma();
+ code_ << *value << "L";
+ WriteKey(key);
+ }
+
+ void Visit(const char* key, uint64_t* value) final {
+ WriteComma();
+ code_ << *value << "UL";
+ WriteKey(key);
+ }
+ void Visit(const char* key, int* value) final {
+ WriteComma();
+ code_ << *value;
+ WriteKey(key);
+ }
+ void Visit(const char* key, bool* value) final {
+ WriteComma();
+ code_ << *value;
+ WriteKey(key);
+ }
+ void Visit(const char* key, std::string* value) final {
+ WriteComma();
+ code_ << "\"" << *value << "\"";
+ WriteKey(key);
+ }
+ void Visit(const char* key, void** value) final {
+ WriteComma();
+ code_ << *value;
+ WriteKey(key);
+ }
+ void Visit(const char* key, DataType* value) final {
+ WriteComma();
+ code_ << "{" << value->code() << ", " << value->bits() << ", " <<
value->lanes() << "}";
+ WriteKey(key);
+ }
+
+ void Visit(const char* key, runtime::NDArray* value) final {
+ // TODO(areusch): probably we could consolidate --link-params here, tho...
+ ICHECK(false) << "do not support serializing NDArray as metadata";
+ }
+
+ void VisitArray(const runtime::metadata::MetadataArrayNode* array) {
+ std::cout << "visit array " << array << ": " << array->type_index << " "
<< array->struct_name
+ << "," << array->array.size() << std::endl;
+ auto old_is_first_item = is_first_item_;
+ is_first_item_ = true;
+ for (unsigned int i = 0; i < array->array.size(); ++i) { // ObjectRef o :
*(array->array)) {
+ ObjectRef o = array->array[i];
+ std::cout << "visiting array element " << i << ": " << o->type_index()
<< " ("
+ << o.operator->() << ")" << std::endl;
+ if (o->IsInstance<IntImmNode>()) {
+ int64_t i = Downcast<Integer>(o);
+ Visit(nullptr, &i);
+ continue;
+ }
+
+ if (o->IsInstance<StringObj>()) {
+ std::string s = Downcast<String>(o);
+ Visit(nullptr, &s);
+ continue;
+ }
+
+ runtime::metadata::MetadataBase metadata =
Downcast<runtime::metadata::MetadataBase>(o);
+ std::stringstream i_str;
+ i_str << i;
+ address_.push_back(i_str.str());
+ Visit(nullptr, &metadata);
+ address_.pop_back();
+ // ReflectionVTable::Global()->VisitAttrs(metadata.operator->(),
this);
Review comment:
Remove?
##########
File path: src/runtime/aot_executor/aot_executor.h
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief Defines an implementation of Module-based Model Runtime Interface
that works with
+ * Ahead-of-Time compilation.
+ * \file aot_executor.h
+ */
+#ifndef TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_
+#define TVM_RUNTIME_AOT_EXECUTOR_AOT_EXECUTOR_H_
+
+#include <tvm/runtime/metadata.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
+
+#include <string>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+
+class TVM_DLL AotExecutor : public ModuleNode {
+ public:
+ /*!
+ * \brief Implements member function lookup for this Module for the frontend.
+ * \param name The name of the function.
+ * \param sptr_to_self The pointer to the module node.
+ * \return The corresponding member function.
+ */
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>&
sptr_to_self) override;
+
+ /*!
+ * \return The type key of the executor.
+ */
+ const char* type_key() const final { return "AotExecutor"; }
+
+ void Run();
+
+ /*!
+ * \brief Initialize the AOT executor with metadata, runtime::Module, and
device.
+ * \param module The module containing the compiled functions for the host
+ * processor.
+ * \param devs The device of the host and devices where graph nodes will be
+ * executed on.
+ * \param lookup_linked_param_func If given, a PackedFunc invoked to lookup
linked parameters
Review comment:
` lookup_linked_param_func` not given in the param list
##########
File path: src/runtime/metadata.cc
##########
@@ -52,5 +62,65 @@ TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data)
TVM_REGISTER_OBJECT_TYPE(TensorInfoNode);
} // namespace metadata
+
+class MetadataModuleNode : public ::tvm::runtime::ModuleNode {
+ public:
+ explicit MetadataModuleNode(runtime::metadata::Metadata metadata) {
+ // CHECK((metadata.defined() && code.size() > 0) || (!metadata.defined()
&& code.size() == 0))
+ // << "metadata and code must both be either defined (when passed from
compiler) or undefined
+ // "
+ // << "(when passed from runtime)";
+ metadata_ = metadata;
+ // code_ = code;
Review comment:
Remove
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]