mikepapadim commented on a change in pull request #9103:
URL: https://github.com/apache/tvm/pull/9103#discussion_r728774603
##########
File path: src/driver/driver_api.cc
##########
@@ -373,97 +394,96 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode);
});
-std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const
Target& target_arg,
- const Target& target_host_arg,
- const transform::PassContext&
pass_ctx) {
+/**
+ * This function takes the input module that contains both the device and host
opts.
+ * Then, it applies transformation on the original module before splitting
into separate modules for
+ * device and host. Then it also applies transformations on the new splitted
modules.
+ */
+std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const
Target& target_arg,
+ const Target& target_host_arg) {
Target target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);
- Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
-
tir::transform::VerifyMemory()};
-
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
- if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier",
Bool(false)).value()) {
- mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
- }
- mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
- mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
- mixed_pass_list.push_back(tir::transform::InferFragment());
- mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+ ICHECK(mod_mixed.defined()) << "This module must be defined";
- if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
- mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
- } else {
- mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
- }
+ mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed,
target));
- mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+ IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed,
target_host));
- auto opt_mixed = transform::Sequential(mixed_pass_list);
- mod_mixed = opt_mixed(std::move(mod_mixed));
-
- // We make an assumption here that the overriden host target
- // can be used alongside the default host codegen based on device type
- // this is so the correct code generator is used later instead of overriding
the target.
- // We need better support for inserting multiple kDLCPU targets as our
current options
- // are kDeviceKernelLaunch or not
- Target overriden_host_target = target_host;
- if (target->kind->device_type == target_host->kind->device_type) {
- overriden_host_target = target;
- }
- auto host_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) !=
- CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(overriden_host_target),
- tir::transform::LowerTVMBuiltin(),
- tir::transform::LowerCustomDatatypes(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
- tir::transform::CombineContextCall(),
- };
- auto opt_host = transform::Sequential(host_pass_list);
- ICHECK(mod_mixed.defined()) << "This module must be defined";
- auto mhost = opt_host(mod_mixed);
-
- // device pipeline
- auto device_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) ==
- CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(target),
- tir::transform::LowerWarpMemory(),
- tir::transform::Simplify(),
- tir::transform::LowerCustomDatatypes(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
- };
- auto opt_device = transform::Sequential(device_pass_list);
- auto mdevice = opt_device(mod_mixed);
+ IRModule device_mod = ApplyPasses(mod_mixed,
DeviceModulePassManager(mod_mixed, target));
- // some final misc checks.
auto keys = target->GetKeys();
+
+ CheckAndUpdateHostConsistency(&target, &target_host);
+
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") !=
keys.end();
- if (target_is_gpu && mdevice->functions.size() == 0) {
- LOG(WARNING) << "Specified target " << target->str()
- << " but cannot find device code. Did you forget to bind?";
+ if (target_is_gpu && device_mod->functions.size() == 0) {
+ DLOG(WARNING) << "Specified target " << target->str()
+ << " but cannot find device code. Did you forget to bind?";
+ }
+
+ return {host_mod, device_mod};
+}
+
+runtime::Module FinalizeModule(const Map<Target, IRModule>& inputs_arg, const
Target& host_target) {
Review comment:
Sure, IMO there is some more cleanup to be done in `build_module`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]