This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 81889decfa [Fix] Replace str(target.kind) with target.kind.name for
Target objects (#18959)
81889decfa is described below
commit 81889decfaa5061e2bdc53ae75a7b395ecd649ce
Author: Akaash Parthasarathy <[email protected]>
AuthorDate: Thu Apr 2 14:17:02 2026 -0400
[Fix] Replace str(target.kind) with target.kind.name for Target objects
(#18959)
Replace `str(target.kind)` with `target.kind.name` for `Target` objects
since `target.kind` is a `TargetKind` object while `target.kind.name`
yields a string describing the target
---
python/tvm/relax/frontend/nn/llm/kv_cache.py | 6 +++---
python/tvm/relax/frontend/nn/llm/tree_attn.py | 4 ++--
python/tvm/runtime/module.py | 4 ++--
python/tvm/testing/utils.py | 4 ++--
4 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index f66ff6b663..749707cb29 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -73,7 +73,7 @@ def check_thread_limits(target: Target, bdx: int, bdy: int,
bdz: int, gdz: int):
f"{target.kind} max num threads exceeded:
{bdx}*{bdy}*{bdz}>{max_num_threads_per_block}"
)
- if str(target.kind) == "webgpu":
+ if target.kind.name == "webgpu":
#
https://gpuweb.github.io/gpuweb/#dom-supported-limits-maxcomputeworkgroupsizez
assert bdz <= 64, f"webgpu's threadIdx.z cannot exceed 64, but got
bdz={bdz}"
assert gdz == 1, f"webgpu's blockIdx.z should be 1, but got gdz={gdz}"
@@ -623,7 +623,7 @@ class TIRPagedKVCache(PagedKVCache): # pylint:
disable=too-few-public-methods
# pylint: enable=line-too-long
]
- if str(target.kind) == "llvm":
+ if target.kind.name == "llvm":
if attn_kind_single == "mla":
raise ValueError("MLA is not supported in TIR kernels for
now.")
# pylint: disable=line-too-long
@@ -1098,7 +1098,7 @@ def _get_prefill_kernel_config(h_kv, h_q, d, dtype,
target: Target):
# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
- str(target.kind) == "webgpu"
+ target.kind.name == "webgpu"
and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
):
tile_z = 8
diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py
b/python/tvm/relax/frontend/nn/llm/tree_attn.py
index c55aa3eceb..2c7f66cf37 100644
--- a/python/tvm/relax/frontend/nn/llm/tree_attn.py
+++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py
@@ -331,7 +331,7 @@ def tree_attn(h_kv, h_q, d, dtype, rope_scaling: dict[str,
Any], target: Target)
# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
- str(target.kind) == "webgpu"
+ target.kind.name == "webgpu"
and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
):
tile_z = 8
@@ -898,7 +898,7 @@ def tree_attn_with_paged_kv_cache(
# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
- str(target.kind) == "webgpu"
+ target.kind.name == "webgpu"
and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
):
tile_z = 8
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 7586e0df15..df78faa596 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -473,7 +473,7 @@ def enabled(target):
Parameters
----------
- target : str
+ target : str or Dict[str, Any] or tvm.target.Target
The target device type.
Returns
@@ -490,7 +490,7 @@ def enabled(target):
if isinstance(target, dict):
target = target.get("kind", "")
elif hasattr(target, "kind"):
- target = str(target.kind)
+ target = target.kind.name
return _ffi_api.RuntimeEnabled(target)
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index c5937eec4e..125fa9586b 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -463,7 +463,7 @@ def device_enabled(target):
Parameters
----------
- target : str
+ target : str or Dict[str, Any] or tvm.target.Target
Target string to check against
Returns
@@ -485,7 +485,7 @@ def device_enabled(target):
if isinstance(target, dict):
target_kind = target["kind"]
elif hasattr(target, "kind"):
- target_kind = str(target.kind)
+ target_kind = target.kind.name
else:
target_kind = target
return any(target_kind == t["target_kind"] for t in _get_targets() if
t["is_runnable"])