[GitHub] [tvm] tqchen commented on a change in pull request #8802: Add LowerTEPass, and convert calls to LowerTE to application of LowerTEPass

2021-08-20 Thread GitBox


tqchen commented on a change in pull request #8802:
URL: https://github.com/apache/tvm/pull/8802#discussion_r692934660



##
File path: src/relay/backend/te_compiler.cc
##
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map unified_funcs;
+  Map unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)
+  for (const auto& kv : mod.main_module->functions) {
+const GlobalVar& var = kv.first;
+const BaseFunc& func = kv.second;
+ICHECK(!func->IsInstance());
+unified_funcs.Set(var, func);
+  }
+
+  // copy the type definitions for the main module
+  for (const auto& kv : mod.main_module->type_definitions) {
+const GlobalTypeVar& ty_var = kv.first;
+const TypeData& ty_data = kv.second;
+unified_type_defs.Set(ty_var, ty_data);
+  }
+  // Move functions in per target IRModule into unified module
+  // Also move the type definitions
+  for (const auto& kv : mod.per_target_module) {
+const String target = kv.first;
+const IRModule target_module = kv.second;
+// Move the per module functions, and annotate the funcs with their target
+for (const auto& kv : target_module->functions) {
+  const GlobalVar& var = kv.first;
+  const BaseFunc& func = kv.second;
+  ICHECK(func->IsInstance())
+  << "We expect the target_module to contain only PrimFuncs at this 
point, but got "
+  << func->GetTypeKey();
+  tir::PrimFunc primFunc = 
WithAttr(Downcast(std::move(func)), attr::kTarget,
+runtime::String(target));
+  unified_funcs.Set(var, primFunc);
+}
+
+// Move the type definitions for the per target IRModule
+for (const auto& kv : target_module->type_definitions) {
+  const GlobalTypeVar& ty_var = kv.first;
+  const TypeData& ty_data = kv.second;
+  unified_type_defs.Set(ty_var, ty_data);
+}
+  }
+
+  IRModule ret_mod =
+  WithAttr(IRModule(unified_funcs, unified_type_defs), "external_mods", 
mod.external_mods);
+  ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
+  return ret_mod;
+}
+
+LoweredModule IRModuleToLoweredModule(IRModule mod) {
+  Map main_mod_funcs;
+  Map> target_funcs;
+  for (const auto& kv : mod->functions) {
+const GlobalVar& var = kv.first;
+const BaseFunc& func = kv.second;
+if (func->IsInstance()) {
+  main_mod_funcs.Set(var, func);
+} else if (func->IsInstance()) {
+  // Extract target
+  auto target = func->GetAttr(attr::kTarget);
+  ICHECK(!target) << "Target should be set at this point";
+
+  // Put the function in target_funcs
+  if (!target_funcs.count(target.value())) {
+// Initialize the map and put it in target_funcs
+Map funcs;
+funcs.Set(var, func);
+target_funcs.Set(target.value(), funcs);
+
+  } else {
+// The map is initialized, so just add the function.
+Map funcs = target_funcs.at(target.value());
+funcs.Set(var, func);
+  }
+} else {
+  LOG(FATAL)
+  << "The function types in the IRModule should be RelayFunction or 
PrimFunc, but got "
+  << func->GetTypeKey();
+}
+  }
+  // Create the per_target_module map
+  Map per_target_modules;
+  for (const auto& kv : target_funcs) {
+String target = kv.first;
+Map funcs = kv.second;
+// Here, we just copy the type defs to every module. Since TIR doesn't use 
the type defs,
+// this duplication should be OK.
+per_target_modules.Set(target, IRModule(funcs, mod->type_definitions));
+  }
+  LoweredModule lowered_module;
+  lowered_module.main_module = IRModule(main_mod_funcs, mod->type_definitions);
+  lowered_module.per_target_module = per_target_modules;
+
+  // Extract external modules and main func info, add to lowered module if 
they exist
+  auto external_mods = 
mod->GetAttr>("external_mods");
+  if (external_mods) {
+lowered_module.external_mods = external_mods.value();
+  }
+  auto main_func_info = mod->GetAttr("main_func_info");
+  if (main_func_info) {
+lowered_module.main_func_info = main_func_info.value();
+  }
+  return lowered_module;
+}
+
+Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
+ backend::StaticMemoryPlan memory_plan, const String& 
module_name,
+ std::function process_fn) {
+  runtime::TypedPackedFunc pass_func = 
[=](IRModule module,
+
PassContext ctx) {
+return LoweredModuleToIRModule(
+LowerTE(module, targets, device_context_map, memory_plan, module_name, 
process_fn));
+  };
+  return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
+}

Review comment:
   for most of the passes that can be modularized, we encourage the python 

[GitHub] [tvm] tqchen commented on a change in pull request #8802: Add LowerTEPass, and convert calls to LowerTE to application of LowerTEPass

2021-08-19 Thread GitBox


tqchen commented on a change in pull request #8802:
URL: https://github.com/apache/tvm/pull/8802#discussion_r692559893



##
File path: src/relay/backend/te_compiler.cc
##
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map unified_funcs;
+  Map unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)
+  for (const auto& kv : mod.main_module->functions) {
+const GlobalVar& var = kv.first;
+const BaseFunc& func = kv.second;
+ICHECK(!func->IsInstance());
+unified_funcs.Set(var, func);
+  }
+
+  // copy the type definitions for the main module
+  for (const auto& kv : mod.main_module->type_definitions) {
+const GlobalTypeVar& ty_var = kv.first;
+const TypeData& ty_data = kv.second;
+unified_type_defs.Set(ty_var, ty_data);
+  }
+  // Move functions in per target IRModule into unified module
+  // Also move the type definitions
+  for (const auto& kv : mod.per_target_module) {
+const String target = kv.first;
+const IRModule target_module = kv.second;
+// Move the per module functions, and annotate the funcs with their target
+for (const auto& kv : target_module->functions) {
+  const GlobalVar& var = kv.first;
+  const BaseFunc& func = kv.second;
+  ICHECK(func->IsInstance())
+  << "We expect the target_module to contain only PrimFuncs at this 
point, but got "
+  << func->GetTypeKey();
+  tir::PrimFunc primFunc = 
WithAttr(Downcast(std::move(func)), attr::kTarget,
+runtime::String(target));
+  unified_funcs.Set(var, primFunc);
+}
+
+// Move the type definitions for the per target IRModule
+for (const auto& kv : target_module->type_definitions) {
+  const GlobalTypeVar& ty_var = kv.first;
+  const TypeData& ty_data = kv.second;
+  unified_type_defs.Set(ty_var, ty_data);
+}
+  }
+
+  IRModule ret_mod =
+  WithAttr(IRModule(unified_funcs, unified_type_defs), "external_mods", 
mod.external_mods);
+  ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
+  return ret_mod;
+}
+
+LoweredModule IRModuleToLoweredModule(IRModule mod) {
+  Map main_mod_funcs;
+  Map> target_funcs;
+  for (const auto& kv : mod->functions) {
+const GlobalVar& var = kv.first;
+const BaseFunc& func = kv.second;
+if (func->IsInstance()) {
+  main_mod_funcs.Set(var, func);
+} else if (func->IsInstance()) {
+  // Extract target
+  auto target = func->GetAttr(attr::kTarget);
+  ICHECK(!target) << "Target should be set at this point";
+
+  // Put the function in target_funcs
+  if (!target_funcs.count(target.value())) {
+// Initialize the map and put it in target_funcs
+Map funcs;
+funcs.Set(var, func);
+target_funcs.Set(target.value(), funcs);
+
+  } else {
+// The map is initialized, so just add the function.
+Map funcs = target_funcs.at(target.value());
+funcs.Set(var, func);
+  }
+} else {
+  LOG(FATAL)
+  << "The function types in the IRModule should be RelayFunction or 
PrimFunc, but got "
+  << func->GetTypeKey();
+}
+  }
+  // Create the per_target_module map
+  Map per_target_modules;
+  for (const auto& kv : target_funcs) {
+String target = kv.first;
+Map funcs = kv.second;
+// Here, we just copy the type defs to every module. Since TIR doesn't use 
the type defs,
+// this duplication should be OK.
+per_target_modules.Set(target, IRModule(funcs, mod->type_definitions));
+  }
+  LoweredModule lowered_module;
+  lowered_module.main_module = IRModule(main_mod_funcs, mod->type_definitions);
+  lowered_module.per_target_module = per_target_modules;
+
+  // Extract external modules and main func info, add to lowered module if 
they exist
+  auto external_mods = 
mod->GetAttr>("external_mods");
+  if (external_mods) {
+lowered_module.external_mods = external_mods.value();
+  }
+  auto main_func_info = mod->GetAttr("main_func_info");
+  if (main_func_info) {
+lowered_module.main_func_info = main_func_info.value();
+  }
+  return lowered_module;
+}
+
+Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
+ backend::StaticMemoryPlan memory_plan, const String& 
module_name,
+ std::function process_fn) {
+  runtime::TypedPackedFunc pass_func = 
[=](IRModule module,
+
PassContext ctx) {
+return LoweredModuleToIRModule(
+LowerTE(module, targets, device_context_map, memory_plan, module_name, 
process_fn));
+  };
+  return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
+}

Review comment:
   Currently when we unit-test passes by exposing them via a python API, 

[GitHub] [tvm] tqchen commented on a change in pull request #8802: Add LowerTEPass, and convert calls to LowerTE to application of LowerTEPass

2021-08-19 Thread GitBox


tqchen commented on a change in pull request #8802:
URL: https://github.com/apache/tvm/pull/8802#discussion_r692559893



##
File path: src/relay/backend/te_compiler.cc
##
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map unified_funcs;
+  Map unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)
+  for (const auto& kv : mod.main_module->functions) {
+const GlobalVar& var = kv.first;
+const BaseFunc& func = kv.second;
+ICHECK(!func->IsInstance());
+unified_funcs.Set(var, func);
+  }
+
+  // copy the type definitions for the main module
+  for (const auto& kv : mod.main_module->type_definitions) {
+const GlobalTypeVar& ty_var = kv.first;
+const TypeData& ty_data = kv.second;
+unified_type_defs.Set(ty_var, ty_data);
+  }
+  // Move functions in per target IRModule into unified module
+  // Also move the type definitions
+  for (const auto& kv : mod.per_target_module) {
+const String target = kv.first;
+const IRModule target_module = kv.second;
+// Move the per module functions, and annotate the funcs with their target
+for (const auto& kv : target_module->functions) {
+  const GlobalVar& var = kv.first;
+  const BaseFunc& func = kv.second;
+  ICHECK(func->IsInstance())
+  << "We expect the target_module to contain only PrimFuncs at this 
point, but got "
+  << func->GetTypeKey();
+  tir::PrimFunc primFunc = 
WithAttr(Downcast(std::move(func)), attr::kTarget,
+runtime::String(target));
+  unified_funcs.Set(var, primFunc);
+}
+
+// Move the type definitions for the per target IRModule
+for (const auto& kv : target_module->type_definitions) {
+  const GlobalTypeVar& ty_var = kv.first;
+  const TypeData& ty_data = kv.second;
+  unified_type_defs.Set(ty_var, ty_data);
+}
+  }
+
+  IRModule ret_mod =
+  WithAttr(IRModule(unified_funcs, unified_type_defs), "external_mods", 
mod.external_mods);
+  ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
+  return ret_mod;
+}
+
+LoweredModule IRModuleToLoweredModule(IRModule mod) {
+  Map main_mod_funcs;
+  Map> target_funcs;
+  for (const auto& kv : mod->functions) {
+const GlobalVar& var = kv.first;
+const BaseFunc& func = kv.second;
+if (func->IsInstance()) {
+  main_mod_funcs.Set(var, func);
+} else if (func->IsInstance()) {
+  // Extract target
+  auto target = func->GetAttr(attr::kTarget);
+  ICHECK(!target) << "Target should be set at this point";
+
+  // Put the function in target_funcs
+  if (!target_funcs.count(target.value())) {
+// Initialize the map and put it in target_funcs
+Map funcs;
+funcs.Set(var, func);
+target_funcs.Set(target.value(), funcs);
+
+  } else {
+// The map is initialized, so just add the function.
+Map funcs = target_funcs.at(target.value());
+funcs.Set(var, func);
+  }
+} else {
+  LOG(FATAL)
+  << "The function types in the IRModule should be RelayFunction or 
PrimFunc, but got "
+  << func->GetTypeKey();
+}
+  }
+  // Create the per_target_module map
+  Map per_target_modules;
+  for (const auto& kv : target_funcs) {
+String target = kv.first;
+Map funcs = kv.second;
+// Here, we just copy the type defs to every module. Since TIR doesn't use 
the type defs,
+// this duplication should be OK.
+per_target_modules.Set(target, IRModule(funcs, mod->type_definitions));
+  }
+  LoweredModule lowered_module;
+  lowered_module.main_module = IRModule(main_mod_funcs, mod->type_definitions);
+  lowered_module.per_target_module = per_target_modules;
+
+  // Extract external modules and main func info, add to lowered module if 
they exist
+  auto external_mods = 
mod->GetAttr>("external_mods");
+  if (external_mods) {
+lowered_module.external_mods = external_mods.value();
+  }
+  auto main_func_info = mod->GetAttr("main_func_info");
+  if (main_func_info) {
+lowered_module.main_func_info = main_func_info.value();
+  }
+  return lowered_module;
+}
+
+Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
+ backend::StaticMemoryPlan memory_plan, const String& 
module_name,
+ std::function process_fn) {
+  runtime::TypedPackedFunc pass_func = 
[=](IRModule module,
+
PassContext ctx) {
+return LoweredModuleToIRModule(
+LowerTE(module, targets, device_context_map, memory_plan, module_name, 
process_fn));
+  };
+  return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
+}

Review comment:
   Currently when we unit-test passes by exposing them via a python API, 

[GitHub] [tvm] tqchen commented on a change in pull request #8802: Add LowerTEPass, and convert calls to LowerTE to application of LowerTEPass

2021-08-19 Thread GitBox


tqchen commented on a change in pull request #8802:
URL: https://github.com/apache/tvm/pull/8802#discussion_r692557835



##
File path: include/tvm/relay/function.h
##
@@ -144,6 +144,8 @@ constexpr const char* kComposite = "Composite";
 constexpr const char* kInline = "Inline";
 /*! \brief Indicate the function was created by the Pattern Partitioning Pass. 
*/
 constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
+/*! \brief Indicate the target that the function should be lowered to. */
+constexpr const char* kTarget = "Target";

Review comment:
   The other "target" attr is used in 
https://github.com/apache/tvm/blob/main/include/tvm/ir/function.h#L170 for tir 
functions, would be great if we can consolidate




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org