tqchen commented on a change in pull request #8802:
URL: https://github.com/apache/tvm/pull/8802#discussion_r692559893



##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map<GlobalVar, BaseFunc> unified_funcs;
+  Map<GlobalTypeVar, TypeData> unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)
+  for (const auto& kv : mod.main_module->functions) {
+    const GlobalVar& var = kv.first;
+    const BaseFunc& func = kv.second;
+    ICHECK(!func->IsInstance<tir::PrimFuncNode>());
+    unified_funcs.Set(var, func);
+  }
+
+  // copy the type definitions for the main module
+  for (const auto& kv : mod.main_module->type_definitions) {
+    const GlobalTypeVar& ty_var = kv.first;
+    const TypeData& ty_data = kv.second;
+    unified_type_defs.Set(ty_var, ty_data);
+  }
+  // Move functions in per target IRModule into unified module
+  // Also move the type definitions
+  for (const auto& kv : mod.per_target_module) {
+    const String target = kv.first;
+    const IRModule target_module = kv.second;
+    // Move the per module functions, and annotate the funcs with their target
+    for (const auto& kv : target_module->functions) {
+      const GlobalVar& var = kv.first;
+      const BaseFunc& func = kv.second;
+      ICHECK(func->IsInstance<tir::PrimFuncNode>())
+          << "We expect the target_module to contain only PrimFuncs at this 
point, but got "
+          << func->GetTypeKey();
+      tir::PrimFunc primFunc = 
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), attr::kTarget,
+                                        runtime::String(target));
+      unified_funcs.Set(var, primFunc);
+    }
+
+    // Move the type definitions for the per target IRModule
+    for (const auto& kv : target_module->type_definitions) {
+      const GlobalTypeVar& ty_var = kv.first;
+      const TypeData& ty_data = kv.second;
+      unified_type_defs.Set(ty_var, ty_data);
+    }
+  }
+
+  IRModule ret_mod =
+      WithAttr(IRModule(unified_funcs, unified_type_defs), "external_mods", 
mod.external_mods);
+  ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
+  return ret_mod;
+}
+
+LoweredModule IRModuleToLoweredModule(IRModule mod) {
+  Map<GlobalVar, BaseFunc> main_mod_funcs;
+  Map<String, Map<GlobalVar, BaseFunc>> target_funcs;
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& var = kv.first;
+    const BaseFunc& func = kv.second;
+    if (func->IsInstance<relay::FunctionNode>()) {
+      main_mod_funcs.Set(var, func);
+    } else if (func->IsInstance<tir::PrimFuncNode>()) {
+      // Extract target
+      auto target = func->GetAttr<String>(attr::kTarget);
+      ICHECK(!target) << "Target should be set at this point";
+
+      // Put the function in target_funcs
+      if (!target_funcs.count(target.value())) {
+        // Initialize the map and put it in target_funcs
+        Map<GlobalVar, BaseFunc> funcs;
+        funcs.Set(var, func);
+        target_funcs.Set(target.value(), funcs);
+
+      } else {
+        // The map is initialized, so just add the function.
+        Map<GlobalVar, BaseFunc> funcs = target_funcs.at(target.value());
+        funcs.Set(var, func);
+      }
+    } else {
+      LOG(FATAL)
+          << "The function types in the IRModule should be RelayFunction or 
PrimFunc, but got "
+          << func->GetTypeKey();
+    }
+  }
+  // Create the per_target_module map
+  Map<String, IRModule> per_target_modules;
+  for (const auto& kv : target_funcs) {
+    String target = kv.first;
+    Map<GlobalVar, BaseFunc> funcs = kv.second;
+    // Here, we just copy the type defs to every module. Since TIR doesn't use 
the type defs,
+    // this duplication should be OK.
+    per_target_modules.Set(target, IRModule(funcs, mod->type_definitions));
+  }
+  LoweredModule lowered_module;
+  lowered_module.main_module = IRModule(main_mod_funcs, mod->type_definitions);
+  lowered_module.per_target_module = per_target_modules;
+
+  // Extract external modules and main func info, add to lowered module if 
they exist
+  auto external_mods = 
mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
+  if (external_mods) {
+    lowered_module.external_mods = external_mods.value();
+  }
+  auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
+  if (main_func_info) {
+    lowered_module.main_func_info = main_func_info.value();
+  }
+  return lowered_module;
+}
+
+Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
+                 backend::StaticMemoryPlan memory_plan, const String& 
module_name,
+                 std::function<void(Function)> process_fn) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule module,
+                                                                            
PassContext ctx) {
+    return LoweredModuleToIRModule(
+        LowerTE(module, targets, device_context_map, memory_plan, module_name, 
process_fn));
+  };
+  return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
+}

Review comment:
       Currently when we unit-test passes by exposing them via a python API, 
construct the expected input and output and run the tests in python:
   
   
https://github.com/apache/tvm/blob/main/tests/python/unittest/test_tir_transform_loop_partition.py#L30
   
   There are certainly pros and cons of doing so. The original rationale is 
that we require most of the compiler passed to be accessible from python and it 
is relatively easier to construct and expand test cases.
   
   We could revisit this pt on the need of the related testcases
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to