Kathryn-cat commented on code in PR #18055:
URL: https://github.com/apache/tvm/pull/18055#discussion_r2147024358
##########
python/tvm/tir/build.py:
##########
@@ -44,28 +43,22 @@ def split_host_device_mods(mod):
A dict mapping targets to device modules
"""
- class CallConv(enum.IntEnum):
- """Enum representing different calling conventions.
- Corresponds to the C++ tvm::ir::CallingConv enum.
- """
-
- kDefault = 0
- kCPackedFunc = 1
- kDeviceKernelLaunch = 2
-
- host_mod = tvm.tir.transform.Filter(
- lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
- != int(CallConv.kDeviceKernelLaunch)
- )(mod)
- device_mod = tvm.tir.transform.Filter(
- lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
- == int(CallConv.kDeviceKernelLaunch)
- )(mod)
+ host_mod = tvm.tir.transform.Filter(lambda f: "cpu" in
str(f.attrs.get("target", "cpu")))(mod)
+ device_mod = tvm.tir.transform.Filter(lambda f: "cpu" not in
str(f.attrs.get("target", "cpu")))(
+ mod
+ )
+ # TODO(syfeng): Here we use str as key since target hash is not correct
+ target_str2target = {}
+ device_func_dict = {}
device_mod_dict = {}
for gv, func in device_mod.functions.items():
- device_mod_dict.setdefault(func.attrs.get("target", None),
dict()).update({gv: func})
- for target, funcs in device_mod_dict.items():
- device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs)
+ target = func.attrs.get("target", None)
+ target_str = str(target) if target is not None else ""
+ target_str2target[target_str] = target # This might be overridden by
the last one
Review Comment:
We want to make sure in which cases different `Target` obects might have the
same string representations `target_str`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]