This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 81889decfa [Fix] Replace str(target.kind) with target.kind.name for 
Target objects (#18959)
81889decfa is described below

commit 81889decfaa5061e2bdc53ae75a7b395ecd649ce
Author: Akaash Parthasarathy <[email protected]>
AuthorDate: Thu Apr 2 14:17:02 2026 -0400

    [Fix] Replace str(target.kind) with target.kind.name for Target objects 
(#18959)
    
    Replace `str(target.kind)` with `target.kind.name` for `Target` objects
    since `target.kind` is a `TargetKind` object while `target.kind.name`
    yields a string describing the target
---
 python/tvm/relax/frontend/nn/llm/kv_cache.py  | 6 +++---
 python/tvm/relax/frontend/nn/llm/tree_attn.py | 4 ++--
 python/tvm/runtime/module.py                  | 4 ++--
 python/tvm/testing/utils.py                   | 4 ++--
 4 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py 
b/python/tvm/relax/frontend/nn/llm/kv_cache.py
index f66ff6b663..749707cb29 100644
--- a/python/tvm/relax/frontend/nn/llm/kv_cache.py
+++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py
@@ -73,7 +73,7 @@ def check_thread_limits(target: Target, bdx: int, bdy: int, 
bdz: int, gdz: int):
         f"{target.kind} max num threads exceeded: 
{bdx}*{bdy}*{bdz}>{max_num_threads_per_block}"
     )
 
-    if str(target.kind) == "webgpu":
+    if target.kind.name == "webgpu":
         # 
https://gpuweb.github.io/gpuweb/#dom-supported-limits-maxcomputeworkgroupsizez
         assert bdz <= 64, f"webgpu's threadIdx.z cannot exceed 64, but got 
bdz={bdz}"
         assert gdz == 1, f"webgpu's blockIdx.z should be 1, but got gdz={gdz}"
@@ -623,7 +623,7 @@ class TIRPagedKVCache(PagedKVCache):  # pylint: 
disable=too-few-public-methods
             # pylint: enable=line-too-long
         ]
 
-        if str(target.kind) == "llvm":
+        if target.kind.name == "llvm":
             if attn_kind_single == "mla":
                 raise ValueError("MLA is not supported in TIR kernels for 
now.")
             # pylint: disable=line-too-long
@@ -1098,7 +1098,7 @@ def _get_prefill_kernel_config(h_kv, h_q, d, dtype, 
target: Target):
 
     # Otherwise we would exceed maxComputeWorkgroupStorageSize
     if (
-        str(target.kind) == "webgpu"
+        target.kind.name == "webgpu"
         and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
     ):
         tile_z = 8
diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py 
b/python/tvm/relax/frontend/nn/llm/tree_attn.py
index c55aa3eceb..2c7f66cf37 100644
--- a/python/tvm/relax/frontend/nn/llm/tree_attn.py
+++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py
@@ -331,7 +331,7 @@ def tree_attn(h_kv, h_q, d, dtype, rope_scaling: dict[str, 
Any], target: Target)
 
     # Otherwise we would exceed maxComputeWorkgroupStorageSize
     if (
-        str(target.kind) == "webgpu"
+        target.kind.name == "webgpu"
         and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
     ):
         tile_z = 8
@@ -898,7 +898,7 @@ def tree_attn_with_paged_kv_cache(
 
     # Otherwise we would exceed maxComputeWorkgroupStorageSize
     if (
-        str(target.kind) == "webgpu"
+        target.kind.name == "webgpu"
         and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
     ):
         tile_z = 8
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 7586e0df15..df78faa596 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -473,7 +473,7 @@ def enabled(target):
 
     Parameters
     ----------
-    target : str
+    target : str or Dict[str, Any] or tvm.target.Target
         The target device type.
 
     Returns
@@ -490,7 +490,7 @@ def enabled(target):
     if isinstance(target, dict):
         target = target.get("kind", "")
     elif hasattr(target, "kind"):
-        target = str(target.kind)
+        target = target.kind.name
     return _ffi_api.RuntimeEnabled(target)
 
 
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index c5937eec4e..125fa9586b 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -463,7 +463,7 @@ def device_enabled(target):
 
     Parameters
     ----------
-    target : str
+    target : str or Dict[str, Any] or tvm.target.Target
         Target string to check against
 
     Returns
@@ -485,7 +485,7 @@ def device_enabled(target):
     if isinstance(target, dict):
         target_kind = target["kind"]
     elif hasattr(target, "kind"):
-        target_kind = str(target.kind)
+        target_kind = target.kind.name
     else:
         target_kind = target
     return any(target_kind == t["target_kind"] for t in _get_targets() if 
t["is_runnable"])

Reply via email to