Mousius commented on a change in pull request #9103:
URL: https://github.com/apache/tvm/pull/9103#discussion_r715590473
##########
File path: python/tvm/relay/build_module.py
##########
@@ -282,6 +282,8 @@ def get_executor_from_target(target, target_host):
return executor
+# Which build is this one... Relay --> graph executor
+# can params being parsed during run-time?
Review comment:
I don't think these comments should be here?
##########
File path: python/tvm/driver/build_module.py
##########
@@ -123,6 +123,7 @@ def lower(
m : IRModule
The result IRModule
"""
+ # ffi.relay.lower_te_pass()
Review comment:
I don't think this comment is necessary?
##########
File path: python/tvm/target/target.py
##########
@@ -174,6 +174,7 @@ def list_kinds():
"""Returns the list of available target names."""
return list(_ffi_api.ListTargetKinds())
+ # TODO: make this return IRModule? idk it seems
Review comment:
This thought is unfinished
##########
File path: python/tvm/driver/build_module.py
##########
@@ -297,25 +238,23 @@ def build(
m1 = tvm.lower(s1, [A, B, C], name="test_add1")
m2 = tvm.lower(s2, [A, B, C], name="test_add2")
rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm")
-
Note
----
See the note on :any:`tvm.target` on target string format.
"""
- if isinstance(inputs, schedule.Schedule):
- if args is None:
- raise ValueError("args must be given for build from schedule")
+
+ # Lowering
+ if isinstance(inputs, (schedule.Schedule, tvm.IRModule, PrimFunc)):
+ # should this be te_lower instead?
Review comment:
Also these comments.
##########
File path: python/tvm/target/codegen.py
##########
@@ -35,8 +35,11 @@ def build_module(mod, target):
module : runtime.Module
The corressponding module.
"""
+ print("In codegen build module")
Review comment:
Remove this `print()`
##########
File path: python/tvm/driver/build_module.py
##########
@@ -149,79 +150,32 @@ def _build_for_device(input_mod, target, target_host):
Returns
-------
- fhost : IRModule
+ host_mod : IRModule
The host IRModule.
- mdev : tvm.module
+ device_mod : tvm.module
A module that contains device code.
"""
- target, target_host = Target.check_and_update_host_consist(target,
target_host)
- device_type = ndarray.device(target.kind.name, 0).device_type
+ from tvm.driver import _ffi_api as _driver_ffi
- mod_mixed = input_mod
- mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
target))(mod_mixed)
-
- opt_mixed = [
- tvm.tir.transform.VerifyMemory(),
- tvm.tir.transform.MergeDynamicSharedMemoryAllocations(),
- ]
- if len(mod_mixed.functions) == 1:
- opt_mixed += [tvm.tir.transform.Apply(lambda f:
f.with_attr("tir.is_entry_func", True))]
-
- if PassContext.current().config.get("tir.detect_global_barrier", False):
- opt_mixed += [tvm.tir.transform.ThreadSync("global")]
- opt_mixed += [
- tvm.tir.transform.ThreadSync("shared"),
- tvm.tir.transform.ThreadSync("warp"),
- tvm.tir.transform.InferFragment(),
- tvm.tir.transform.LowerThreadAllreduce(),
- tvm.tir.transform.MakePackedAPI(),
- tvm.tir.transform.SplitHostDevice(),
- ]
- mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)
-
- # device optimizations
- opt_device = tvm.transform.Sequential(
- [
- tvm.tir.transform.Filter(
- lambda f: "calling_conv" in f.attrs
- and f.attrs["calling_conv"].value ==
CallingConv.DEVICE_KERNEL_LAUNCH
- ),
- tvm.tir.transform.LowerWarpMemory(),
- tvm.tir.transform.Simplify(),
- tvm.tir.transform.LowerDeviceStorageAccessInfo(),
- tvm.tir.transform.LowerCustomDatatypes(),
- tvm.tir.transform.LowerIntrin(),
- ]
- )
- mod_dev = opt_device(mod_mixed)
-
- # host optimizations
- opt_host = tvm.transform.Sequential(
- [
- tvm.tir.transform.Filter(
- lambda f: "calling_conv" not in f.attrs
- or f.attrs["calling_conv"].value !=
CallingConv.DEVICE_KERNEL_LAUNCH
- ),
- tvm.tir.transform.Apply(lambda f: f.with_attr("target",
target_host)),
- tvm.tir.transform.LowerTVMBuiltin(),
- tvm.tir.transform.LowerDeviceStorageAccessInfo(),
- tvm.tir.transform.LowerCustomDatatypes(),
- tvm.tir.transform.LowerIntrin(),
- tvm.tir.transform.CombineContextCall(),
- ]
- )
- mod_host = opt_host(mod_mixed)
+ mod_mixed = _driver_ffi.get_mod_mixed(input_mod, target)
+ device_mod = _driver_ffi.get_device_mod(mod_mixed, target)
+ host_mod = _driver_ffi.get_host_mod(mod_mixed, target_host)
+ device_type = ndarray.device(target.kind.name, 0).device_type
if device_type == ndarray.cpu(0).device_type and target_host == target:
- assert len(mod_dev.functions) == 0
- if "gpu" in target.keys and len(mod_dev.functions) == 0:
+ assert len(device_mod.functions) == 0
+ if "gpu" in target.keys and len(device_mod.functions) == 0:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target
)
- rt_mod_dev = codegen.build_module(mod_dev, target) if
len(mod_dev.functions) != 0 else None
- return mod_host, rt_mod_dev
+ # rt_mod_dev is runtime::Module so this can be moved out maybe?
+ rt_mod_dev = (
+ codegen.build_module(device_mod, target) if len(device_mod.functions)
!= 0 else None
+ )
+ # TIR module for the host, runtime module for devices?
Review comment:
Do we need to resolve these questions?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]