mbs-octoml commented on a change in pull request #9103:
URL: https://github.com/apache/tvm/pull/9103#discussion_r715708496
##########
File path: include/tvm/driver/driver_api.h
##########
@@ -44,6 +44,31 @@
namespace tvm {
+/*!
+ * \brief Returns the optimized IRModule for original fused module (pre split)
that contains device
+ * and host code.
+ * \param mixed_mod The original mixed module.
+ * \param target The device Target.
+ * \return The result optimized mixed module.
+ */
+IRModule OptimizeMixedModule(IRModule mixed_mod, Target target);
+
+/*!
+ * \brief Returns the optimized IRModule for the device Target after
device/host from mixed module.
Review comment:
nit: after device/host splitting
Worth making clearer that the result is just for target (assuming that's the
case)? Ditto for the host one below.
##########
File path: apps/ios_rpc/README.md
##########
@@ -79,7 +79,7 @@ You can get value of your `team_id` in the following ways:
select target `tvmrpc`. At the bottom of this panel go to `Signing &
Capabilities` tab and in the field `Team` select your local developer profile
(`Your Name (Personal Team)`).
-
+
Review comment:
nit: let's just revert these too to demonstrate good practice.
##########
File path: include/tvm/target/codegen.h
##########
@@ -45,7 +45,7 @@ using runtime::TVMRetValue;
* \param target The target to be built.
Review comment:
nit: Might as well update the comment -- we're compiling the PrimFuncs
in IRModule tagged for target to a runtime::Module.
##########
File path: python/tvm/driver/build_module.py
##########
@@ -234,18 +185,14 @@ def build(
):
"""Build a function with arguments as signature. Code will be generated
for devices coupled with target information.
-
Review comment:
nit: did these lines get deleted by black?
##########
File path: python/tvm/driver/build_module.py
##########
@@ -123,6 +122,7 @@ def lower(
m : IRModule
The result IRModule
"""
+ # ffi.relay.lower_te_pass()
Review comment:
nit: intentional?
##########
File path: python/tvm/driver/build_module.py
##########
@@ -297,28 +235,27 @@ def build(
m1 = tvm.lower(s1, [A, B, C], name="test_add1")
m2 = tvm.lower(s2, [A, B, C], name="test_add2")
rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm")
-
Note
----
See the note on :any:`tvm.target` on target string format.
"""
- if isinstance(inputs, schedule.Schedule):
- if args is None:
- raise ValueError("args must be given for build from schedule")
+
+ # Lowering
+ if isinstance(inputs, (schedule.Schedule, tvm.IRModule, PrimFunc)):
+ # should this be te_lower instead?
input_mod = lower(inputs, args, name=name, binds=binds)
elif isinstance(inputs, (list, tuple, container.Array)):
merged_mod = tvm.IRModule({})
for x in inputs:
merged_mod.update(lower(x))
input_mod = merged_mod
- elif isinstance(inputs, (tvm.IRModule, PrimFunc)):
- input_mod = lower(inputs)
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError(
- f"Inputs must be Schedule, IRModule or dict of target to IRModule,
"
+ f"Inputs must be Schedule, PrimFunc, IRModule or dict of target to
IRModule, "
f"but got {type(inputs)}."
)
+ # starts here
Review comment:
nit: maybe sweep for leftover comments
##########
File path: python/tvm/driver/build_module.py
##########
@@ -149,79 +149,30 @@ def _build_for_device(input_mod, target, target_host):
Review comment:
(Is there a way to tell github to allow comments on unchanged sections?)
nit: Still using 'schedule' in the comment above.
##########
File path: src/driver/driver_api.cc
##########
@@ -373,88 +381,119 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode);
});
-std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const
Target& target_arg,
- const Target& target_host_arg,
- const transform::PassContext&
pass_ctx) {
+// Splits module into one to run on the device and one to run the host. E.g.,
CUDA, OpenCL etc
+std::pair<IRModule, IRModule> SplitFuncsToDevHostMods(IRModule mod_mixed,
const Target& target_arg,
Review comment:
Might as well use the Doxygen-friendly comment format and give this a
proper introduction.
##########
File path: src/driver/driver_api.cc
##########
@@ -530,12 +570,108 @@ runtime::Module build(const Map<String, IRModule>&
inputs_arg, const Target& tar
}
// Build for homogeneous execution.
+// Where is this called from?]
+// called from compile engine and it accepts lowered functions
runtime::Module build(const IRModule& funcs, const Target& target_arg,
const Target& target_host_arg) {
auto target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);
+ // More maps of target and target host
Map<Target, IRModule> inputs = {{target, funcs}};
return build(inputs, target_host);
}
+IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) {
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+
+ Array<transform::Pass> mixed_pass_list;
+
+ mixed_pass_list.push_back(BindTarget(target));
+
+ bool is_entry_func = false;
+ if (mixed_mod->functions.size() == 1) {
+ is_entry_func = pass_ctx->GetConfig<Bool>("tir.is_entry_func",
Bool(true)).value();
+ mixed_pass_list.push_back(AnnotateEntryFunc(is_entry_func));
+ }
+
+ mixed_pass_list.push_back(tir::transform::VerifyMemory());
+
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
Review comment:
From here on looks good as it's a mechanical transfer from the py.
##########
File path: src/driver/driver_api.cc
##########
@@ -530,12 +570,108 @@ runtime::Module build(const Map<String, IRModule>&
inputs_arg, const Target& tar
}
// Build for homogeneous execution.
+// Where is this called from?]
+// called from compile engine and it accepts lowered functions
runtime::Module build(const IRModule& funcs, const Target& target_arg,
const Target& target_host_arg) {
auto target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);
+ // More maps of target and target host
Map<Target, IRModule> inputs = {{target, funcs}};
return build(inputs, target_host);
}
+IRModule OptimizeMixedModule(IRModule mixed_mod, Target target) {
+ transform::PassContext pass_ctx = transform::PassContext::Current();
+
+ Array<transform::Pass> mixed_pass_list;
+
+ mixed_pass_list.push_back(BindTarget(target));
+
+ bool is_entry_func = false;
+ if (mixed_mod->functions.size() == 1) {
+ is_entry_func = pass_ctx->GetConfig<Bool>("tir.is_entry_func",
Bool(true)).value();
Review comment:
I know you've transliterated this from _build_for_device but it doesn't
make any sense to me: if the module has a single function with a single
tir.is_entry_func=true annotation then apply the same annotation to the same
function. I think it's a no-op.
##########
File path: src/driver/driver_api.cc
##########
@@ -373,88 +381,119 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode);
});
-std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const
Target& target_arg,
- const Target& target_host_arg,
- const transform::PassContext&
pass_ctx) {
+// Splits module into one to run on the device and one to run the host. E.g.,
CUDA, OpenCL etc
+std::pair<IRModule, IRModule> SplitFuncsToDevHostMods(IRModule mod_mixed,
const Target& target_arg,
+ const Target&
target_host_arg) {
Target target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);
- Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
-
tir::transform::VerifyMemory()};
-
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
- if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier",
Bool(false)).value()) {
- mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
- }
- mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
- mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
- mixed_pass_list.push_back(tir::transform::InferFragment());
- mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+ ICHECK(mod_mixed.defined()) << "This module must be defined";
- if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
- mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
- } else {
- mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
- }
+ mod_mixed = OptimizeMixedModule(mod_mixed, target);
- mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+ auto host_mod = OptimizeHostModule(mod_mixed, target_host);
- auto opt_mixed = transform::Sequential(mixed_pass_list);
- mod_mixed = opt_mixed(std::move(mod_mixed));
-
- auto host_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) !=
- CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(target_host),
- tir::transform::LowerTVMBuiltin(),
- tir::transform::LowerCustomDatatypes(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
- tir::transform::CombineContextCall(),
- };
- auto opt_host = transform::Sequential(host_pass_list);
- ICHECK(mod_mixed.defined()) << "This module must be defined";
- auto mhost = opt_host(mod_mixed);
-
- // device pipeline
- auto device_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) ==
- CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(target),
- tir::transform::LowerWarpMemory(),
- tir::transform::Simplify(),
- tir::transform::LowerCustomDatatypes(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
- };
- auto opt_device = transform::Sequential(device_pass_list);
- auto mdevice = opt_device(mod_mixed);
+ auto device_mod = OptimizeDeviceModule(mod_mixed, target);
// some final misc checks.
auto keys = target->GetKeys();
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") !=
keys.end();
- if (target_is_gpu && mdevice->functions.size() == 0) {
+ if (target_is_gpu && device_mod->functions.size() == 0) {
LOG(WARNING) << "Specified target " << target->str()
<< " but cannot find device code. Did you forget to bind?";
}
- if (target->kind->device_type == kDLCPU && target_host == target) {
- // TODO(@jroesch): This check is no longer true we need to figure out if
we care about this.
- // We need to relax this check for just TIR functions.
- // ICHECK(mdevice->functions.empty()) << "No device code should be
generated when target "
- // << "and host_target are both llvm
target."
- // << "\n";
+ return {host_mod, device_mod};
+}
+
+std::pair<Target, Target> TargetTypeMangling(const Map<Target, IRModule>&
inputs_arg, Target target,
Review comment:
Full comments please.
##########
File path: src/driver/driver_api.cc
##########
@@ -373,88 +381,119 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode);
});
-std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const
Target& target_arg,
- const Target& target_host_arg,
- const transform::PassContext&
pass_ctx) {
+// Splits module into one to run on the device and one to run the host. E.g.,
CUDA, OpenCL etc
+std::pair<IRModule, IRModule> SplitFuncsToDevHostMods(IRModule mod_mixed,
const Target& target_arg,
+ const Target&
target_host_arg) {
Target target = target_arg, target_host = target_host_arg;
CheckAndUpdateHostConsistency(&target, &target_host);
- Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
-
tir::transform::VerifyMemory()};
-
mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
- if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier",
Bool(false)).value()) {
- mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
- }
- mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
- mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
- mixed_pass_list.push_back(tir::transform::InferFragment());
- mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+ ICHECK(mod_mixed.defined()) << "This module must be defined";
- if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
- mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
- } else {
- mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
- }
+ mod_mixed = OptimizeMixedModule(mod_mixed, target);
- mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+ auto host_mod = OptimizeHostModule(mod_mixed, target_host);
- auto opt_mixed = transform::Sequential(mixed_pass_list);
- mod_mixed = opt_mixed(std::move(mod_mixed));
-
- auto host_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) !=
- CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(target_host),
- tir::transform::LowerTVMBuiltin(),
- tir::transform::LowerCustomDatatypes(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
- tir::transform::CombineContextCall(),
- };
- auto opt_host = transform::Sequential(host_pass_list);
- ICHECK(mod_mixed.defined()) << "This module must be defined";
- auto mhost = opt_host(mod_mixed);
-
- // device pipeline
- auto device_pass_list = {
- Filter([](const tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) ==
- CallingConv::kDeviceKernelLaunch;
- }),
- BindTarget(target),
- tir::transform::LowerWarpMemory(),
- tir::transform::Simplify(),
- tir::transform::LowerCustomDatatypes(),
- tir::transform::LowerIntrin(),
- tir::transform::LowerDeviceStorageAccessInfo(),
- };
- auto opt_device = transform::Sequential(device_pass_list);
- auto mdevice = opt_device(mod_mixed);
+ auto device_mod = OptimizeDeviceModule(mod_mixed, target);
// some final misc checks.
auto keys = target->GetKeys();
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") !=
keys.end();
- if (target_is_gpu && mdevice->functions.size() == 0) {
+ if (target_is_gpu && device_mod->functions.size() == 0) {
LOG(WARNING) << "Specified target " << target->str()
<< " but cannot find device code. Did you forget to bind?";
}
- if (target->kind->device_type == kDLCPU && target_host == target) {
- // TODO(@jroesch): This check is no longer true we need to figure out if
we care about this.
- // We need to relax this check for just TIR functions.
- // ICHECK(mdevice->functions.empty()) << "No device code should be
generated when target "
- // << "and host_target are both llvm
target."
- // << "\n";
+ return {host_mod, device_mod};
+}
+
+std::pair<Target, Target> TargetTypeMangling(const Map<Target, IRModule>&
inputs_arg, Target target,
+ Target target_host_arg) {
+ Target target_input_mod, target_host;
+
+ target = !target.defined() ? target.Current() : target;
Review comment:
I've gotten lost at this point: I had no idea we had a global target
context stack and don't understand how that interacts with all the target
arguments we're using here.
Since you've gone pretty deep into this now could you write some comments
explaining the mangling and finalization steps then I can try again.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]