manupa-arm commented on a change in pull request #9565:
URL: https://github.com/apache/tvm/pull/9565#discussion_r783273486
##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -645,32 +632,120 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// TODO(giuseros): we should allocate this once outside the PrimFunc
// so we don't pay the price of allocation for every inference
if (!allocated[sid]) {
- body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size},
tir::const_true(), body);
+ PointerType ptype =
Downcast<PointerType>(sids_table_[sid]->type_annotation);
+ DataType element_type =
Downcast<PrimType>(ptype->element_type)->dtype;
+ body = tir::Allocate(sids_table_[sid], element_type, {size},
tir::const_true(), body);
}
allocated[sid] = true;
}
}
- // Define the attributes
- body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_type, 1, body);
- body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_id, 0, body);
-
// Define the PrimFunc attributes
Map<String, ObjectRef> dict_attrs;
String run_func_name =
runtime::get_name_mangled(mod_name,
runtime::symbol::tvm_run_func_suffix);
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));
+ dict_attrs.Set(tvm::attr::kTarget, target_host_);
tir::Stmt device_activations = GenerateAllDeviceHook("Activate");
tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate");
tir::Stmt final_body = tir::SeqStmt({device_activations, body,
device_deactivations});
// Make the PrimFunc
- return tir::PrimFunc(main_signature_, final_body, VoidType(),
Map<tir::Var, tir::Buffer>(),
+ return tir::PrimFunc(main_signature_, final_body, VoidType(),
main_buffer_map_,
DictAttrs(dict_attrs));
}
+ /*!
+ * brief Access IO vars using the buffer vars and
+ * not the actual var.
+ */
+ tir::Var GetBufferVarForIO(int index) { return
main_buffer_map_[main_signature_[index]]->data; }
+
+ /*!
+ * brief Create tir::Var for input/output while updating
+ * the buffer_maps.
+ */
+ void CreateIOVar(const Expr& expr, std::string name) {
+ if (expr->IsInstance<TupleNode>()) {
+ Tuple tuple = Downcast<Tuple>(expr);
+ for (unsigned i = 0; i < tuple->fields.size(); i++) {
+ CreateIOVar(tuple->fields[i], name + std::to_string(i) + "_");
+ }
+ } else {
+ tir::Var var = tir::Var(name, DataType::Handle());
+ main_signature_.push_back(var);
+ auto tensor_type = expr->checked_type().as<TensorTypeNode>();
+ DataType elem_type = tensor_type->dtype;
+ tir::Var buffer_var =
+ tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type),
"global"));
+ tir::Buffer buffer = tir::Buffer(buffer_var, elem_type,
tensor_type->shape, {}, 0,
+ name + "_buffer", 16, 1,
tir::BufferType::kDefault);
+ main_buffer_map_.Set(var, buffer);
+ }
+ }
+
+ /*!
+ * brief This function is a wrapper to run memory planning
+ * followed by recording the latest workspaces required.
+ */
+ IRModule PlanMemoryLoweredModule(const IRModule& mod) {
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+ bool enable_usmp = pass_ctx->GetConfig<Bool>("tir.usmp.enable",
Bool(false)).value();
Review comment:
made into constants.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]