jroesch commented on a change in pull request #9103:
URL: https://github.com/apache/tvm/pull/9103#discussion_r728738291



##########
File path: src/driver/driver_api.cc
##########
@@ -373,97 +394,96 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
       return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode);
     });
 
-std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const 
Target& target_arg,
-                                                const Target& target_host_arg,
-                                                const transform::PassContext& 
pass_ctx) {
+/**
+ * This function takes the input module that contains both the device and host 
opts.
+ * Then, it applies transformation on the original module before splitting 
into separate modules for
+ * device and host. Then it also applies transformations on the new splitted 
modules.
+ */
+std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const 
Target& target_arg,
+                                               const Target& target_host_arg) {
   Target target = target_arg, target_host = target_host_arg;
   CheckAndUpdateHostConsistency(&target, &target_host);
-  Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
-                                                 
tir::transform::VerifyMemory()};
 
-  
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
-  if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", 
Bool(false)).value()) {
-    mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
-  }
-  mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
-  mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
-  mixed_pass_list.push_back(tir::transform::InferFragment());
-  mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+  ICHECK(mod_mixed.defined()) << "This module must be defined";
 
-  if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
-    mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
-  } else {
-    mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
-  }
+  mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, 
target));
 
-  mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+  IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, 
target_host));
 
-  auto opt_mixed = transform::Sequential(mixed_pass_list);
-  mod_mixed = opt_mixed(std::move(mod_mixed));
-
-  // We make an assumption here that the overriden host target
-  // can be used alongside the default host codegen based on device type
-  // this is so the correct code generator is used later instead of overriding 
the target.
-  // We need better support for inserting multiple kDLCPU targets as our 
current options
-  // are kDeviceKernelLaunch or not
-  Target overriden_host_target = target_host;
-  if (target->kind->device_type == target_host->kind->device_type) {
-    overriden_host_target = target;
-  }
-  auto host_pass_list = {
-      Filter([](const tir::PrimFunc& f) {
-        return f->GetAttr<Integer>(tvm::attr::kCallingConv, 
Integer(CallingConv::kDefault)) !=
-               CallingConv::kDeviceKernelLaunch;
-      }),
-      BindTarget(overriden_host_target),
-      tir::transform::LowerTVMBuiltin(),
-      tir::transform::LowerCustomDatatypes(),
-      tir::transform::LowerIntrin(),
-      tir::transform::LowerDeviceStorageAccessInfo(),
-      tir::transform::CombineContextCall(),
-  };
-  auto opt_host = transform::Sequential(host_pass_list);
-  ICHECK(mod_mixed.defined()) << "This module must be defined";
-  auto mhost = opt_host(mod_mixed);
-
-  // device pipeline
-  auto device_pass_list = {
-      Filter([](const tir::PrimFunc& f) {
-        return f->GetAttr<Integer>(tvm::attr::kCallingConv, 
Integer(CallingConv::kDefault)) ==
-               CallingConv::kDeviceKernelLaunch;
-      }),
-      BindTarget(target),
-      tir::transform::LowerWarpMemory(),
-      tir::transform::Simplify(),
-      tir::transform::LowerCustomDatatypes(),
-      tir::transform::LowerIntrin(),
-      tir::transform::LowerDeviceStorageAccessInfo(),
-  };
-  auto opt_device = transform::Sequential(device_pass_list);
-  auto mdevice = opt_device(mod_mixed);
+  IRModule device_mod = ApplyPasses(mod_mixed, 
DeviceModulePassManager(mod_mixed, target));
 
-  // some final misc checks.
   auto keys = target->GetKeys();
+
+  CheckAndUpdateHostConsistency(&target, &target_host);
+
   bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != 
keys.end();
-  if (target_is_gpu && mdevice->functions.size() == 0) {
-    LOG(WARNING) << "Specified target " << target->str()
-                 << " but cannot find device code. Did you forget to bind?";
+  if (target_is_gpu && device_mod->functions.size() == 0) {
+    DLOG(WARNING) << "Specified target " << target->str()
+                  << " but cannot find device code. Did you forget to bind?";
+  }
+
+  return {host_mod, device_mod};
+}
+
+runtime::Module FinalizeModule(const Map<Target, IRModule>& inputs_arg, const 
Target& host_target) {

Review comment:
       I think we should probably come up with a better name for this function, 
finalize has too many technical meanings, probably better to call something 
which implies what is going on cc @mbs-octoml going to land as is, maybe we can 
make a note to follow up on this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to