jroesch commented on a change in pull request #9103:
URL: https://github.com/apache/tvm/pull/9103#discussion_r728738291
##########
File path: src/driver/driver_api.cc
##########
@@ -373,97 +394,96 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode);
});
-std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const
Target& target_arg,
- const Target& target_host_arg,
- const transform::PassContext&
pass_ctx) {
+/**
+ * This function takes the input module that contains both the device and host
opts.
+ * Then, it applies transformation on the original module before splitting
into separate modules for
+ * device and host. Then it also applies transformations on the new splitted
modules.
+ */
+std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const
Target& target_arg,
+ const Target& target_host_arg) {
Target target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);
- Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
-
tir::transform::VerifyMemory()};
-
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
- if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier",
Bool(false)).value()) {
- mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
- }
- mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
- mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
- mixed_pass_list.push_back(tir::transform::InferFragment());
- mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+ ICHECK(mod_mixed.defined()) << "This module must be defined";
- if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
- mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
- } else {
- mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
- }
+ mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed,
target));
- mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+ IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed,
target_host));
- auto opt_mixed = transform::Sequential(mixed_pass_list);
- mod_mixed = opt_mixed(std::move(mod_mixed));
-
- // We make an assumption here that the overriden host target
- // can be used alongside the default host codegen based on device type
- // this is so the correct code generator is used later instead of overriding
the target.
- // We need better support for inserting multiple kDLCPU targets as our
current options
- // are kDeviceKernelLaunch or not
- Target overriden_host_target = target_host;
- if (target->kind->device_type == target_host->kind->device_type) {
- overriden_host_target = target;
- }
- auto host_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) !=
- CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(overriden_host_target),
- tir::transform::LowerTVMBuiltin(),
- tir::transform::LowerCustomDatatypes(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
- tir::transform::CombineContextCall(),
- };
- auto opt_host = transform::Sequential(host_pass_list);
- ICHECK(mod_mixed.defined()) << "This module must be defined";
- auto mhost = opt_host(mod_mixed);
-
- // device pipeline
- auto device_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) ==
- CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(target),
- tir::transform::LowerWarpMemory(),
- tir::transform::Simplify(),
- tir::transform::LowerCustomDatatypes(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
- };
- auto opt_device = transform::Sequential(device_pass_list);
- auto mdevice = opt_device(mod_mixed);
+ IRModule device_mod = ApplyPasses(mod_mixed,
DeviceModulePassManager(mod_mixed, target));
- // some final misc checks.
auto keys = target->GetKeys();
+
+ CheckAndUpdateHostConsistency(&target, &target_host);
+
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") !=
keys.end();
- if (target_is_gpu && mdevice->functions.size() == 0) {
- LOG(WARNING) << "Specified target " << target->str()
- << " but cannot find device code. Did you forget to bind?";
+ if (target_is_gpu && device_mod->functions.size() == 0) {
+ DLOG(WARNING) << "Specified target " << target->str()
+ << " but cannot find device code. Did you forget to bind?";
+ }
+
+ return {host_mod, device_mod};
+}
+
+runtime::Module FinalizeModule(const Map<Target, IRModule>& inputs_arg, const
Target& host_target) {
Review comment:
I think we should probably come up with a better name for this function,
finalize has too many technical meanings, probably better to call something
which implies what is going on cc @mbs-octoml going to land as is, maybe we can
make a note to follow up on this.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]