electriclilies commented on a change in pull request #8802:
URL: https://github.com/apache/tvm/pull/8802#discussion_r692553994



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map<GlobalVar, BaseFunc> unified_funcs;
+  Map<GlobalTypeVar, TypeData> unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)

Review comment:
       I resolved this question, the main module functions are all Relay 
functions so we don't need to annotate them with the target. I will add a 
comment about this
   

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map<GlobalVar, BaseFunc> unified_funcs;
+  Map<GlobalTypeVar, TypeData> unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)

Review comment:
       I'll change this

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -593,6 +593,14 @@ Pass LowerTensorExpr(TargetMap targets, DeviceMap 
device_context_map,
   return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
 }
 
+/*!
+ * \brief Obtain the Target from the device type.
+ * If homogenous compilation, this will return the only target.
+ * If heteregenous compilation, this will select associated using the targets_ 
Map.

Review comment:
       I am making this change in the header file

##########
File path: python/tvm/relay/backend/graph_executor_codegen.py
##########
@@ -53,7 +53,7 @@ def __init__(self, mod, target):
         self._get_irmodule = self._mod["get_irmodule"]
         self._setup(mod, target)
 
-    def _setup(self, mod, target):
+    def _setup(self, mod, target: Dict[int, Target]):

Review comment:
       Yeah I think this slipped through, I think this was maybe the ideal type

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -48,6 +48,7 @@ namespace backend {
 using IntegerArray = Array<Integer>;
 using StorageMap =
     std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, 
runtime::ObjectPtrEqual>;
+using namespace tec;

Review comment:
       Sure, I can use tec:: instead

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -593,6 +593,14 @@ Pass LowerTensorExpr(TargetMap targets, DeviceMap 
device_context_map,
   return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
 }
 
+/*!
+ * \brief Obtain the Target from the device type.
+ * If homogenous compilation, this will return the only target.
+ * If heteregenous compilation, this will select associated using the targets_ 
Map.
+ *
+ * \param dev_type
+ * \return Target
+ */

Review comment:
       Huh, I'm honestly not sure how this got in here-- maybe it was when I 
was moving stuff around.. I will remove and add @tkonolige's suggestion to the 
description in the header file

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map<GlobalVar, BaseFunc> unified_funcs;
+  Map<GlobalTypeVar, TypeData> unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)
+  for (const auto& kv : mod.main_module->functions) {
+    const GlobalVar& var = kv.first;
+    const BaseFunc& func = kv.second;
+    ICHECK(!func->IsInstance<tir::PrimFuncNode>());
+    unified_funcs.Set(var, func);
+  }
+
+  // copy the type definitions for the main module
+  for (const auto& kv : mod.main_module->type_definitions) {
+    const GlobalTypeVar& ty_var = kv.first;
+    const TypeData& ty_data = kv.second;
+    unified_type_defs.Set(ty_var, ty_data);
+  }
+  // Move functions in per target IRModule into unified module
+  // Also move the type definitions
+  for (const auto& kv : mod.per_target_module) {
+    const String target = kv.first;
+    const IRModule target_module = kv.second;
+    // Move the per module functions, and annotate the funcs with their target
+    for (const auto& kv : target_module->functions) {
+      const GlobalVar& var = kv.first;
+      const BaseFunc& func = kv.second;
+      ICHECK(func->IsInstance<tir::PrimFuncNode>())
+          << "We expect the target_module to contain only PrimFuncs at this 
point, but got "
+          << func->GetTypeKey();
+      tir::PrimFunc primFunc = 
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), attr::kTarget,
+                                        runtime::String(target));
+      unified_funcs.Set(var, primFunc);
+    }
+
+    // Move the type definitions for the per target IRModule

Review comment:
       Sure, that seems like a better approach

##########
File path: src/relay/backend/te_compiler.h
##########
@@ -184,12 +206,15 @@ Target GetTargetFromInteger(DLDeviceType dev_type, 
TargetMap targets);
  * \param device_map An analysis result mapping each sub-expression to a 
device.
  * \return The lowered module, see above.
  */
-// TODO(@electriclilies): Not sure if this default initialization is correct...
 LoweredModule LowerTE(
     const IRModule& module, TargetMap targets, DeviceMap device_map,
     backend::StaticMemoryPlan memory_plan, const String& module_name,
     ProcessFn process_fn = [](Function f) {});
 
+using namespace transform;
+Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,

Review comment:
       Will do

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map<GlobalVar, BaseFunc> unified_funcs;
+  Map<GlobalTypeVar, TypeData> unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)
+  for (const auto& kv : mod.main_module->functions) {
+    const GlobalVar& var = kv.first;
+    const BaseFunc& func = kv.second;
+    ICHECK(!func->IsInstance<tir::PrimFuncNode>());
+    unified_funcs.Set(var, func);
+  }
+
+  // copy the type definitions for the main module
+  for (const auto& kv : mod.main_module->type_definitions) {
+    const GlobalTypeVar& ty_var = kv.first;
+    const TypeData& ty_data = kv.second;
+    unified_type_defs.Set(ty_var, ty_data);
+  }
+  // Move functions in per target IRModule into unified module
+  // Also move the type definitions
+  for (const auto& kv : mod.per_target_module) {
+    const String target = kv.first;
+    const IRModule target_module = kv.second;
+    // Move the per module functions, and annotate the funcs with their target
+    for (const auto& kv : target_module->functions) {
+      const GlobalVar& var = kv.first;
+      const BaseFunc& func = kv.second;
+      ICHECK(func->IsInstance<tir::PrimFuncNode>())
+          << "We expect the target_module to contain only PrimFuncs at this 
point, but got "
+          << func->GetTypeKey();
+      tir::PrimFunc primFunc = 
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), attr::kTarget,
+                                        runtime::String(target));
+      unified_funcs.Set(var, primFunc);
+    }
+
+    // Move the type definitions for the per target IRModule
+    for (const auto& kv : target_module->type_definitions) {
+      const GlobalTypeVar& ty_var = kv.first;
+      const TypeData& ty_data = kv.second;
+      unified_type_defs.Set(ty_var, ty_data);
+    }
+  }
+
+  IRModule ret_mod =
+      WithAttr(IRModule(unified_funcs, unified_type_defs), "external_mods", 
mod.external_mods);
+  ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
+  return ret_mod;
+}
+
+LoweredModule IRModuleToLoweredModule(IRModule mod) {
+  Map<GlobalVar, BaseFunc> main_mod_funcs;
+  Map<String, Map<GlobalVar, BaseFunc>> target_funcs;
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& var = kv.first;
+    const BaseFunc& func = kv.second;
+    if (func->IsInstance<relay::FunctionNode>()) {
+      main_mod_funcs.Set(var, func);
+    } else if (func->IsInstance<tir::PrimFuncNode>()) {
+      // Extract target
+      auto target = func->GetAttr<String>(attr::kTarget);
+      ICHECK(!target) << "Target should be set at this point";
+
+      // Put the function in target_funcs
+      if (!target_funcs.count(target.value())) {
+        // Initialize the map and put it in target_funcs
+        Map<GlobalVar, BaseFunc> funcs;
+        funcs.Set(var, func);
+        target_funcs.Set(target.value(), funcs);
+
+      } else {
+        // The map is initialized, so just add the function.
+        Map<GlobalVar, BaseFunc> funcs = target_funcs.at(target.value());
+        funcs.Set(var, func);
+      }
+    } else {
+      LOG(FATAL)
+          << "The function types in the IRModule should be RelayFunction or 
PrimFunc, but got "
+          << func->GetTypeKey();
+    }
+  }
+  // Create the per_target_module map
+  Map<String, IRModule> per_target_modules;
+  for (const auto& kv : target_funcs) {
+    String target = kv.first;
+    Map<GlobalVar, BaseFunc> funcs = kv.second;
+    // Here, we just copy the type defs to every module. Since TIR doesn't use 
the type defs,
+    // this duplication should be OK.
+    per_target_modules.Set(target, IRModule(funcs, mod->type_definitions));
+  }
+  LoweredModule lowered_module;
+  lowered_module.main_module = IRModule(main_mod_funcs, mod->type_definitions);
+  lowered_module.per_target_module = per_target_modules;
+
+  // Extract external modules and main func info, add to lowered module if 
they exist
+  auto external_mods = 
mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
+  if (external_mods) {
+    lowered_module.external_mods = external_mods.value();
+  }
+  auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
+  if (main_func_info) {
+    lowered_module.main_func_info = main_func_info.value();
+  }
+  return lowered_module;
+}
+
+Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
+                 backend::StaticMemoryPlan memory_plan, const String& 
module_name,
+                 std::function<void(Function)> process_fn) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule module,
+                                                                            
PassContext ctx) {
+    return LoweredModuleToIRModule(
+        LowerTE(module, targets, device_context_map, memory_plan, module_name, 
process_fn));
+  };
+  return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
+}

Review comment:
       This is a good point. Since these two functions are supposed to be 
inverses of each other, it would be pretty easy to write a unit test for it in 
theory. When I was developing, I actually inserted the conversions in some 
places and ran existing unit tests to make sure that the functions worked, but 
it would be great to have a way to directly write unit tests in C++. That way I 
wouldn't have to remove my tests before merging!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to