electriclilies commented on a change in pull request #8802:
URL: https://github.com/apache/tvm/pull/8802#discussion_r692615158



##########
File path: include/tvm/relay/function.h
##########
@@ -144,6 +144,8 @@ constexpr const char* kComposite = "Composite";
 constexpr const char* kInline = "Inline";
 /*! \brief Indicate the function was created by the Pattern Partitioning Pass. 
*/
 constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
+/*! \brief Indicate the target that the function should be lowered to. */
+constexpr const char* kTarget = "Target";

Review comment:
       Yup, will do

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map<GlobalVar, BaseFunc> unified_funcs;
+  Map<GlobalTypeVar, TypeData> unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)
+  for (const auto& kv : mod.main_module->functions) {
+    const GlobalVar& var = kv.first;
+    const BaseFunc& func = kv.second;
+    ICHECK(!func->IsInstance<tir::PrimFuncNode>());
+    unified_funcs.Set(var, func);
+  }
+
+  // copy the type definitions for the main module
+  for (const auto& kv : mod.main_module->type_definitions) {
+    const GlobalTypeVar& ty_var = kv.first;
+    const TypeData& ty_data = kv.second;
+    unified_type_defs.Set(ty_var, ty_data);
+  }
+  // Move functions in per target IRModule into unified module
+  // Also move the type definitions
+  for (const auto& kv : mod.per_target_module) {
+    const String target = kv.first;
+    const IRModule target_module = kv.second;
+    // Move the per module functions, and annotate the funcs with their target
+    for (const auto& kv : target_module->functions) {
+      const GlobalVar& var = kv.first;
+      const BaseFunc& func = kv.second;
+      ICHECK(func->IsInstance<tir::PrimFuncNode>())
+          << "We expect the target_module to contain only PrimFuncs at this 
point, but got "
+          << func->GetTypeKey();
+      tir::PrimFunc primFunc = 
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), attr::kTarget,
+                                        runtime::String(target));

Review comment:
       Yup, I can also make this change
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to