Mousius commented on a change in pull request #8697:
URL: https://github.com/apache/tvm/pull/8697#discussion_r686144631
##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -684,35 +589,68 @@ class AOTExecutorCodegen : public ExprVisitor {
/*! \brief mapping sid -> tir::Var */
std::unordered_map<int, te::Var> sids_table_;
/*! \brief lowered funcs */
- std::unordered_map<std::string, IRModule> lowered_funcs_;
- /*! \brief lowered funcs */
Map<String, FunctionInfo> function_metadata_;
- /*! \brief compile engine */
- CompileEngine compile_engine_;
/*! \brief the set of statements that make the program */
std::vector<tir::Stmt> stmts_;
/*! \brief the list of return sids (note that the function might return more
then one output */
std::vector<int> return_sid_;
- /*! \brief the module name we use to mangle the function names */
- String mod_name_;
public:
- AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target
target_host)
+ AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets,
Target target_host)
: mod_(mod),
targets_(targets),
target_host_(target_host),
-
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))),
- compile_engine_(CompileEngine::Global()) {}
+
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false)))
{}
LoweredOutput Codegen(relay::Function func, String mod_name) {
auto aot_allocator = AOTOnDemandAllocator();
aot_allocator.Run(func);
- // Retrieve the storage map
- storage_device_map_ = aot_allocator.GetStorageMap();
- mod_name_ = mod_name;
+ // Pre-lowering storage map and memory plan
+ StorageMap initial_storage_map = aot_allocator.GetStorageMap();
+ StaticMemoryPlan memory_plan(initial_storage_map);
+
+ // Build a map from each operation to device.
+ tec::DeviceMap device_context_map;
+ for (const auto& it : memory_plan->expr_to_storage_info) {
+ auto expr = it.first;
+ auto storage_info = it.second;
+ auto device_types = storage_info->device_types;
+ // CHECK_EQ(device_types.size(), 1);
+ tvm::Device dev;
+ dev.device_id = 0;
+ dev.device_type = device_types[0];
+ device_context_map.insert({expr, dev});
+ }
+
+ // This first phase moves from implicit use of compile engine,
+ // to instead explicitly lowering the incoming IRModule, and then
+ // performing the preexisting AOT executor code generation phase.
+ IRModule mod = IRModule::FromExpr(func);
+ auto lowered_module = tec::LowerTE(
+ mod, targets_, device_context_map, memory_plan, mod_name,
[this](Function func) {
+ // We need to maintain the constant map for external
+ // functions so we pass this processing function which
+ // allows us to process each function as we lower it.
+ if (func->GetAttr<String>(attr::kCompiler).defined()) {
+ UpdateConstants(func, ¶ms_);
+ }
+
+ // TODO(@areusch, @jroesch): We should refactor this to
+ // execute as a further pass, instead writing data to the
+ // lowering process directly.
+ tec::UpdateFunctionMetadata(func, this->function_metadata_);
+ });
- for (auto input : func->params) {
+ auto lowered_main = lowered_module.main_module->Lookup("main");
+ auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
+
+ // Post-lowering storage map for writing main func
+ auto new_allocator = AOTOnDemandAllocator();
+ new_allocator.Run(lowered_main_func);
+ storage_device_map_ = new_allocator.GetStorageMap();
Review comment:
Yip, that's correct, I've updated the comment to clarify that - what do
you think now?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]