This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d1f19c4 Add LowerTEPass, and convert calls to LowerTE to application
of LowerTEPass (#8802)
d1f19c4 is described below
commit d1f19c470c16a1ca87c67fd93f30dd59e16bbec1
Author: Lily Orth-Smith <[email protected]>
AuthorDate: Mon Aug 23 20:12:07 2021 -0700
Add LowerTEPass, and convert calls to LowerTE to application of LowerTEPass
(#8802)
* Initial commit
Initial stab at IRModule -> LoweredModule conversion func, notes
Add external_mods and main_func_info to conversion funcs
MTest lowered module to ir module
fix problem with conversion funcs + print stmts
Add LowerTE pass
Add pLowerTEPass
AAdd LowerTEPass to graph_executor_codegen.cc
Use LowerTEPass instead of LowerTe in graph_executor_codegen.cc
Code cleanup
Add docs, more cleanup
Formatting
* Fix bad rebase
* Address 1st round of comments
* Use tir kTarget instead of relay one
* Change target string to Target obj
* removing target string causing issues
* Fix typos
* Revert target str -> target obj changes
* Don't use Update : IRModule because it is broken
* Fix check
* flaky test?
* lint
---
include/tvm/relay/function.h | 1 -
src/relay/backend/aot_executor_codegen.cc | 8 +-
src/relay/backend/graph_executor_codegen.cc | 7 +-
src/relay/backend/te_compiler.cc | 117 ++++++++++++++++++++++++++--
src/relay/backend/te_compiler.h | 49 +++++++++++-
5 files changed, 167 insertions(+), 15 deletions(-)
diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h
index fccd1f9..9170bc5 100644
--- a/include/tvm/relay/function.h
+++ b/include/tvm/relay/function.h
@@ -144,7 +144,6 @@ constexpr const char* kComposite = "Composite";
constexpr const char* kInline = "Inline";
/*! \brief Indicate the function was created by the Pattern Partitioning Pass.
*/
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
-
/*! \brief Mark the function as only composed of reshape operations. */
constexpr const char* kReshapeOnly = "relay.reshape_only";
} // namespace attr
diff --git a/src/relay/backend/aot_executor_codegen.cc
b/src/relay/backend/aot_executor_codegen.cc
index 54a10ad..942bc0d 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -586,8 +586,9 @@ class AOTExecutorCodegen : public ExprVisitor {
// to instead explicitly lowering the incoming IRModule, and then
// performing the preexisting AOT executor code generation phase.
IRModule mod = IRModule::FromExpr(func);
- auto lowered_module = tec::LowerTE(
- mod, targets_, device_context_map, memory_plan, mod_name,
[this](Function func) {
+
+ IRModule new_mod =
+ LowerTEPass(targets_, device_context_map, memory_plan, mod_name,
[this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
@@ -599,8 +600,9 @@ class AOTExecutorCodegen : public ExprVisitor {
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
- });
+ })(mod);
+ tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main,
lowered_module.main_func_info);
auto lowered_main = lowered_module.main_module->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
diff --git a/src/relay/backend/graph_executor_codegen.cc
b/src/relay/backend/graph_executor_codegen.cc
index cc54a52..486a6dc 100644
--- a/src/relay/backend/graph_executor_codegen.cc
+++ b/src/relay/backend/graph_executor_codegen.cc
@@ -221,8 +221,8 @@ class GraphExecutorCodegen : public
backend::MemoizedExprTranslator<std::vector<
device_context_map.insert({expr, dev});
}
- auto lowered_module = tec::LowerTE(
- mod, targets_, device_context_map, memory_plan_, mod_name_,
[this](Function func) {
+ IRModule new_mod =
+ LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_,
[this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
@@ -234,8 +234,9 @@ class GraphExecutorCodegen : public
backend::MemoizedExprTranslator<std::vector<
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
- });
+ })(mod);
+ tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
function_metadata_.Set(runtime::symbol::tvm_module_main,
lowered_module.main_func_info);
auto main_module = lowered_module.main_module;
main_module = relay::transform::InferType()(main_module);
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 93fcf73..71ac752 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -20,6 +20,7 @@
#include "te_compiler.h"
#include <tvm/driver/driver_api.h>
+#include <tvm/ir/attrs.h>
#include <tvm/ir/function.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
@@ -749,8 +750,6 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const
IRModule& mod, TargetMap tar
relay_primfuncs);
}
-// TODO(@electriclilies): Is the function passed in here relay_func??
-// Also should this be inlined?
/*!
* \brief A function to create the function metadata for an input function (ie
calculate buffer
* input/output sizes)
@@ -830,9 +829,6 @@ void UpdateFunctionMetadata(Function relay_func,
function_metadata.Set(prim_fn_var.value()->name_hint, fi);
}
-// TODO(mbs): Make this an IRModule->IRModule pass by folding LoweredModule
back into IRModule.
-// Currently we rely on accumulating bindings inside the local TECompiler
which we then
-// host into the LoweredModule result.
LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap
device_context_map,
backend::StaticMemoryPlan memory_plan, const String&
module_name,
std::function<void(Function)> process_fn) {
@@ -875,6 +871,117 @@ LoweredModule LowerTE(const IRModule& module, TargetMap
targets, DeviceMap devic
return lowered_module;
}
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+ IRModule unified_module;
+
+ // Copy the main module and its typedefs
+ for (const auto& kv : mod.main_module->functions) {
+ unified_module->Add(kv.first, kv.second);
+ }
+ for (const auto& kv : mod.main_module->type_definitions) {
+ unified_module->AddTypeDef(kv.first, kv.second);
+ }
+
+ // Annotate the per-target functions with their target and add them to the
unified module
+ for (const auto& kv : mod.per_target_module) {
+ const String target = kv.first;
+ const IRModule target_module = kv.second;
+
+ // Right now, per-target functions are TIR functions, which don't have
type definitions, so
+ // there should be no type defs in the per_target_modules
+ size_t ty_def_size = target_module->type_definitions.size();
+ ICHECK(ty_def_size == 0)
+ << "Expected there to be no type definitions in the
per_target_modules, but found "
+ << ty_def_size;
+
+ for (const auto& kv : target_module->functions) {
+ const GlobalVar& var = kv.first;
+ const BaseFunc& func = kv.second;
+ if (func->IsInstance<tir::PrimFuncNode>()) {
+ tir::PrimFunc primFunc =
+ WithAttr(Downcast<tir::PrimFunc>(std::move(func)),
tvm::attr::kTarget, target);
+ unified_module->Add(var, primFunc);
+ } else if (func->IsInstance<relay::FunctionNode>()) {
+ relay::Function relayFunc =
+ WithAttr(Downcast<relay::Function>(std::move(func)),
tvm::attr::kTarget, target);
+ unified_module->Add(var, relayFunc);
+ } else {
+ LOG(FATAL)
+ << "We expected to only have PrimFuncs or RelayFuncs in the target
modules, but found "
+ << func->GetTypeKey();
+ }
+ }
+ }
+
+ IRModule ret_mod = WithAttr(unified_module, "external_mods",
mod.external_mods);
+ ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
+ return ret_mod;
+}
+
+LoweredModule IRModuleToLoweredModule(IRModule mod) {
+ IRModule main_mod;
+ // Copy just the TypeDefs from the IRModule to the LoweredModule's main
module
+ // This is the only time we need to do this since there are no TypeDefs in
TIR
+ for (const auto& kv : mod->type_definitions) {
+ main_mod->AddTypeDef(kv.first, kv.second);
+ }
+
+ Map<String, IRModule> per_target_modules;
+ for (const auto& kv : mod->functions) {
+ const GlobalVar& var = kv.first;
+ const BaseFunc& func = kv.second;
+ if (func->IsInstance<relay::FunctionNode>()) {
+ main_mod->Add(var, func);
+ } else if (func->IsInstance<tir::PrimFuncNode>()) {
+ // Extract target
+ Optional<String> target = func->GetAttr<String>(tvm::attr::kTarget);
+ ICHECK(target) << "Target should be set at this point";
+
+ // Put the function in per_target_modules
+ if (!per_target_modules.count(target.value())) {
+ // Initialize the IRModule for this target and add the function
+ IRModule target_module;
+ target_module->Add(var, func);
+ per_target_modules.Set(target.value(), target_module);
+ } else {
+ // The IRModule for this target is initialized, so just add the
function.
+ IRModule target_module = per_target_modules.at(target.value());
+ target_module->Add(var, func);
+ }
+ } else {
+ LOG(FATAL)
+ << "The function types in the IRModule should be RelayFunction or
PrimFunc, but got "
+ << func->GetTypeKey();
+ }
+ }
+
+ // Put the LoweredModule together
+ LoweredModule lowered_module;
+ lowered_module.main_module = main_mod;
+ lowered_module.per_target_module = per_target_modules;
+
+ // Extract external modules and main func info, add to lowered module if
they exist
+ auto external_mods =
mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
+ if (external_mods) {
+ lowered_module.external_mods = external_mods.value();
+ }
+ auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
+ if (main_func_info) {
+ lowered_module.main_func_info = main_func_info.value();
+ }
+ return lowered_module;
+}
+
+Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
+ backend::StaticMemoryPlan memory_plan, const String&
module_name,
+ std::function<void(Function)> process_fn) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule module,
+
PassContext ctx) {
+ return LoweredModuleToIRModule(
+ LowerTE(module, targets, device_context_map, memory_plan, module_name,
process_fn));
+ };
+ return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
+}
} // namespace tec
} // namespace relay
} // namespace tvm
diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h
index 8376b99..e9cfb0d 100644
--- a/src/relay/backend/te_compiler.h
+++ b/src/relay/backend/te_compiler.h
@@ -166,30 +166,73 @@ void UpdateFunctionMetadata(Function relay_func,
/*!
* \brief Obtain the Target from the device type.
* If homogenous compilation, this will return the only target.
- * If heteregenous compilation, this will select associated using the targets_
Map.
+ * If heterogeneous compilation, this will select the associated target using
the
+ * targets_ Map.
*
* \param dev_type
* \return Target
*/
Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets);
+/*! \brief Utility to convert a LoweredModule to an IRModule.
+ *
+ * This function takes all the target specific modules in LoweredModule and
+ * annotates their functions with the correct target, and puts all those
functions
+ * in one IRModule.
+ * The purpose of this utility is to allow us to slowly remove LoweredModule
from the codebase.
+ *
+ * \param mod The LoweredModule to convert.
+ * \return The IRModule form of the input LoweredModule.
+ */
+IRModule LoweredModuleToIRModule(LoweredModule mod);
+
+/*! \brief Utility to convert an IRModule to a LoweredModule.
+ *
+ * This function takes all the functions in the IRModule and moves them into
target-specific
+ * IRModules stored inside a LoweredModule.
+ * The purpose of this utility is to allow us to slowly remove LoweredModule
from the codebase.
+ * \param mod The IRModule to convert.
+ * \return The LoweredModule form of the input IRModule.
+ */
+LoweredModule IRModuleToLoweredModule(IRModule mod);
+
/*! \brief Lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive
functions"
* to TE expressions, schedules them, and then to TIR.
*
- * \param compiler The TE-to-TIR compliler (which caches lowered functions)
* \param module The IRModule.
* \param targets The mapping for devices to targets.
* \param device_map An analysis result mapping each sub-expression to a
device.
+ * \param memory_plan The memory plan used during lowering
+ * \param module_name The name of this module
+ * \param process_fn Callback allowing one-level up code generators to process
+ * each function that we lower
* \return The lowered module, see above.
*/
-// TODO(@electriclilies): Not sure if this default initialization is correct...
LoweredModule LowerTE(
const IRModule& module, TargetMap targets, DeviceMap device_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
ProcessFn process_fn = [](Function f) {});
+/*! \brief Pass to lower an IRModule's primitive functions to TIR.
+ *
+ * This is the "back half" of the Relay compiler which lowers "primitive
functions"
+ * to TE expressions, schedules them, and then to TIR. This Pass calls
LowerTE, and
+ * uses LoweredModuleToIRModule utility to convert the output LowerTE's output
+ * LoweredModule into an IRModule before returning it.
+ *
+ * \param targets The mapping for devices to targets.
+ * \param device_context_map An analysis result mapping each sub-expression to
a device.
+ * \param memory_plan The memory plan used during lowering
+ * \param module_name The name of this module
+ * \param process_fn Callback allowing one-level up code generators to process
+ * each function that we lower
+ * \returns The pass which lowers primative functions to TIR
+ */
+transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
+ backend::StaticMemoryPlan memory_plan, const
String& module_name,
+ std::function<void(Function)> process_fn);
} // namespace tec
} // namespace relay
} // namespace tvm