This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d1f19c4  Add LowerTEPass, and convert calls to LowerTE to application 
of LowerTEPass (#8802)
d1f19c4 is described below

commit d1f19c470c16a1ca87c67fd93f30dd59e16bbec1
Author: Lily Orth-Smith <lilyorthsm...@gmail.com>
AuthorDate: Mon Aug 23 20:12:07 2021 -0700

    Add LowerTEPass, and convert calls to LowerTE to application of LowerTEPass 
(#8802)
    
    * Initial commit
    
    Initial stab at IRModule -> LoweredModule conversion func, notes
    
    Add external_mods and main_func_info to conversion funcs
    
    MTest lowered module to ir module
    
    fix problem with conversion funcs + print stmts
    
    Add LowerTE pass
    
    Add pLowerTEPass
    
    AAdd LowerTEPass to graph_executor_codegen.cc
    
    Use LowerTEPass instead of LowerTe in graph_executor_codegen.cc
    
    Code cleanup
    
    Add docs, more cleanup
    
    Formatting
    
    * Fix bad rebase
    
    * Address 1st round of comments
    
    * Use tir kTarget instead of relay one
    
    * Change target string to Target obj
    
    * removing target string causing issues
    
    * Fix typos
    
    * Revert target str -> target obj changes
    
    * Don't use Update : IRModule because it is broken
    
    * Fix check
    
    * flaky test?
    
    * lint
---
 include/tvm/relay/function.h                |   1 -
 src/relay/backend/aot_executor_codegen.cc   |   8 +-
 src/relay/backend/graph_executor_codegen.cc |   7 +-
 src/relay/backend/te_compiler.cc            | 117 ++++++++++++++++++++++++++--
 src/relay/backend/te_compiler.h             |  49 +++++++++++-
 5 files changed, 167 insertions(+), 15 deletions(-)

diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h
index fccd1f9..9170bc5 100644
--- a/include/tvm/relay/function.h
+++ b/include/tvm/relay/function.h
@@ -144,7 +144,6 @@ constexpr const char* kComposite = "Composite";
 constexpr const char* kInline = "Inline";
 /*! \brief Indicate the function was created by the Pattern Partitioning Pass. 
*/
 constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
-
 /*! \brief Mark the function as only composed of reshape operations. */
 constexpr const char* kReshapeOnly = "relay.reshape_only";
 }  // namespace attr
diff --git a/src/relay/backend/aot_executor_codegen.cc 
b/src/relay/backend/aot_executor_codegen.cc
index 54a10ad..942bc0d 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -586,8 +586,9 @@ class AOTExecutorCodegen : public ExprVisitor {
     // to instead explicitly lowering the incoming IRModule, and then
     // performing the preexisting AOT executor code generation phase.
     IRModule mod = IRModule::FromExpr(func);
-    auto lowered_module = tec::LowerTE(
-        mod, targets_, device_context_map, memory_plan, mod_name, 
[this](Function func) {
+
+    IRModule new_mod =
+        LowerTEPass(targets_, device_context_map, memory_plan, mod_name, 
[this](Function func) {
           // We need to maintain the constant map for external
           // functions so we pass this processing function which
           // allows us to process each function as we lower it.
@@ -599,8 +600,9 @@ class AOTExecutorCodegen : public ExprVisitor {
           // execute as a further pass, instead writing data to the
           // lowering process directly.
           tec::UpdateFunctionMetadata(func, this->function_metadata_);
-        });
+        })(mod);
 
+    tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
     function_metadata_.Set(runtime::symbol::tvm_module_main, 
lowered_module.main_func_info);
     auto lowered_main = lowered_module.main_module->Lookup("main");
     auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
diff --git a/src/relay/backend/graph_executor_codegen.cc 
b/src/relay/backend/graph_executor_codegen.cc
index cc54a52..486a6dc 100644
--- a/src/relay/backend/graph_executor_codegen.cc
+++ b/src/relay/backend/graph_executor_codegen.cc
@@ -221,8 +221,8 @@ class GraphExecutorCodegen : public 
backend::MemoizedExprTranslator<std::vector<
       device_context_map.insert({expr, dev});
     }
 
-    auto lowered_module = tec::LowerTE(
-        mod, targets_, device_context_map, memory_plan_, mod_name_, 
[this](Function func) {
+    IRModule new_mod =
+        LowerTEPass(targets_, device_context_map, memory_plan_, mod_name_, 
[this](Function func) {
           // We need to maintain the constant map for external
           // functions so we pass this processing function which
           // allows us to process each function as we lower it.
@@ -234,8 +234,9 @@ class GraphExecutorCodegen : public 
backend::MemoizedExprTranslator<std::vector<
           // execute as a further pass, instead writing data to the
           // lowering process directly.
           tec::UpdateFunctionMetadata(func, this->function_metadata_);
-        });
+        })(mod);
 
+    tec::LoweredModule lowered_module = tec::IRModuleToLoweredModule(new_mod);
     function_metadata_.Set(runtime::symbol::tvm_module_main, 
lowered_module.main_func_info);
     auto main_module = lowered_module.main_module;
     main_module = relay::transform::InferType()(main_module);
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index 93fcf73..71ac752 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -20,6 +20,7 @@
 #include "te_compiler.h"
 
 #include <tvm/driver/driver_api.h>
+#include <tvm/ir/attrs.h>
 #include <tvm/ir/function.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
@@ -749,8 +750,6 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const 
IRModule& mod, TargetMap tar
                                relay_primfuncs);
 }
 
-// TODO(@electriclilies): Is the function passed in here relay_func??
-// Also should this be inlined?
 /*!
  * \brief A function to create the function metadata for an input function (ie 
calculate buffer
  * input/output sizes)
@@ -830,9 +829,6 @@ void UpdateFunctionMetadata(Function relay_func,
   function_metadata.Set(prim_fn_var.value()->name_hint, fi);
 }
 
-// TODO(mbs): Make this an IRModule->IRModule pass by folding LoweredModule 
back into IRModule.
-// Currently we rely on accumulating bindings inside the local TECompiler 
which we then
-// host into the LoweredModule result.
 LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap 
device_context_map,
                       backend::StaticMemoryPlan memory_plan, const String& 
module_name,
                       std::function<void(Function)> process_fn) {
@@ -875,6 +871,117 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  IRModule unified_module;
+
+  // Copy the main module and its typedefs
+  for (const auto& kv : mod.main_module->functions) {
+    unified_module->Add(kv.first, kv.second);
+  }
+  for (const auto& kv : mod.main_module->type_definitions) {
+    unified_module->AddTypeDef(kv.first, kv.second);
+  }
+
+  // Annotate the per-target functions with their target and add them to the 
unified module
+  for (const auto& kv : mod.per_target_module) {
+    const String target = kv.first;
+    const IRModule target_module = kv.second;
+
+    // Right now, per-target functions are TIR functions, which don't have 
type definitions, so
+    // there should be no type defs in the per_target_modules
+    size_t ty_def_size = target_module->type_definitions.size();
+    ICHECK(ty_def_size == 0)
+        << "Expected there to be no type definitions in the 
per_target_modules, but found "
+        << ty_def_size;
+
+    for (const auto& kv : target_module->functions) {
+      const GlobalVar& var = kv.first;
+      const BaseFunc& func = kv.second;
+      if (func->IsInstance<tir::PrimFuncNode>()) {
+        tir::PrimFunc primFunc =
+            WithAttr(Downcast<tir::PrimFunc>(std::move(func)), 
tvm::attr::kTarget, target);
+        unified_module->Add(var, primFunc);
+      } else if (func->IsInstance<relay::FunctionNode>()) {
+        relay::Function relayFunc =
+            WithAttr(Downcast<relay::Function>(std::move(func)), 
tvm::attr::kTarget, target);
+        unified_module->Add(var, relayFunc);
+      } else {
+        LOG(FATAL)
+            << "We expected to only have PrimFuncs or RelayFuncs in the target 
modules, but found "
+            << func->GetTypeKey();
+      }
+    }
+  }
+
+  IRModule ret_mod = WithAttr(unified_module, "external_mods", 
mod.external_mods);
+  ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
+  return ret_mod;
+}
+
+LoweredModule IRModuleToLoweredModule(IRModule mod) {
+  IRModule main_mod;
+  // Copy just the TypeDefs from the IRModule to the LoweredModule's main 
module
+  // This is the only time we need to do this since there are no TypeDefs in 
TIR
+  for (const auto& kv : mod->type_definitions) {
+    main_mod->AddTypeDef(kv.first, kv.second);
+  }
+
+  Map<String, IRModule> per_target_modules;
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& var = kv.first;
+    const BaseFunc& func = kv.second;
+    if (func->IsInstance<relay::FunctionNode>()) {
+      main_mod->Add(var, func);
+    } else if (func->IsInstance<tir::PrimFuncNode>()) {
+      // Extract target
+      Optional<String> target = func->GetAttr<String>(tvm::attr::kTarget);
+      ICHECK(target) << "Target should be set at this point";
+
+      // Put the function in per_target_modules
+      if (!per_target_modules.count(target.value())) {
+        // Initialize the IRModule for this target and add the function
+        IRModule target_module;
+        target_module->Add(var, func);
+        per_target_modules.Set(target.value(), target_module);
+      } else {
+        // The IRModule for this target is initialized, so just add the 
function.
+        IRModule target_module = per_target_modules.at(target.value());
+        target_module->Add(var, func);
+      }
+    } else {
+      LOG(FATAL)
+          << "The function types in the IRModule should be RelayFunction or 
PrimFunc, but got "
+          << func->GetTypeKey();
+    }
+  }
+
+  // Put the LoweredModule together
+  LoweredModule lowered_module;
+  lowered_module.main_module = main_mod;
+  lowered_module.per_target_module = per_target_modules;
+
+  // Extract external modules and main func info, add to lowered module if 
they exist
+  auto external_mods = 
mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
+  if (external_mods) {
+    lowered_module.external_mods = external_mods.value();
+  }
+  auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
+  if (main_func_info) {
+    lowered_module.main_func_info = main_func_info.value();
+  }
+  return lowered_module;
+}
+
+Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
+                 backend::StaticMemoryPlan memory_plan, const String& 
module_name,
+                 std::function<void(Function)> process_fn) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule module,
+                                                                            
PassContext ctx) {
+    return LoweredModuleToIRModule(
+        LowerTE(module, targets, device_context_map, memory_plan, module_name, 
process_fn));
+  };
+  return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
+}
 }  // namespace tec
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h
index 8376b99..e9cfb0d 100644
--- a/src/relay/backend/te_compiler.h
+++ b/src/relay/backend/te_compiler.h
@@ -166,30 +166,73 @@ void UpdateFunctionMetadata(Function relay_func,
 /*!
  * \brief Obtain the Target from the device type.
  * If homogenous compilation, this will return the only target.
- * If heteregenous compilation, this will select associated using the targets_ 
Map.
+ * If heterogeneous compilation, this will select the associated target using 
the
+ * targets_ Map.
  *
  * \param dev_type
  * \return Target
  */
 Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets);
 
+/*! \brief Utility to convert a LoweredModule to an IRModule.
+ *
+ * This function takes all the target specific modules in LoweredModule and
+ * annotates their functions with the correct target, and puts all those 
functions
+ * in one IRModule.
+ * The purpose of this utility is to allow us to slowly remove LoweredModule 
from the codebase.
+ *
+ * \param mod The LoweredModule to convert.
+ * \return The IRModule form of the input LoweredModule.
+ */
+IRModule LoweredModuleToIRModule(LoweredModule mod);
+
+/*! \brief Utility to convert an IRModule to a LoweredModule.
+ *
+ * This function takes all the functions in the IRModule and moves them into 
target-specific
+ * IRModules stored inside a LoweredModule.
+ * The purpose of this utility is to allow us to slowly remove LoweredModule 
from the codebase.
+ * \param mod The IRModule to convert.
+ * \return The LoweredModule form of the input IRModule.
+ */
+LoweredModule IRModuleToLoweredModule(IRModule mod);
+
 /*! \brief Lower an IRModule's primitive functions to TIR.
  *
  * This is the "back half" of the Relay compiler which lowers "primitive 
functions"
  * to TE expressions, schedules them, and then to TIR.
  *
- * \param compiler The TE-to-TIR compliler (which caches lowered functions)
  * \param module The IRModule.
  * \param targets The mapping for devices to targets.
  * \param device_map An analysis result mapping each sub-expression to a 
device.
+ * \param memory_plan The memory plan used during lowering
+ * \param module_name The name of this module
+ * \param process_fn Callback allowing one-level up code generators to process
+ * each function that we lower
  * \return The lowered module, see above.
  */
-// TODO(@electriclilies): Not sure if this default initialization is correct...
 LoweredModule LowerTE(
     const IRModule& module, TargetMap targets, DeviceMap device_map,
     backend::StaticMemoryPlan memory_plan, const String& module_name,
     ProcessFn process_fn = [](Function f) {});
 
+/*! \brief Pass to lower an IRModule's primitive functions to TIR.
+ *
+ * This is the "back half" of the Relay compiler which lowers "primitive 
functions"
+ * to TE expressions, schedules them, and then to TIR. This Pass calls 
LowerTE, and
+ * uses LoweredModuleToIRModule utility to convert the output LowerTE's output
+ * LoweredModule into an IRModule before returning it.
+ *
+ * \param targets The mapping for devices to targets.
+ * \param device_context_map An analysis result mapping each sub-expression to 
a device.
+ * \param memory_plan The memory plan used during lowering
+ * \param module_name The name of this module
+ * \param process_fn Callback allowing one-level up code generators to process
+ * each function that we lower
+ * \returns The pass which lowers primative functions to TIR
+ */
+transform::Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
+                            backend::StaticMemoryPlan memory_plan, const 
String& module_name,
+                            std::function<void(Function)> process_fn);
 }  // namespace tec
 }  // namespace relay
 }  // namespace tvm

Reply via email to