areusch commented on a change in pull request #9565:
URL: https://github.com/apache/tvm/pull/9565#discussion_r778195539
##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -859,20 +922,28 @@ class AOTExecutorCodegen : public MixedModeVisitor {
ret.external_mods = external_modules.value();
- if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
- VLOG(1) << "merging main into existing module for host target";
- ret.lowered_funcs[target_host_]->Update(mod_run);
- } else {
- VLOG(1) << "adding main into new module for host target";
- ret.lowered_funcs.Set(target_host_, mod_run);
+ Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
+ std::vector<tir::Var> pool_vars;
+ tir::PrimFunc tir_main_func =
+
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
+ Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
+
tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
+ int main_workspace_size = 0;
Review comment:
size_bytes
##########
File path: src/target/source/source_module.cc
##########
@@ -197,45 +202,161 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
<< "}\n";
}
+ String GenerateDLTensorStructWrapper(String reference_arg) {
+ code_ << "DLTensor " << reference_arg << "_dlt = {\n";
Review comment:
since this is generated code, maybe we should opt to expand _dlt?
##########
File path: src/target/source/source_module.cc
##########
@@ -197,45 +202,161 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
<< "}\n";
}
+ String GenerateDLTensorStructWrapper(String reference_arg) {
+ code_ << "DLTensor " << reference_arg << "_dlt = {\n";
+ code_ << ".data = &" << reference_arg << "\n";
+ code_ << "};\n";
+ code_ << "TVMValue " << reference_arg << "_tvmv = {\n";
+ code_ << ".v_handle = &" << reference_arg << "_dlt\n";
+ code_ << "};\n";
+ return reference_arg + "_tvmv";
Review comment:
slight preference for like `_value` or something, since tvmv is pretty
close to tvm
##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -645,32 +632,120 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// TODO(giuseros): we should allocate this once outside the PrimFunc
// so we don't pay the price of allocation for every inference
if (!allocated[sid]) {
- body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size},
tir::const_true(), body);
+ PointerType ptype =
Downcast<PointerType>(sids_table_[sid]->type_annotation);
+ DataType element_type =
Downcast<PrimType>(ptype->element_type)->dtype;
+ body = tir::Allocate(sids_table_[sid], element_type, {size},
tir::const_true(), body);
}
allocated[sid] = true;
}
}
- // Define the attributes
- body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_type, 1, body);
- body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_id, 0, body);
-
// Define the PrimFunc attributes
Map<String, ObjectRef> dict_attrs;
String run_func_name =
runtime::get_name_mangled(mod_name,
runtime::symbol::tvm_run_func_suffix);
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));
+ dict_attrs.Set(tvm::attr::kTarget, target_host_);
tir::Stmt device_activations = GenerateAllDeviceHook("Activate");
tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate");
tir::Stmt final_body = tir::SeqStmt({device_activations, body,
device_deactivations});
// Make the PrimFunc
- return tir::PrimFunc(main_signature_, final_body, VoidType(),
Map<tir::Var, tir::Buffer>(),
+ return tir::PrimFunc(main_signature_, final_body, VoidType(),
main_buffer_map_,
DictAttrs(dict_attrs));
}
+ /*!
+ * brief Access IO vars using the buffer vars and
+ * not the actual var.
+ */
+ tir::Var GetBufferVarForIO(int index) { return
main_buffer_map_[main_signature_[index]]->data; }
+
+ /*!
+ * brief Create tir::Var for input/output while updating
+ * the buffer_maps.
+ */
+ void CreateIOVar(const Expr& expr, std::string name) {
+ if (expr->IsInstance<TupleNode>()) {
+ Tuple tuple = Downcast<Tuple>(expr);
+ for (unsigned i = 0; i < tuple->fields.size(); i++) {
+ CreateIOVar(tuple->fields[i], name + std::to_string(i) + "_");
+ }
+ } else {
+ tir::Var var = tir::Var(name, DataType::Handle());
+ main_signature_.push_back(var);
+ auto tensor_type = expr->checked_type().as<TensorTypeNode>();
+ DataType elem_type = tensor_type->dtype;
+ tir::Var buffer_var =
+ tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type),
"global"));
+ tir::Buffer buffer = tir::Buffer(buffer_var, elem_type,
tensor_type->shape, {}, 0,
+ name + "_buffer", 16, 1,
tir::BufferType::kDefault);
+ main_buffer_map_.Set(var, buffer);
+ }
+ }
+
+ /*!
+ * brief This function is a wrapper to run memory planning
+ * followed by recording the latest workspaces required.
+ */
+ IRModule PlanMemoryLoweredModule(const IRModule& mod) {
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+ bool enable_usmp = pass_ctx->GetConfig<Bool>("tir.usmp.enable",
Bool(false)).value();
+
+ IRModule lowered_mod = mod->ShallowCopy();
+ Executor executor_config =
mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
+ Integer workspace_byte_alignment =
+
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
+ if (enable_usmp) {
+ lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod);
+ // Update workspace size based on the pool allocations.
+ Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
+
lowered_mod->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
+ int main_workspace_size = 0;
+ if (allocated_pool_infos) {
+ for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info :
+ allocated_pool_infos.value()) {
+ main_workspace_size += allocated_pool_info->allocated_size->value;
+ }
+ }
+ for (const auto& kv : function_metadata_) {
+ if (lowered_mod->ContainGlobalVar(kv.first) &&
+ lowered_mod->Lookup(kv.first)->IsInstance<tir::PrimFuncNode>()) {
+ tir::PrimFunc pfunc =
Downcast<tir::PrimFunc>(lowered_mod->Lookup(kv.first));
+ Target tgt = pfunc->GetAttr<Target>(tvm::attr::kTarget).value();
Review comment:
should this somehow use VirtualDevice (either now or in a follow on)? cc
@mbs-octoml
##########
File path: src/target/source/source_module.cc
##########
@@ -197,45 +202,161 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
<< "}\n";
}
+ String GenerateDLTensorStructWrapper(String reference_arg) {
+ code_ << "DLTensor " << reference_arg << "_dlt = {\n";
+ code_ << ".data = &" << reference_arg << "\n";
+ code_ << "};\n";
+ code_ << "TVMValue " << reference_arg << "_tvmv = {\n";
+ code_ << ".v_handle = &" << reference_arg << "_dlt\n";
+ code_ << "};\n";
+ return reference_arg + "_tvmv";
+ }
+
+ void GenerateInternalWorkspaceBuffers() {
+ if (metadata_->pool_inputs.defined()) {
+ for (const auto& kv : metadata_->pool_inputs.value()) {
+ tir::usmp::AllocatedPoolInfo allocated_pool_info = kv.second;
+ if (allocated_pool_info->pool_info->is_internal) {
+ code_ << "__attribute__((section(\".bss.tvm\"), ";
+ code_ << "aligned(" << 16 << ")))\n";
+ code_ << "static uint8_t " <<
allocated_pool_info->pool_info->pool_name << "["
+ << allocated_pool_info->allocated_size->value << "];\n";
+ }
+ }
+ }
+ }
+
+ bool IsInternalWorkspaceBuffer(const tir::Var& pool_var) {
+ if (metadata_->pool_inputs.defined()) {
+ Map<tir::Var, tir::usmp::AllocatedPoolInfo> allocated_pool_infos =
+ metadata_->pool_inputs.value();
+ if (allocated_pool_infos.find(pool_var) != allocated_pool_infos.end()) {
+ tir::usmp::AllocatedPoolInfo allocate_pool_info =
allocated_pool_infos[pool_var];
+ if (allocate_pool_info->pool_info->is_internal) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name,
const std::string& run_func) {
code_ << "TVM_DLL int32_t " << run_func << "(";
- unsigned int total_args = (metadata_->inputs.size() +
metadata_->num_outputs);
- for (unsigned int i = 0; i < total_args; ++i) {
- code_ << "void* arg" << i;
- if (i + 1 != total_args) {
- code_ << ",";
+
+ {
+ std::stringstream call_args_ss;
+ for (const tir::Var& input_var : metadata_->inputs) {
+ if (input_var->type_annotation.defined()) {
+ codegen_c_.PrintType(input_var->type_annotation, call_args_ss);
+ } else {
+ codegen_c_.PrintType(input_var.dtype(), call_args_ss);
+ }
+ call_args_ss << " " << input_var->name_hint << ",";
+ }
+ for (unsigned int i = 0; i < metadata_->num_outputs; ++i) {
+ call_args_ss << "void* output" << i << ",";
+ }
+ for (const tir::Var& pool_var : metadata_->pools) {
+ if (pool_var->type_annotation.defined()) {
+ codegen_c_.PrintType(pool_var->type_annotation, call_args_ss);
+ } else {
+ codegen_c_.PrintType(pool_var.dtype(), call_args_ss);
+ }
+ call_args_ss << " " << pool_var->name_hint << ",";
}
+ std::string call_args_str = call_args_ss.str();
+ call_args_str.pop_back();
+ code_ << call_args_str;
}
+
code_ << ");\n";
code_ << "int32_t " << entrypoint_name;
code_ << "(void* args, void* type_code, int num_args, void* out_value,
void* "
"out_type_code, void* resource_handle) {\n";
code_ << "return " << run_func << "(";
- for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) {
- code_ << "((DLTensor*)(((TVMValue*)args)[" << i <<
"].v_handle))[0].data,";
+
+ {
+ std::stringstream call_args_ss;
+ for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) {
+ call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << i <<
"].v_handle))[0].data,";
Review comment:
any reason to use `[0].` instead of `->`?
##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -645,32 +632,120 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// TODO(giuseros): we should allocate this once outside the PrimFunc
// so we don't pay the price of allocation for every inference
if (!allocated[sid]) {
- body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size},
tir::const_true(), body);
+ PointerType ptype =
Downcast<PointerType>(sids_table_[sid]->type_annotation);
+ DataType element_type =
Downcast<PrimType>(ptype->element_type)->dtype;
+ body = tir::Allocate(sids_table_[sid], element_type, {size},
tir::const_true(), body);
}
allocated[sid] = true;
}
}
- // Define the attributes
- body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_type, 1, body);
- body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_id, 0, body);
-
// Define the PrimFunc attributes
Map<String, ObjectRef> dict_attrs;
String run_func_name =
runtime::get_name_mangled(mod_name,
runtime::symbol::tvm_run_func_suffix);
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));
+ dict_attrs.Set(tvm::attr::kTarget, target_host_);
tir::Stmt device_activations = GenerateAllDeviceHook("Activate");
tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate");
tir::Stmt final_body = tir::SeqStmt({device_activations, body,
device_deactivations});
// Make the PrimFunc
- return tir::PrimFunc(main_signature_, final_body, VoidType(),
Map<tir::Var, tir::Buffer>(),
+ return tir::PrimFunc(main_signature_, final_body, VoidType(),
main_buffer_map_,
DictAttrs(dict_attrs));
}
+ /*!
+ * brief Access IO vars using the buffer vars and
+ * not the actual var.
+ */
+ tir::Var GetBufferVarForIO(int index) { return
main_buffer_map_[main_signature_[index]]->data; }
+
+ /*!
+ * brief Create tir::Var for input/output while updating
+ * the buffer_maps.
+ */
+ void CreateIOVar(const Expr& expr, std::string name) {
+ if (expr->IsInstance<TupleNode>()) {
+ Tuple tuple = Downcast<Tuple>(expr);
+ for (unsigned i = 0; i < tuple->fields.size(); i++) {
+ CreateIOVar(tuple->fields[i], name + std::to_string(i) + "_");
+ }
+ } else {
+ tir::Var var = tir::Var(name, DataType::Handle());
+ main_signature_.push_back(var);
+ auto tensor_type = expr->checked_type().as<TensorTypeNode>();
+ DataType elem_type = tensor_type->dtype;
+ tir::Var buffer_var =
+ tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type),
"global"));
+ tir::Buffer buffer = tir::Buffer(buffer_var, elem_type,
tensor_type->shape, {}, 0,
+ name + "_buffer", 16, 1,
tir::BufferType::kDefault);
+ main_buffer_map_.Set(var, buffer);
+ }
+ }
+
+ /*!
+ * brief This function is a wrapper to run memory planning
+ * followed by recording the latest workspaces required.
+ */
+ IRModule PlanMemoryLoweredModule(const IRModule& mod) {
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+ bool enable_usmp = pass_ctx->GetConfig<Bool>("tir.usmp.enable",
Bool(false)).value();
Review comment:
should `tir.usmp.enable` be a constant somewhere?
##########
File path: tests/python/relay/aot/test_c_device_api.py
##########
@@ -231,26 +216,29 @@ def
test_without_device_api_unpacked_api(non_device_api_main_func):
"""Test a graph without the Device API with the unpacked internal calls"""
main_func = non_device_api_main_func(interface_api="c",
use_unpacked_api=True)
-
+ print(str(main_func.body))
Review comment:
nit: delete
##########
File path: src/target/source/source_module.cc
##########
@@ -197,45 +202,161 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
<< "}\n";
}
+ String GenerateDLTensorStructWrapper(String reference_arg) {
+ code_ << "DLTensor " << reference_arg << "_dlt = {\n";
+ code_ << ".data = &" << reference_arg << "\n";
+ code_ << "};\n";
+ code_ << "TVMValue " << reference_arg << "_tvmv = {\n";
+ code_ << ".v_handle = &" << reference_arg << "_dlt\n";
+ code_ << "};\n";
+ return reference_arg + "_tvmv";
+ }
+
+ void GenerateInternalWorkspaceBuffers() {
+ if (metadata_->pool_inputs.defined()) {
+ for (const auto& kv : metadata_->pool_inputs.value()) {
+ tir::usmp::AllocatedPoolInfo allocated_pool_info = kv.second;
+ if (allocated_pool_info->pool_info->is_internal) {
+ code_ << "__attribute__((section(\".bss.tvm\"), ";
+ code_ << "aligned(" << 16 << ")))\n";
+ code_ << "static uint8_t " <<
allocated_pool_info->pool_info->pool_name << "["
+ << allocated_pool_info->allocated_size->value << "];\n";
+ }
+ }
+ }
+ }
+
+ bool IsInternalWorkspaceBuffer(const tir::Var& pool_var) {
+ if (metadata_->pool_inputs.defined()) {
+ Map<tir::Var, tir::usmp::AllocatedPoolInfo> allocated_pool_infos =
+ metadata_->pool_inputs.value();
+ if (allocated_pool_infos.find(pool_var) != allocated_pool_infos.end()) {
+ tir::usmp::AllocatedPoolInfo allocate_pool_info =
allocated_pool_infos[pool_var];
+ if (allocate_pool_info->pool_info->is_internal) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
void GenerateEntrypointForUnpackedAPI(const std::string& entrypoint_name,
const std::string& run_func) {
code_ << "TVM_DLL int32_t " << run_func << "(";
- unsigned int total_args = (metadata_->inputs.size() +
metadata_->num_outputs);
- for (unsigned int i = 0; i < total_args; ++i) {
- code_ << "void* arg" << i;
- if (i + 1 != total_args) {
- code_ << ",";
+
+ {
+ std::stringstream call_args_ss;
+ for (const tir::Var& input_var : metadata_->inputs) {
+ if (input_var->type_annotation.defined()) {
+ codegen_c_.PrintType(input_var->type_annotation, call_args_ss);
+ } else {
+ codegen_c_.PrintType(input_var.dtype(), call_args_ss);
+ }
+ call_args_ss << " " << input_var->name_hint << ",";
+ }
+ for (unsigned int i = 0; i < metadata_->num_outputs; ++i) {
+ call_args_ss << "void* output" << i << ",";
+ }
+ for (const tir::Var& pool_var : metadata_->pools) {
+ if (pool_var->type_annotation.defined()) {
+ codegen_c_.PrintType(pool_var->type_annotation, call_args_ss);
+ } else {
+ codegen_c_.PrintType(pool_var.dtype(), call_args_ss);
+ }
+ call_args_ss << " " << pool_var->name_hint << ",";
}
+ std::string call_args_str = call_args_ss.str();
+ call_args_str.pop_back();
+ code_ << call_args_str;
}
+
code_ << ");\n";
code_ << "int32_t " << entrypoint_name;
code_ << "(void* args, void* type_code, int num_args, void* out_value,
void* "
"out_type_code, void* resource_handle) {\n";
code_ << "return " << run_func << "(";
- for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) {
- code_ << "((DLTensor*)(((TVMValue*)args)[" << i <<
"].v_handle))[0].data,";
+
+ {
+ std::stringstream call_args_ss;
+ for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) {
+ call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << i <<
"].v_handle))[0].data,";
+ }
+ for (unsigned int i = 0; i < metadata_->num_outputs; ++i) {
+ int j = metadata_->inputs.size() + i;
+ call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << j <<
"].v_handle))[0].data,";
+ }
+ for (const tir::Var& pool_var : metadata_->pools) {
+ if (IsInternalWorkspaceBuffer(pool_var)) {
Review comment:
what if the workspace buffers are defined by the user?
##########
File path: src/tir/usmp/utils.cc
##########
@@ -168,6 +174,20 @@ Array<BufferInfo> CreateArrayBufferInfo(const
Map<BufferInfo, Stmt>& buffer_info
return ret;
}
+void PrintConflicts(const Array<BufferInfo>& bi_arr) {
+ for (const auto& bi : bi_arr) {
Review comment:
want to either log each conflict separately or wrap this function in a
log level check? (or, is this function called?)
##########
File path: src/target/source/source_module.cc
##########
@@ -197,45 +202,161 @@ class CSourceCrtMetadataModuleNode : public
runtime::ModuleNode {
<< "}\n";
}
+ String GenerateDLTensorStructWrapper(String reference_arg) {
+ code_ << "DLTensor " << reference_arg << "_dlt = {\n";
+ code_ << ".data = &" << reference_arg << "\n";
+ code_ << "};\n";
+ code_ << "TVMValue " << reference_arg << "_tvmv = {\n";
+ code_ << ".v_handle = &" << reference_arg << "_dlt\n";
+ code_ << "};\n";
+ return reference_arg + "_tvmv";
+ }
+
+ void GenerateInternalWorkspaceBuffers() {
+ if (metadata_->pool_inputs.defined()) {
+ for (const auto& kv : metadata_->pool_inputs.value()) {
+ tir::usmp::AllocatedPoolInfo allocated_pool_info = kv.second;
+ if (allocated_pool_info->pool_info->is_internal) {
+ code_ << "__attribute__((section(\".bss.tvm\"), ";
Review comment:
should we zero these? it might add significantly to startup cycles
##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -645,32 +632,120 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// TODO(giuseros): we should allocate this once outside the PrimFunc
// so we don't pay the price of allocation for every inference
if (!allocated[sid]) {
- body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size},
tir::const_true(), body);
+ PointerType ptype =
Downcast<PointerType>(sids_table_[sid]->type_annotation);
+ DataType element_type =
Downcast<PrimType>(ptype->element_type)->dtype;
+ body = tir::Allocate(sids_table_[sid], element_type, {size},
tir::const_true(), body);
}
allocated[sid] = true;
}
}
- // Define the attributes
- body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_type, 1, body);
- body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_id, 0, body);
-
// Define the PrimFunc attributes
Map<String, ObjectRef> dict_attrs;
String run_func_name =
runtime::get_name_mangled(mod_name,
runtime::symbol::tvm_run_func_suffix);
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));
+ dict_attrs.Set(tvm::attr::kTarget, target_host_);
tir::Stmt device_activations = GenerateAllDeviceHook("Activate");
tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate");
tir::Stmt final_body = tir::SeqStmt({device_activations, body,
device_deactivations});
// Make the PrimFunc
- return tir::PrimFunc(main_signature_, final_body, VoidType(),
Map<tir::Var, tir::Buffer>(),
+ return tir::PrimFunc(main_signature_, final_body, VoidType(),
main_buffer_map_,
DictAttrs(dict_attrs));
}
+ /*!
+ * brief Access IO vars using the buffer vars and
+ * not the actual var.
+ */
+ tir::Var GetBufferVarForIO(int index) { return
main_buffer_map_[main_signature_[index]]->data; }
+
+ /*!
+ * brief Create tir::Var for input/output while updating
+ * the buffer_maps.
+ */
+ void CreateIOVar(const Expr& expr, std::string name) {
+ if (expr->IsInstance<TupleNode>()) {
+ Tuple tuple = Downcast<Tuple>(expr);
+ for (unsigned i = 0; i < tuple->fields.size(); i++) {
+ CreateIOVar(tuple->fields[i], name + std::to_string(i) + "_");
+ }
+ } else {
+ tir::Var var = tir::Var(name, DataType::Handle());
+ main_signature_.push_back(var);
+ auto tensor_type = expr->checked_type().as<TensorTypeNode>();
+ DataType elem_type = tensor_type->dtype;
+ tir::Var buffer_var =
+ tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type),
"global"));
+ tir::Buffer buffer = tir::Buffer(buffer_var, elem_type,
tensor_type->shape, {}, 0,
+ name + "_buffer", 16, 1,
tir::BufferType::kDefault);
+ main_buffer_map_.Set(var, buffer);
+ }
+ }
+
+ /*!
+ * brief This function is a wrapper to run memory planning
+ * followed by recording the latest workspaces required.
+ */
+ IRModule PlanMemoryLoweredModule(const IRModule& mod) {
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+ bool enable_usmp = pass_ctx->GetConfig<Bool>("tir.usmp.enable",
Bool(false)).value();
+
+ IRModule lowered_mod = mod->ShallowCopy();
+ Executor executor_config =
mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
+ Integer workspace_byte_alignment =
+
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
+ if (enable_usmp) {
+ lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod);
+ // Update workspace size based on the pool allocations.
+ Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
+
lowered_mod->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
+ int main_workspace_size = 0;
Review comment:
just curious: when we expand to allowing AOT to do heterogeneous
execution and device copies, would we then assemble a set of workspace sizes?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]