This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2a77aaaadd [TIRx] Phase out flat device-intrinsic op aliases (#19838)
2a77aaaadd is described below

commit 2a77aaaadd8421484bf8d8b754e91d2ba7142658
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Jun 19 12:25:41 2026 -0400

    [TIRx] Phase out flat device-intrinsic op aliases (#19838)
    
    PR #19677 registered every CUDA / Trainium device intrinsic under two Op
    names: a flat `tirx.<ns>_<name>` alias plus the canonical
    `tirx.<ns>.<name>`. The flat aliases were a migration shim; passes and
    codegen that match an intrinsic had to check both spellings (the
    dual-name IsOp pattern). The Python builders and TVMScript parser
    already canonicalize, so every real Call already carries the canonical
    op and the flat aliases were dead weight.
    
    This pr removes the flat device-intrinsic aliases, keeping only the
    canonical namespaced ops:
    
    - RegisterDeviceIntrinsic (backend/cuda) and RegisterNKIIntrinsic
    (backend/trn) register only the canonical name.
    - Drop the flat-only macro registrations for device intrinsics; the
    canonical op with all attrs is registered from the alias table. The WMMA
    tvm_*_sync / mma_store / mma_fill builtins and the profiling
    timer_*_cuda builtins keep their flat names (no namespace / canonical
    form, category "builtin").
    - Remove the redundant flat tirx.ptx_fetch_register registration.
    - C++ consumers that resolved a flat op by name string now use the
    canonical name; the ptx_elect_sync / cuda_func_call dual-name matchers
    collapse to the canonical check.
    - Python: the InjectPTXAsyncCopy round-trip Op.get and the matching test
    assertion use the canonical name. call_intrin keeps its flat->canonical
    rewrite for back-compat, so user-facing wrappers are unchanged.
    - test_op_namespace_cleanup asserts device_intrin op names are canonical
    so a flat alias cannot silently reappear.
    
    Generated CUDA is byte-identical: helper names are literals and codegen
    dispatches by op name, with the registry resolving the canonical name to
    the same helper.
---
 python/tvm/backend/cuda/op.py                      | 322 ++++++++++-----------
 .../backend/cuda/operator/intrinsics/cp_async.py   |   3 +
 .../operator/tile_primitive/gemm_async/tcgen05.py  |   2 +-
 python/tvm/backend/cuda/script.py                  |   8 +-
 python/tvm/backend/trn/op.py                       |  32 +-
 src/backend/cuda/codegen/codegen_cuda.cc           |  18 +-
 src/backend/cuda/op/target_builtin.cc              | 278 +-----------------
 src/backend/trn/codegen/codegen_trn.cc             |  32 +-
 src/backend/trn/op/target_builtin.cc               |  63 +---
 src/s_tir/transform/inject_ptx_async_copy.cc       |   6 +-
 .../transform/merge_shared_memory_allocations.cc   |   2 +-
 src/tirx/analysis/filter_canonical.cc              |  10 +-
 src/tirx/analysis/filter_canonical.h               |   8 +-
 src/tirx/script/printer/expr.cc                    |   2 +-
 src/tirx/transform/tile_primitive_dispatch.cc      |   8 +-
 .../test_s_tir_transform_inject_ptx_async_copy.py  |  17 +-
 16 files changed, 253 insertions(+), 558 deletions(-)

diff --git a/python/tvm/backend/cuda/op.py b/python/tvm/backend/cuda/op.py
index bb3c59599e..fc3efbf5f5 100644
--- a/python/tvm/backend/cuda/op.py
+++ b/python/tvm/backend/cuda/op.py
@@ -70,7 +70,7 @@ def cuda_func_call(func_name, *args, source_code, 
return_type="void"):
     return_type: str
         The return type of the CUDA function.
     """
-    return call_intrin(return_type, "tirx.cuda_func_call", func_name, *args, 
source_code)
+    return call_intrin(return_type, "tirx.cuda.func_call", func_name, *args, 
source_code)
 
 
 def cuda_warp_reduce(value, op, width=32):
@@ -97,7 +97,7 @@ def cuda_warp_reduce(value, op, width=32):
     call : PrimExpr
         The reduced value (same dtype as *value*).
     """
-    return call_intrin(value.dtype, "tirx.cuda_warp_reduce", value, op, width)
+    return call_intrin(value.dtype, "tirx.cuda.warp_reduce", value, op, width)
 
 
 def cuda_warp_sum(value, width=32):
@@ -141,7 +141,7 @@ def cuda_cta_reduce(value, op, num_warps, scratch):
     call : PrimExpr
         The reduced value broadcast to all threads (same dtype as *value*).
     """
-    return call_intrin(value.dtype, "tirx.cuda_cta_reduce", value, op, 
num_warps, scratch)
+    return call_intrin(value.dtype, "tirx.cuda.cta_reduce", value, op, 
num_warps, scratch)
 
 
 def cuda_cta_sum(value, num_warps, scratch):
@@ -182,7 +182,7 @@ def cuda_copy_bytes(dst, src, num_bytes):
     call : PrimExpr
         A void call expression.
     """
-    return call_intrin("void", "tirx.cuda_copy_bytes", dst, src, num_bytes)
+    return call_intrin("void", "tirx.cuda.copy_bytes", dst, src, num_bytes)
 
 
 def cuda_copy_128b(dst, src):
@@ -220,7 +220,7 @@ def cuda_warp_sync():
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_warp_sync")
+    return call_intrin("", "tirx.cuda.warp_sync")
 
 
 def cuda_cta_sync():
@@ -231,7 +231,7 @@ def cuda_cta_sync():
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_cta_sync")
+    return call_intrin("", "tirx.cuda.cta_sync")
 
 
 def cuda_grid_sync():
@@ -242,7 +242,7 @@ def cuda_grid_sync():
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_grid_sync")
+    return call_intrin("", "tirx.cuda.grid_sync")
 
 
 def cuda_cluster_sync():
@@ -253,7 +253,7 @@ def cuda_cluster_sync():
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_cluster_sync")
+    return call_intrin("", "tirx.cuda.cluster_sync")
 
 
 def cuda_thread_rank():
@@ -271,7 +271,7 @@ def cuda_thread_rank():
     call : PrimExpr
         The call expression (``int32``).
     """
-    return call_intrin("int32", "tirx.cuda_thread_rank")
+    return call_intrin("int32", "tirx.cuda.thread_rank")
 
 
 def cuda_half2float(src):
@@ -287,7 +287,7 @@ def cuda_half2float(src):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("float32", "tirx.cuda_half2float", src)
+    return call_intrin("float32", "tirx.cuda.half2float", src)
 
 
 def cuda_bfloat162float(src):
@@ -303,7 +303,7 @@ def cuda_bfloat162float(src):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("float32", "tirx.cuda_bfloat162float", src)
+    return call_intrin("float32", "tirx.cuda.bfloat162float", src)
 
 
 def cuda_float22half2(dst, src):
@@ -322,7 +322,7 @@ def cuda_float22half2(dst, src):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_float22half2", dst, src)
+    return call_intrin("", "tirx.cuda.float22half2", dst, src)
 
 
 def cuda_trap_when_assert_failed(cond):
@@ -338,7 +338,7 @@ def cuda_trap_when_assert_failed(cond):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_trap_when_assert_failed", cond)
+    return call_intrin("", "tirx.cuda.trap_when_assert_failed", cond)
 
 
 def cuda_runtime_instr_desc(desc, sf_id):
@@ -357,7 +357,7 @@ def cuda_runtime_instr_desc(desc, sf_id):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_runtime_instr_desc", desc, sf_id)
+    return call_intrin("", "tirx.cuda.runtime_instr_desc", desc, sf_id)
 
 
 def cuda_half8tofloat8(src_addr, dst_addr):
@@ -376,7 +376,7 @@ def cuda_half8tofloat8(src_addr, dst_addr):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_half8tofloat8", src_addr, dst_addr)
+    return call_intrin("", "tirx.cuda.half8tofloat8", src_addr, dst_addr)
 
 
 def cuda_float8tohalf8(src_addr, dst_addr):
@@ -395,7 +395,7 @@ def cuda_float8tohalf8(src_addr, dst_addr):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_float8tohalf8", src_addr, dst_addr)
+    return call_intrin("", "tirx.cuda.float8tohalf8", src_addr, dst_addr)
 
 
 def ptx_mma_sp(
@@ -480,7 +480,7 @@ def ptx_mma_sp(
     """
     return call_intrin(
         dtype,
-        "tirx.ptx_mma_sp",
+        "tirx.ptx.mma_sp",
         shape,
         A_layout,
         B_layout,
@@ -536,7 +536,7 @@ def ptx_cp_async_bulk(
     """
     return call_intrin(
         dtype,
-        "tirx.ptx_cp_async_bulk",
+        "tirx.ptx.cp_async_bulk",
         shared_ptr,
         shared_offset,
         global_ptr,
@@ -572,7 +572,7 @@ def ptx_cp_async_bulk_shared_to_cluster(dst_ptr, src_ptr, 
size, mbar):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_cp_async_bulk_shared_to_cluster", 
dst_ptr, src_ptr, size, mbar)
+    return call_intrin("", "tirx.ptx.cp_async_bulk_shared_to_cluster", 
dst_ptr, src_ptr, size, mbar)
 
 
 def ptx_cp_async_mbarrier_arrive(barrier_id):
@@ -589,7 +589,7 @@ def ptx_cp_async_mbarrier_arrive(barrier_id):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_cp_async_mbarrier_arrive", barrier_id)
+    return call_intrin("", "tirx.ptx.cp_async_mbarrier_arrive", barrier_id)
 
 
 def ptx_fence(sem: str, scope: str):
@@ -611,7 +611,7 @@ def ptx_fence(sem: str, scope: str):
     """
     _choice("sem", sem, _FENCE_SEM)
     _choice("scope", scope, _FENCE_SCOPE)
-    return call_intrin("", "tirx.ptx_fence", sem, scope)
+    return call_intrin("", "tirx.ptx.fence", sem, scope)
 
 
 def ptx_fence_proxy_async(space: str = ""):
@@ -631,7 +631,7 @@ def ptx_fence_proxy_async(space: str = ""):
         The call expression.
     """
     _choice("space", space, _FENCE_PROXY_ASYNC_SPACE)
-    return call_intrin("", "tirx.ptx_fence_proxy_async", space)
+    return call_intrin("", "tirx.ptx.fence_proxy_async", space)
 
 
 def ptx_mbarrier_init(bar, thread_count):
@@ -650,7 +650,7 @@ def ptx_mbarrier_init(bar, thread_count):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_mbarrier_init", bar, thread_count)
+    return call_intrin("", "tirx.ptx.mbarrier_init", bar, thread_count)
 
 
 def ptx_mbarrier_arrive(bar, cta_id=None, pred=None, count=None):
@@ -677,11 +677,11 @@ def ptx_mbarrier_arrive(bar, cta_id=None, pred=None, 
count=None):
         ``mbarrier.arrive.shared::cluster.b64 _, [addr], count``.
     """
     if cta_id is None and pred is None:
-        return call_intrin("", "tirx.ptx_mbarrier_arrive", bar)
+        return call_intrin("", "tirx.ptx.mbarrier_arrive", bar)
     assert cta_id is not None and pred is not None
     if count is None:
-        return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred)
-    return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred, 
count)
+        return call_intrin("", "tirx.ptx.mbarrier_arrive", bar, cta_id, pred)
+    return call_intrin("", "tirx.ptx.mbarrier_arrive", bar, cta_id, pred, 
count)
 
 
 def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count):
@@ -691,7 +691,7 @@ def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count):
     ``@p mapa.shared::cluster.u32`` + ``@p mbarrier.arrive.shared::cluster.b64 
_,
     [addr], count`` with the guard defaulted to 1.
     """
-    return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, True, 
count)
+    return call_intrin("", "tirx.ptx.mbarrier_arrive", bar, cta_id, True, 
count)
 
 
 def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None):
@@ -722,13 +722,13 @@ def ptx_mbarrier_arrive_expect_tx(bar, byte_count, 
cta_id=None, pred=None):
         The call expression.
     """
     if cta_id is None and pred is None:
-        return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, 
byte_count)
+        return call_intrin("", "tirx.ptx.mbarrier_arrive_expect_tx", bar, 
byte_count)
     assert cta_id is not None
     # Cross-CTA expect_tx from an already-elected thread: default the guard to 
1
     # (the caller has elected a single lane), so callers can pass cta_id alone.
     if pred is None:
         pred = True
-    return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, 
byte_count, cta_id, pred)
+    return call_intrin("", "tirx.ptx.mbarrier_arrive_expect_tx", bar, 
byte_count, cta_id, pred)
 
 
 def ptx_mbarrier_try_wait(bar, phase):
@@ -747,7 +747,7 @@ def ptx_mbarrier_try_wait(bar, phase):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_mbarrier_try_wait", bar, phase)
+    return call_intrin("", "tirx.ptx.mbarrier_try_wait", bar, phase)
 
 
 def ptx_mbarrier_try_wait_acquire_cluster(bar, phase):
@@ -764,7 +764,7 @@ def ptx_mbarrier_try_wait_acquire_cluster(bar, phase):
     phase : int
         The phase of the barrier.
     """
-    return call_intrin("", "tirx.ptx_mbarrier_try_wait_acquire_cluster", bar, 
phase)
+    return call_intrin("", "tirx.ptx.mbarrier_try_wait_acquire_cluster", bar, 
phase)
 
 
 def ptx_mbarrier_try_wait_once(bar, phase, ticks):
@@ -774,7 +774,7 @@ def ptx_mbarrier_try_wait_once(bar, phase, ticks):
     This is intended for bounded debug waits; production waits should use
     :func:`ptx_mbarrier_try_wait`.
     """
-    return call_intrin("uint32", "tirx.ptx_mbarrier_try_wait_once", bar, 
phase, ticks)
+    return call_intrin("uint32", "tirx.ptx.mbarrier_try_wait_once", bar, 
phase, ticks)
 
 
 def ptx_bar_arrive(name_bar_id, thread_count):
@@ -793,7 +793,7 @@ def ptx_bar_arrive(name_bar_id, thread_count):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_bar_arrive", name_bar_id, thread_count)
+    return call_intrin("", "tirx.ptx.bar_arrive", name_bar_id, thread_count)
 
 
 def ptx_bar_sync(name_bar_id, thread_count):
@@ -812,7 +812,7 @@ def ptx_bar_sync(name_bar_id, thread_count):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_bar_sync", name_bar_id, thread_count)
+    return call_intrin("", "tirx.ptx.bar_sync", name_bar_id, thread_count)
 
 
 def ptx_cp_async(
@@ -870,7 +870,7 @@ def ptx_cp_async(
     _choice("fill_mode", fill_mode, _CP_ASYNC_FILL_MODE)
     return call_intrin(
         "",
-        "tirx.ptx_cp_async",
+        "tirx.ptx.cp_async",
         dst_ptr,
         src_ptr,
         cp_size,
@@ -921,7 +921,7 @@ def ptx_cp_async_commit_group():
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_cp_async_commit_group")
+    return call_intrin("", "tirx.ptx.cp_async_commit_group")
 
 
 def ptx_cp_async_wait_group(num=0):
@@ -938,7 +938,7 @@ def ptx_cp_async_wait_group(num=0):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_cp_async_wait_group", num)
+    return call_intrin("", "tirx.ptx.cp_async_wait_group", num)
 
 
 def ptx_cp_async_bulk_tensor_global_to_cluster(
@@ -985,7 +985,7 @@ def ptx_cp_async_bulk_tensor_global_to_cluster(
         has_cache_policy, *coords = coords
         return call_intrin(
             "",
-            "tirx.ptx_cp_async_bulk_tensor_global_to_cluster",
+            "tirx.ptx.cp_async_bulk_tensor_global_to_cluster",
             dim,
             dst_ptr,
             bar,
@@ -999,7 +999,7 @@ def ptx_cp_async_bulk_tensor_global_to_cluster(
     cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, 
cache_policy)
     return call_intrin(
         "",
-        "tirx.ptx_cp_async_bulk_tensor_global_to_cluster",
+        "tirx.ptx.cp_async_bulk_tensor_global_to_cluster",
         dim,
         dst_ptr,
         bar,
@@ -1054,7 +1054,7 @@ def 
ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster(
         has_cache_policy, *coords = coords
         return call_intrin(
             "",
-            "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster",
+            "tirx.ptx.cp_async_bulk_tensor_tile_gather4_global_to_cluster",
             dim,
             dst_ptr,
             bar,
@@ -1068,7 +1068,7 @@ def 
ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster(
     cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, 
cache_policy)
     return call_intrin(
         "",
-        "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster",
+        "tirx.ptx.cp_async_bulk_tensor_tile_gather4_global_to_cluster",
         dim,
         dst_ptr,
         bar,
@@ -1112,7 +1112,7 @@ def ptx_cp_async_bulk_tensor_shared_to_global(
         has_cache_policy, *coords = coords
         return call_intrin(
             "",
-            "tirx.ptx_cp_async_bulk_tensor_shared_to_global",
+            "tirx.ptx.cp_async_bulk_tensor_shared_to_global",
             dim,
             src_ptr,
             tensormap_addr,
@@ -1123,7 +1123,7 @@ def ptx_cp_async_bulk_tensor_shared_to_global(
     cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, 
cache_policy)
     return call_intrin(
         "",
-        "tirx.ptx_cp_async_bulk_tensor_shared_to_global",
+        "tirx.ptx.cp_async_bulk_tensor_shared_to_global",
         dim,
         src_ptr,
         tensormap_addr,
@@ -1161,7 +1161,7 @@ def ptx_cp_async_bulk_tensor_global_to_cluster_prefetch(
         has_cache_policy, *coords = coords
         return call_intrin(
             "",
-            "tirx.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch",
+            "tirx.ptx.cp_async_bulk_tensor_global_to_cluster_prefetch",
             dim,
             tensormap_addr,
             cache_hint,
@@ -1171,7 +1171,7 @@ def ptx_cp_async_bulk_tensor_global_to_cluster_prefetch(
     cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, 
cache_policy)
     return call_intrin(
         "",
-        "tirx.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch",
+        "tirx.ptx.cp_async_bulk_tensor_global_to_cluster_prefetch",
         dim,
         tensormap_addr,
         cache_policy,
@@ -1216,7 +1216,7 @@ def ptx_cp_async_bulk_tensor_shared_to_global_reduce(
         _choice("red_op", red_op, _CP_ASYNC_BULK_RED_OP)
         return call_intrin(
             "",
-            "tirx.ptx_cp_async_bulk_tensor_shared_to_global_reduce",
+            "tirx.ptx.cp_async_bulk_tensor_shared_to_global_reduce",
             dim,
             src_ptr,
             tensormap_addr,
@@ -1229,7 +1229,7 @@ def ptx_cp_async_bulk_tensor_shared_to_global_reduce(
     _choice("red_op", red_op, _CP_ASYNC_BULK_RED_OP)
     return call_intrin(
         "",
-        "tirx.ptx_cp_async_bulk_tensor_shared_to_global_reduce",
+        "tirx.ptx.cp_async_bulk_tensor_shared_to_global_reduce",
         dim,
         src_ptr,
         tensormap_addr,
@@ -1248,7 +1248,7 @@ def ptx_cp_async_bulk_commit_group():
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_cp_async_bulk_commit_group")
+    return call_intrin("", "tirx.ptx.cp_async_bulk_commit_group")
 
 
 def ptx_cp_async_bulk_wait_group(n=0, read=True):
@@ -1267,7 +1267,7 @@ def ptx_cp_async_bulk_wait_group(n=0, read=True):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_cp_async_bulk_wait_group", n, read)
+    return call_intrin("", "tirx.ptx.cp_async_bulk_wait_group", n, read)
 
 
 def ptx_barrier_cluster_arrive(sem="", aligned=True):
@@ -1282,7 +1282,7 @@ def ptx_barrier_cluster_arrive(sem="", aligned=True):
         Whether all threads in the warp must execute the same instruction.
     """
     _choice("sem", sem, _CLUSTER_BARRIER_SEM)
-    return call_intrin("", "tirx.ptx_barrier_cluster_arrive", sem, aligned)
+    return call_intrin("", "tirx.ptx.barrier_cluster_arrive", sem, aligned)
 
 
 def ptx_barrier_cluster_wait(acquire=False, aligned=True):
@@ -1296,7 +1296,7 @@ def ptx_barrier_cluster_wait(acquire=False, aligned=True):
     aligned : bool
         Whether all threads in the warp must execute the same instruction.
     """
-    return call_intrin("", "tirx.ptx_barrier_cluster_wait", acquire, aligned)
+    return call_intrin("", "tirx.ptx.barrier_cluster_wait", acquire, aligned)
 
 
 def ptx_clc_try_cancel(handle, mbar):
@@ -1314,7 +1314,7 @@ def ptx_clc_try_cancel(handle, mbar):
     mbar : PrimExpr
         Pointer to the mbarrier signalled when the handle lands.
     """
-    return call_intrin("", "tirx.ptx_clc_try_cancel", handle, mbar)
+    return call_intrin("", "tirx.ptx.clc_try_cancel", handle, mbar)
 
 
 def ptx_clc_query_cancel(handle):
@@ -1328,12 +1328,12 @@ def ptx_clc_query_cancel(handle):
     handle : PrimExpr
         Pointer to the 16B (uint4) smem response handle.
     """
-    return call_intrin("uint32", "tirx.ptx_clc_query_cancel", handle)
+    return call_intrin("uint32", "tirx.ptx.clc_query_cancel", handle)
 
 
 def ptx_elect_sync():
     """TVM intrinsic to call elect.sync"""
-    return call_intrin("uint32", "tirx.ptx_elect_sync")
+    return call_intrin("uint32", "tirx.ptx.elect_sync")
 
 
 def ptx_fence_mbarrier_init():
@@ -1346,7 +1346,7 @@ def ptx_fence_mbarrier_init():
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_fence_mbarrier_init")
+    return call_intrin("", "tirx.ptx.fence_mbarrier_init")
 
 
 def ptx_fetch_register(bits, reg_name):
@@ -1365,7 +1365,7 @@ def ptx_fetch_register(bits, reg_name):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("int" + str(bits), "tirx.ptx_fetch_register", bits, 
reg_name)
+    return call_intrin("int" + str(bits), "tirx.ptx.fetch_register", bits, 
reg_name)
 
 
 def ptx_mma(
@@ -1460,7 +1460,7 @@ def ptx_mma(
 
     base = [
         "",
-        "tirx.ptx_mma",
+        "tirx.ptx.mma",
         shape,
         a_layout,
         b_layout,
@@ -1552,7 +1552,7 @@ def ptx_mma_legacy(*all_args, operator=None):
     ]
     if operator is not None:
         call_args.append(operator)
-    return call_intrin("", "tirx.ptx_mma_legacy", *call_args)
+    return call_intrin("", "tirx.ptx.mma_legacy", *call_args)
 
 
 def ptx_mma_sp_legacy(*all_args):
@@ -1679,7 +1679,7 @@ def ptx_ldmatrix(trans, num, dtype, smem_ptr, 
*dst_handles):
             f"ldmatrix .x{int(num)}.{dtype_bare} expects {n_regs} destination "
             f"handles, got {len(dst_handles)}"
         )
-    return call_intrin("", "tirx.ptx_ldmatrix", trans, num, dtype, smem_ptr, 
*dst_handles)
+    return call_intrin("", "tirx.ptx.ldmatrix", trans, num, dtype, smem_ptr, 
*dst_handles)
 
 
 _PTX_TO_NUMPY_DTYPE = {
@@ -1773,7 +1773,7 @@ def ptx_ldmatrix_legacy(*all_args):
     # int8+trans manual-loop fallback (ldmatrix can't transpose int8).
     return call_intrin(
         elem_dtype,
-        "tirx.ptx_ldmatrix_legacy",
+        "tirx.ptx.ldmatrix_legacy",
         trans,
         num,
         dtype,
@@ -1826,7 +1826,7 @@ def ptx_stmatrix(trans, num, dtype, smem_ptr, 
*src_handles, shape="m8n8", space=
             f"handles, got {len(src_handles)}"
         )
     return call_intrin(
-        "", "tirx.ptx_stmatrix", trans, num, dtype, shape, space, smem_ptr, 
*src_handles
+        "", "tirx.ptx.stmatrix", trans, num, dtype, shape, space, smem_ptr, 
*src_handles
     )
 
 
@@ -1850,7 +1850,7 @@ def ptx_wgmma_encode_matrix_descriptor(desc, addr, ldo, 
sdo, swizzle):
     swizzle : int
         The swizzle value (CUtensorMapSwizzle_enum).
     """
-    return call_intrin("", "tirx.ptx_wgmma_encode_matrix_descriptor", desc, 
addr, ldo, sdo, swizzle)
+    return call_intrin("", "tirx.ptx.wgmma_encode_matrix_descriptor", desc, 
addr, ldo, sdo, swizzle)
 
 
 def ptx_wgmma_noop_barrier(reg):
@@ -1866,7 +1866,7 @@ def ptx_wgmma_noop_barrier(reg):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_wgmma_noop_barrier", reg)
+    return call_intrin("", "tirx.ptx.wgmma_noop_barrier", reg)
 
 
 def ptx_wgmma_mma_async_ss(
@@ -1917,7 +1917,7 @@ def ptx_wgmma_mma_async_ss(
     """  # noqa: E501
     return call_intrin(
         "",
-        "tirx.ptx_wgmma_mma_async_ss",
+        "tirx.ptx.wgmma_mma_async_ss",
         M,
         N,
         K,
@@ -1980,7 +1980,7 @@ def ptx_wgmma_mma_async_rs(
     """
     return call_intrin(
         "",
-        "tirx.ptx_wgmma_mma_async_rs",
+        "tirx.ptx.wgmma_mma_async_rs",
         M,
         N,
         K,
@@ -2004,7 +2004,7 @@ def ptx_wgmma_fence():
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_wgmma_fence")
+    return call_intrin("", "tirx.ptx.wgmma_fence")
 
 
 def ptx_wgmma_commit_group():
@@ -2015,7 +2015,7 @@ def ptx_wgmma_commit_group():
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_wgmma_commit_group")
+    return call_intrin("", "tirx.ptx.wgmma_commit_group")
 
 
 def ptx_wgmma_wait_group(n):
@@ -2031,7 +2031,7 @@ def ptx_wgmma_wait_group(n):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_wgmma_wait_group", n)
+    return call_intrin("", "tirx.ptx.wgmma_wait_group", n)
 
 
 def ptx_setmaxnreg(inc: bool, reg_count):
@@ -2045,7 +2045,7 @@ def ptx_setmaxnreg(inc: bool, reg_count):
     reg_count : int
         The register count.
     """
-    return call_intrin("", "tirx.ptx_setmaxnreg", inc, reg_count)
+    return call_intrin("", "tirx.ptx.setmaxnreg", inc, reg_count)
 
 
 def ptx_tcgen05_alloc(dst_ptr, n_cols, cta_group=1):
@@ -2068,7 +2068,7 @@ def ptx_tcgen05_alloc(dst_ptr, n_cols, cta_group=1):
         one warp from each of the peer CTAs perform the allocation.
     """
     _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP)
-    return call_intrin("", "tirx.ptx_tcgen05_alloc", dst_ptr, n_cols, 
cta_group)
+    return call_intrin("", "tirx.ptx.tcgen05_alloc", dst_ptr, n_cols, 
cta_group)
 
 
 def ptx_tcgen05_dealloc(taddr, n_cols, cta_group=1):
@@ -2090,7 +2090,7 @@ def ptx_tcgen05_dealloc(taddr, n_cols, cta_group=1):
         one warp from each of the peer CTAs perform the deallocation.
     """
     _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP)
-    return call_intrin("", "tirx.ptx_tcgen05_dealloc", taddr, n_cols, 
cta_group)
+    return call_intrin("", "tirx.ptx.tcgen05_dealloc", taddr, n_cols, 
cta_group)
 
 
 def ptx_tcgen05_relinquish_alloc_permit(cta_group=1):
@@ -2106,7 +2106,7 @@ def ptx_tcgen05_relinquish_alloc_permit(cta_group=1):
         one warp from each of the peer CTAs perform the relinquishing.
     """
     _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP)
-    return call_intrin("", "tirx.ptx_tcgen05_relinquish_alloc_permit", 
cta_group)
+    return call_intrin("", "tirx.ptx.tcgen05_relinquish_alloc_permit", 
cta_group)
 
 
 def ptx_tcgen05_encode_matrix_descriptor(desc, addr, ldo, sdo, swizzle):
@@ -2130,7 +2130,7 @@ def ptx_tcgen05_encode_matrix_descriptor(desc, addr, ldo, 
sdo, swizzle):
         The swizzle value (CUtensorMapSwizzle_enum).
     """
     return call_intrin(
-        "", "tirx.ptx_tcgen05_encode_matrix_descriptor", desc, addr, ldo, sdo, 
swizzle
+        "", "tirx.ptx.tcgen05_encode_matrix_descriptor", desc, addr, ldo, sdo, 
swizzle
     )
 
 
@@ -2202,7 +2202,7 @@ def ptx_tcgen05_encode_instr_descriptor(
     _choice("n_cta_groups", n_cta_groups, _TCGEN05_CTA_GROUP)
     return call_intrin(
         "",
-        "tirx.ptx_tcgen05_encode_instr_descriptor",
+        "tirx.ptx.tcgen05_encode_instr_descriptor",
         desc,
         d_dtype,
         a_dtype,
@@ -2300,7 +2300,7 @@ def ptx_tcgen05_encode_instr_descriptor_block_scaled(
     _choice("n_cta_groups", n_cta_groups, _TCGEN05_CTA_GROUP)
     return call_intrin(
         "",
-        "tirx.ptx_tcgen05_encode_instr_descriptor_block_scaled",
+        "tirx.ptx.tcgen05_encode_instr_descriptor_block_scaled",
         desc,
         d_dtype,
         a_dtype,
@@ -2407,7 +2407,7 @@ def ptx_tcgen05_mma(
     ]
     if pred is not None:
         args.append(pred)
-    return call_intrin("", "tirx.ptx_tcgen05_mma", *args)
+    return call_intrin("", "tirx.ptx.tcgen05_mma", *args)
 
 
 def ptx_tcgen05_mma_block_scale(
@@ -2481,7 +2481,7 @@ def ptx_tcgen05_mma_block_scale(
     _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP)
     return call_intrin(
         "",
-        "tirx.ptx_tcgen05_mma_block_scale",
+        "tirx.ptx.tcgen05_mma_block_scale",
         d_dtype,
         a_dtype,
         b_dtype,
@@ -2569,7 +2569,7 @@ def ptx_tcgen05_mma_sp(
 
     return call_intrin(
         "",
-        "tirx.ptx_tcgen05_mma_sp",
+        "tirx.ptx.tcgen05_mma_sp",
         d_dtype,
         a_dtype,
         b_dtype,
@@ -2660,7 +2660,7 @@ def ptx_tcgen05_mma_sp_block_scale(
     _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP)
     return call_intrin(
         "",
-        "tirx.ptx_tcgen05_mma_sp_block_scale",
+        "tirx.ptx.tcgen05_mma_sp_block_scale",
         d_dtype,
         a_dtype,
         b_dtype,
@@ -2683,14 +2683,14 @@ def ptx_tcgen05_fence_before_thread_sync():
     """TVM intrinsic to call tcgen05.fence::before_thread_sync
     Orders all prior asynchronous tcgen05 operations relative to subsequent 
operations.
     """
-    return call_intrin("", "tirx.ptx_tcgen05_fence_before_thread_sync")
+    return call_intrin("", "tirx.ptx.tcgen05_fence_before_thread_sync")
 
 
 def ptx_tcgen05_fence_after_thread_sync():
     """TVM intrinsic to call tcgen05.fence::after_thread_sync
     Orders all subsequent asynchronous tcgen05 operations relative to previous 
operations.
     """
-    return call_intrin("", "tirx.ptx_tcgen05_fence_after_thread_sync")
+    return call_intrin("", "tirx.ptx.tcgen05_fence_after_thread_sync")
 
 
 def _choice(name: str, value, options):
@@ -2770,7 +2770,7 @@ def ptx_tcgen05_cp(
 
     return call_intrin(
         "",
-        "tirx.ptx_tcgen05_cp",
+        "tirx.ptx.tcgen05_cp",
         taddr,
         src_desc,
         shape,
@@ -2798,7 +2798,7 @@ def ptx_tcgen05_shift(taddr, cta_group=1):
         the peer CTA.
     """
     _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP)
-    return call_intrin("", "tirx.ptx_tcgen05_shift", taddr, cta_group)
+    return call_intrin("", "tirx.ptx.tcgen05_shift", taddr, cta_group)
 
 
 def ptx_tcgen05_ld(src_addr, *regs, shape, num, row=0, col=0, pack=False):
@@ -2828,7 +2828,7 @@ def ptx_tcgen05_ld(src_addr, *regs, shape, num, row=0, 
col=0, pack=False):
         Pack two 16-bit chunks into a single 32-bit register.
     """
     _choice("shape", shape, _TCGEN05_LDST_SHAPES)
-    return call_intrin("", "tirx.ptx_tcgen05_ld", src_addr, row, col, shape, 
num, pack, *regs)
+    return call_intrin("", "tirx.ptx.tcgen05_ld", src_addr, row, col, shape, 
num, pack, *regs)
 
 
 def ptx_tcgen05_st(dst_addr, *regs, shape, num, row=0, col=0, unpack=False):
@@ -2858,21 +2858,21 @@ def ptx_tcgen05_st(dst_addr, *regs, shape, num, row=0, 
col=0, unpack=False):
         Unpack a 32-bit register into two 16-bit chunks.
     """
     _choice("shape", shape, _TCGEN05_LDST_SHAPES)
-    return call_intrin("", "tirx.ptx_tcgen05_st", dst_addr, row, col, shape, 
num, unpack, *regs)
+    return call_intrin("", "tirx.ptx.tcgen05_st", dst_addr, row, col, shape, 
num, unpack, *regs)
 
 
 def ptx_tcgen05_wait_ld():
     """TVM intrinsic to call tcgen05.wait::ld.sync.aligned
     Wait for the completion of all prior async tcgen05.ld operations.
     """
-    return call_intrin("", "tirx.ptx_tcgen05_wait_ld")
+    return call_intrin("", "tirx.ptx.tcgen05_wait_ld")
 
 
 def ptx_tcgen05_wait_st():
     """TVM intrinsic to call tcgen05.wait::st.sync.aligned
     Wait for the completion of all prior async tcgen05.st operations.
     """
-    return call_intrin("", "tirx.ptx_tcgen05_wait_st")
+    return call_intrin("", "tirx.ptx.tcgen05_wait_st")
 
 
 def ptx_tcgen05_commit(bar, cta_group=1, cta_mask=0, *, pred=None):
@@ -2904,7 +2904,7 @@ def ptx_tcgen05_commit(bar, cta_group=1, cta_mask=0, *, 
pred=None):
     args = [bar, cta_group, cta_mask]
     if pred is not None:
         args.append(pred)
-    return call_intrin("", "tirx.ptx_tcgen05_commit", *args)
+    return call_intrin("", "tirx.ptx.tcgen05_commit", *args)
 
 
 def timer_init_cuda(profiler_buffer, profiler_tag, profiler_write_offset, 
num_groups, group_id):
@@ -3100,7 +3100,7 @@ def cuda_atomic_add(res_addr, value):
         The call expression.
     """
     value = tir.convert(value)
-    return call_intrin(value.dtype, "tirx.cuda_atomic_add", res_addr, value)
+    return call_intrin(value.dtype, "tirx.cuda.atomic_add", res_addr, value)
 
 
 def cuda_thread_fence():
@@ -3111,7 +3111,7 @@ def cuda_thread_fence():
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_thread_fence")
+    return call_intrin("", "tirx.cuda.thread_fence")
 
 
 def cuda_warpgroup_sync(bar_no):
@@ -3131,7 +3131,7 @@ def cuda_warpgroup_sync(bar_no):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_warpgroup_sync", bar_no)
+    return call_intrin("", "tirx.cuda.warpgroup_sync", bar_no)
 
 
 def cuda_syncthreads_and(cond):
@@ -3147,7 +3147,7 @@ def cuda_syncthreads_and(cond):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("int64", "tirx.cuda_syncthreads_and", cond)
+    return call_intrin("int64", "tirx.cuda.syncthreads_and", cond)
 
 
 def cuda_syncthreads_or(cond):
@@ -3163,7 +3163,7 @@ def cuda_syncthreads_or(cond):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("int64", "tirx.cuda_syncthreads_or", cond)
+    return call_intrin("int64", "tirx.cuda.syncthreads_or", cond)
 
 
 def cuda_nano_sleep(time):
@@ -3179,7 +3179,7 @@ def cuda_nano_sleep(time):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_nano_sleep", time)
+    return call_intrin("", "tirx.cuda.nano_sleep", time)
 
 
 def cuda_printf(fmt, *args):
@@ -3198,7 +3198,7 @@ def cuda_printf(fmt, *args):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.cuda_printf", fmt, *args)
+    return call_intrin("", "tirx.cuda.printf", fmt, *args)
 
 
 def cuda_ldg(addr, dtype):
@@ -3214,7 +3214,7 @@ def cuda_ldg(addr, dtype):
 
     Returns
     """
-    return call_intrin(dtype, "tirx.cuda_ldg", addr, dtype)
+    return call_intrin(dtype, "tirx.cuda.ldg", addr, dtype)
 
 
 def cuda_get_tmem_addr(addr, row_offset, col_offset):
@@ -3236,7 +3236,7 @@ def cuda_get_tmem_addr(addr, row_offset, col_offset):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("uint32", "tirx.cuda_get_tmem_addr", addr, row_offset, 
col_offset)
+    return call_intrin("uint32", "tirx.cuda.get_tmem_addr", addr, row_offset, 
col_offset)
 
 
 def cuda_cvta_generic_to_shared(ptr):
@@ -3246,7 +3246,7 @@ def cuda_cvta_generic_to_shared(ptr):
     precompute the shared-memory address at the wrapper layer instead of
     inside the asm helper body.
     """
-    return call_intrin("uint32", "tirx.cuda_cvta_generic_to_shared", ptr)
+    return call_intrin("uint32", "tirx.cuda.cvta_generic_to_shared", ptr)
 
 
 def cuda_smem_addr_from_uint64(cluster_addr):
@@ -3255,7 +3255,7 @@ def cuda_smem_addr_from_uint64(cluster_addr):
     Wraps ``static_cast<unsigned int>(cluster_addr)``. Used by
     cp.async.bulk.shared::cluster.* op-wrappers.
     """
-    return call_intrin("uint32", "tirx.cuda_smem_addr_from_uint64", 
cluster_addr)
+    return call_intrin("uint32", "tirx.cuda.smem_addr_from_uint64", 
cluster_addr)
 
 
 def cuda_sm100_tma_2sm_mbarrier_addr(bar):
@@ -3276,7 +3276,7 @@ def ptx_exp2(x):
     call : PrimExpr
         The call expression returning 2^x (approximate).
     """
-    return call_intrin("float32", "tirx.ptx_exp2", x)
+    return call_intrin("float32", "tirx.ptx.exp2", x)
 
 
 def ptx_rcp(x):
@@ -3292,7 +3292,7 @@ def ptx_rcp(x):
     call : PrimExpr
         The call expression returning 1/x (approximate).
     """
-    return call_intrin("float32", "tirx.ptx_rcp", x)
+    return call_intrin("float32", "tirx.ptx.rcp", x)
 
 
 def ptx_any_sync(mask, pred):
@@ -3310,7 +3310,7 @@ def ptx_any_sync(mask, pred):
     call : PrimExpr
         The call expression returning 1 if any thread in mask has pred != 0.
     """
-    return call_intrin("int32", "tirx.ptx_any_sync", mask, pred)
+    return call_intrin("int32", "tirx.ptx.any_sync", mask, pred)
 
 
 def ptx_reduce3_max_f32(a, b, c):
@@ -3326,7 +3326,7 @@ def ptx_reduce3_max_f32(a, b, c):
     call : PrimExpr
         The call expression returning max(a, b, c).
     """
-    return call_intrin("float32", "tirx.ptx_reduce3_max_f32", a, b, c)
+    return call_intrin("float32", "tirx.ptx.reduce3_max_f32", a, b, c)
 
 
 def ptx_reduce3_min_f32(a, b, c):
@@ -3342,7 +3342,7 @@ def ptx_reduce3_min_f32(a, b, c):
     call : PrimExpr
         The call expression returning min(a, b, c).
     """
-    return call_intrin("float32", "tirx.ptx_reduce3_min_f32", a, b, c)
+    return call_intrin("float32", "tirx.ptx.reduce3_min_f32", a, b, c)
 
 
 def _ptx_binary_arith(op_name, dtype, d, a, b, *, rounding="rn", ftz=False, 
sat=False):
@@ -3354,7 +3354,7 @@ def _ptx_binary_arith(op_name, dtype, d, a, b, *, 
rounding="rn", ftz=False, sat=
         raise ValueError(f"PTX {op_name}.f32x2 does not accept .sat")
     return call_intrin(
         "",
-        f"tirx.ptx_{op_name}_{dtype}",
+        f"tirx.ptx.{op_name}_{dtype}",
         d,
         a,
         b,
@@ -3373,7 +3373,7 @@ def _ptx_fma(dtype, d, a, b, c, *, rounding="rn", 
ftz=False, sat=False):
         raise ValueError("PTX fma.f32x2 does not accept .sat")
     return call_intrin(
         "",
-        f"tirx.ptx_fma_{dtype}",
+        f"tirx.ptx.fma_{dtype}",
         d,
         a,
         b,
@@ -3466,7 +3466,7 @@ def ptx_max_f32(a, b, *, ftz=False, nan=False):
     nan : bool
         If True, propagate NaN inputs (``.NaN``).
     """
-    return call_intrin("float32", "tirx.ptx_max_f32", a, b, int(ftz), int(nan))
+    return call_intrin("float32", "tirx.ptx.max_f32", a, b, int(ftz), int(nan))
 
 
 def ptx_griddepcontrol_wait():
@@ -3476,7 +3476,7 @@ def ptx_griddepcontrol_wait():
     :func:`ptx_griddepcontrol_launch_dependents` have finished. Acts as a
     full memory barrier.
     """
-    return call_intrin("", "tirx.ptx_griddepcontrol_wait")
+    return call_intrin("", "tirx.ptx.griddepcontrol_wait")
 
 
 def ptx_griddepcontrol_launch_dependents():
@@ -3485,7 +3485,7 @@ def ptx_griddepcontrol_launch_dependents():
     Signals that the current grid has reached a point where dependent
     grids may begin execution.
     """
-    return call_intrin("", "tirx.ptx_griddepcontrol_launch_dependents")
+    return call_intrin("", "tirx.ptx.griddepcontrol_launch_dependents")
 
 
 _PTX_LD_SCOPE = {"cta", "cluster", "gpu", "sys"}
@@ -3565,7 +3565,7 @@ def ptx_ld_acquire(addr, return_type, ptx_type, *, 
scope="gpu", space="global"):
     _choice("space", space, _PTX_LD_SPACE)
     _choice("ptx_type", ptx_type, _PTX_LD_TYPE)
     return call_intrin(
-        return_type, "tirx.ptx_ld_acquire", addr, return_type, ptx_type, 
scope, space
+        return_type, "tirx.ptx.ld_acquire", addr, return_type, ptx_type, 
scope, space
     )
 
 
@@ -3591,7 +3591,7 @@ def ptx_ld(
     cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, 
cache_policy)
     return call_intrin(
         return_type,
-        "tirx.ptx_ld",
+        "tirx.ptx.ld",
         addr,
         cache_policy,
         return_type,
@@ -3610,7 +3610,7 @@ def ptx_ld_volatile(addr, return_type, ptx_type, *, 
space="global"):
     """
     _choice("space", space, _PTX_LD_VOLATILE_SPACE)
     _choice("ptx_type", ptx_type, _PTX_LD_TYPE)
-    return call_intrin(return_type, "tirx.ptx_ld_volatile", addr, return_type, 
ptx_type, space)
+    return call_intrin(return_type, "tirx.ptx.ld_volatile", addr, return_type, 
ptx_type, space)
 
 
 def ptx_ld_global_acquire(res, addr):
@@ -3629,7 +3629,7 @@ def ptx_ld_global_acquire(res, addr):
     call : PrimExpr
         The call expression.
     """
-    return call_intrin("", "tirx.ptx_ld_global_acquire", res, addr)
+    return call_intrin("", "tirx.ptx.ld_global_acquire", res, addr)
 
 
 def ptx_red_scalar(
@@ -3655,7 +3655,7 @@ def ptx_red_scalar(
         raise ValueError(f"Unsupported PTX red sem {sem!r}")
     return call_intrin(
         "",
-        "tirx.ptx_red_scalar",
+        "tirx.ptx.red_scalar",
         address,
         value,
         cache_policy,
@@ -3689,7 +3689,7 @@ def ptx_atom_scalar(
         raise ValueError(f"Unsupported PTX atom sem {sem!r}")
     return call_intrin(
         _PTX_SCALAR_RETURN_TYPE[ptx_type],
-        "tirx.ptx_atom_scalar",
+        "tirx.ptx.atom_scalar",
         address,
         value,
         cache_policy,
@@ -3720,7 +3720,7 @@ def ptx_st(
     cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, 
cache_policy)
     return call_intrin(
         "",
-        "tirx.ptx_st",
+        "tirx.ptx.st",
         address,
         *values,
         cache_policy,
@@ -3736,12 +3736,12 @@ def ptx_st(
 def ptx_st_bulk(ptr, num_bytes, *, weak=False, space="shared::cta"):
     if space not in ("", "shared::cta"):
         raise ValueError(f"Unsupported PTX st.bulk space {space!r}")
-    return call_intrin("", "tirx.ptx_st_bulk", ptr, num_bytes, 
int(bool(weak)), space)
+    return call_intrin("", "tirx.ptx.st_bulk", ptr, num_bytes, 
int(bool(weak)), space)
 
 
 def ptx_prefetch_tensormap(tensormap_addr, space=""):
     _choice("space", space, _PTX_PREFETCH_TENSORMAP_SPACE)
-    return call_intrin("", "tirx.ptx_prefetch_tensormap", tensormap_addr, 
space)
+    return call_intrin("", "tirx.ptx.prefetch_tensormap", tensormap_addr, 
space)
 
 
 def ptx_mbarrier_test_wait_parity(barrier, phase, *, sem="", scope="", 
space="shared::cta"):
@@ -3754,7 +3754,7 @@ def ptx_mbarrier_test_wait_parity(barrier, phase, *, 
sem="", scope="", space="sh
     if space not in ("shared", "shared::cta"):
         raise ValueError(f"Unsupported mbarrier.test_wait.parity space 
{space!r}")
     return call_intrin(
-        "uint32", "tirx.ptx_mbarrier_test_wait_parity", barrier, phase, sem, 
scope, space
+        "uint32", "tirx.ptx.mbarrier_test_wait_parity", barrier, phase, sem, 
scope, space
     )
 
 
@@ -3773,7 +3773,7 @@ def ptx_cp_async_bulk_g2s_cta(
     cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, 
cache_policy)
     return call_intrin(
         "",
-        "tirx.ptx_cp_async_bulk_g2s_cta",
+        "tirx.ptx.cp_async_bulk_g2s_cta",
         dst_ptr,
         src_ptr,
         num_bytes,
@@ -3800,7 +3800,7 @@ def ptx_cp_async_bulk_g2s_cluster(
     cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, 
cache_policy)
     return call_intrin(
         "",
-        "tirx.ptx_cp_async_bulk_g2s_cluster",
+        "tirx.ptx.cp_async_bulk_g2s_cluster",
         dst_ptr,
         src_ptr,
         num_bytes,
@@ -3814,7 +3814,7 @@ def ptx_cp_async_bulk_g2s_cluster(
 
 def ptx_cp_async_bulk_s2s_cluster(dst_ptr, src_ptr, num_bytes, mbarrier):
     return call_intrin(
-        "", "tirx.ptx_cp_async_bulk_s2s_cluster", dst_ptr, src_ptr, num_bytes, 
mbarrier
+        "", "tirx.ptx.cp_async_bulk_s2s_cluster", dst_ptr, src_ptr, num_bytes, 
mbarrier
     )
 
 
@@ -3824,7 +3824,7 @@ def ptx_cp_async_bulk_s2g(
     cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, 
cache_policy)
     return call_intrin(
         "",
-        "tirx.ptx_cp_async_bulk_s2g",
+        "tirx.ptx.cp_async_bulk_s2g",
         dst_ptr,
         src_ptr,
         num_bytes,
@@ -3836,83 +3836,83 @@ def ptx_cp_async_bulk_s2g(
 
 
 def ptx_fns_b32(mask, base, offset):
-    return call_intrin("uint32", "tirx.ptx_fns_b32", mask, base, offset)
+    return call_intrin("uint32", "tirx.ptx.fns_b32", mask, base, offset)
 
 
 def ptx_add_rn_f32_bf16(acc, x):
-    return call_intrin("float32", "tirx.ptx_add_rn_f32_bf16", acc, x)
+    return call_intrin("float32", "tirx.ptx.add_rn_f32_bf16", acc, x)
 
 
 def cuda_uint_as_float(bits):
-    return call_intrin("float32", "tirx.cuda_uint_as_float", bits)
+    return call_intrin("float32", "tirx.cuda.uint_as_float", bits)
 
 
 def cuda_float_as_uint(x):
-    return call_intrin("uint32", "tirx.cuda_float_as_uint", x)
+    return call_intrin("uint32", "tirx.cuda.float_as_uint", x)
 
 
 def cuda_ballot_sync(mask, pred):
-    return call_intrin("uint32", "tirx.cuda_ballot_sync", mask, pred)
+    return call_intrin("uint32", "tirx.cuda.ballot_sync", mask, pred)
 
 
 def cuda_ffs_u32(value):
-    return call_intrin("int32", "tirx.cuda_ffs_u32", value)
+    return call_intrin("int32", "tirx.cuda.ffs_u32", value)
 
 
 def cuda_reduce_add_sync_u32(mask, value):
-    return call_intrin("uint32", "tirx.cuda_reduce_add_sync_u32", mask, value)
+    return call_intrin("uint32", "tirx.cuda.reduce_add_sync_u32", mask, value)
 
 
 def cuda_reduce_min_sync_u32(mask, value):
-    return call_intrin("uint32", "tirx.cuda_reduce_min_sync_u32", mask, value)
+    return call_intrin("uint32", "tirx.cuda.reduce_min_sync_u32", mask, value)
 
 
 def cuda_clock64():
-    return call_intrin("uint64", "tirx.cuda_clock64")
+    return call_intrin("uint64", "tirx.cuda.clock64")
 
 
 def cuda_make_float2(x, y):
-    return call_intrin("uint64", "tirx.cuda_make_float2", x, y)
+    return call_intrin("uint64", "tirx.cuda.make_float2", x, y)
 
 
 def cuda_float2_x(packed):
-    return call_intrin("float32", "tirx.cuda_float2_x", packed)
+    return call_intrin("float32", "tirx.cuda.float2_x", packed)
 
 
 def cuda_float2_y(packed):
-    return call_intrin("float32", "tirx.cuda_float2_y", packed)
+    return call_intrin("float32", "tirx.cuda.float2_y", packed)
 
 
 def cuda_fmul2_rn(a, b):
-    return call_intrin("uint64", "tirx.cuda_fmul2_rn", a, b)
+    return call_intrin("uint64", "tirx.cuda.fmul2_rn", a, b)
 
 
 def cuda_fadd2_rn(a, b):
-    return call_intrin("uint64", "tirx.cuda_fadd2_rn", a, b)
+    return call_intrin("uint64", "tirx.cuda.fadd2_rn", a, b)
 
 
 def cuda_float22bfloat162_rn(v0, v1):
-    return call_intrin("uint32", "tirx.cuda_float22bfloat162_rn", v0, v1)
+    return call_intrin("uint32", "tirx.cuda.float22bfloat162_rn", v0, v1)
 
 
 def cuda_float22bfloat162_rn_from_float2(packed):
-    return call_intrin("uint32", "tirx.cuda_float22bfloat162_rn_from_float2", 
packed)
+    return call_intrin("uint32", "tirx.cuda.float22bfloat162_rn_from_float2", 
packed)
 
 
 def cuda_bfloat1622float2(packed):
-    return call_intrin("uint64", "tirx.cuda_bfloat1622float2", packed)
+    return call_intrin("uint64", "tirx.cuda.bfloat1622float2", packed)
 
 
 def cuda_hmin2(a, b):
-    return call_intrin("uint32", "tirx.cuda_hmin2", a, b)
+    return call_intrin("uint32", "tirx.cuda.hmin2", a, b)
 
 
 def cuda_hmax2(a, b):
-    return call_intrin("uint32", "tirx.cuda_hmax2", a, b)
+    return call_intrin("uint32", "tirx.cuda.hmax2", a, b)
 
 
 def cuda_fp8x4_e4m3_from_float4(x, y, z, w):
-    return call_intrin("uint32", "tirx.cuda_fp8x4_e4m3_from_float4", x, y, z, 
w)
+    return call_intrin("uint32", "tirx.cuda.fp8x4_e4m3_from_float4", x, y, z, 
w)
 
 
 def ptx_map_shared_rank(ptr, rank):
@@ -3941,7 +3941,7 @@ def ptx_mapa(ptr, rank, *, space="", ptx_type="u64", 
return_type="uint64"):
         raise ValueError(f"Unsupported mapa space {space!r}")
     if ptx_type not in ("u32", "u64"):
         raise ValueError(f"Unsupported mapa type {ptx_type!r}")
-    return call_intrin(return_type, "tirx.ptx_mapa", ptr, rank, space, 
ptx_type, return_type)
+    return call_intrin(return_type, "tirx.ptx.mapa", ptr, rank, space, 
ptx_type, return_type)
 
 
 def cuda_atomic_cas(ptr, old_val, new_val):
@@ -3964,7 +3964,7 @@ def cuda_atomic_cas(ptr, old_val, new_val):
         The call expression.
     """
     old_val = tir.convert(old_val)
-    return call_intrin(old_val.dtype, "tirx.cuda_atomic_cas", ptr, old_val, 
new_val)
+    return call_intrin(old_val.dtype, "tirx.cuda.atomic_cas", ptr, old_val, 
new_val)
 
 
 ########################################################
@@ -3981,7 +3981,7 @@ def nvshmem_my_pe():
         The call expression.
     """
 
-    return call_intrin("int32", "tirx.nvshmem_my_pe")
+    return call_intrin("int32", "tirx.nvshmem.my_pe")
 
 
 def nvshmem_n_pes():
@@ -3993,7 +3993,7 @@ def nvshmem_n_pes():
         The call expression.
     """
 
-    return call_intrin("int32", "tirx.nvshmem_n_pes")
+    return call_intrin("int32", "tirx.nvshmem.n_pes")
 
 
 def nvshmem_getmem_nbi(dst, src, nelems, pe):
@@ -4019,7 +4019,7 @@ def nvshmem_getmem_nbi(dst, src, nelems, pe):
         The call expression.
     """  # noqa: E501
 
-    return call_intrin("", "tirx.nvshmem_getmem_nbi", dst, src, nelems, pe)
+    return call_intrin("", "tirx.nvshmem.getmem_nbi", dst, src, nelems, pe)
 
 
 def nvshmem_putmem_nbi(dst, src, nelems, pe):
@@ -4045,7 +4045,7 @@ def nvshmem_putmem_nbi(dst, src, nelems, pe):
         The call expression.
     """
 
-    return call_intrin("", "tirx.nvshmem_putmem_nbi", dst, src, nelems, pe)
+    return call_intrin("", "tirx.nvshmem.putmem_nbi", dst, src, nelems, pe)
 
 
 def nvshmem_getmem_nbi_warp(dst, src, nelems, pe):
@@ -4071,7 +4071,7 @@ def nvshmem_getmem_nbi_warp(dst, src, nelems, pe):
         The call expression.
     """  # noqa: E501
 
-    return call_intrin("", "tirx.nvshmem_getmem_nbi_warp", dst, src, nelems, 
pe)
+    return call_intrin("", "tirx.nvshmem.getmem_nbi_warp", dst, src, nelems, 
pe)
 
 
 def nvshmem_putmem_nbi_warp(dst, src, nelems, pe):
@@ -4097,7 +4097,7 @@ def nvshmem_putmem_nbi_warp(dst, src, nelems, pe):
         The call expression.
     """
 
-    return call_intrin("", "tirx.nvshmem_putmem_nbi_warp", dst, src, nelems, 
pe)
+    return call_intrin("", "tirx.nvshmem.putmem_nbi_warp", dst, src, nelems, 
pe)
 
 
 def nvshmem_getmem_nbi_block(dst, src, nelems, pe):
@@ -4123,7 +4123,7 @@ def nvshmem_getmem_nbi_block(dst, src, nelems, pe):
         The call expression.
     """  # noqa: E501
 
-    return call_intrin("", "tirx.nvshmem_getmem_nbi_block", dst, src, nelems, 
pe)
+    return call_intrin("", "tirx.nvshmem.getmem_nbi_block", dst, src, nelems, 
pe)
 
 
 def nvshmem_putmem_nbi_block(dst, src, nelems, pe):
@@ -4149,7 +4149,7 @@ def nvshmem_putmem_nbi_block(dst, src, nelems, pe):
         The call expression.
     """
 
-    return call_intrin("", "tirx.nvshmem_putmem_nbi_block", dst, src, nelems, 
pe)
+    return call_intrin("", "tirx.nvshmem.putmem_nbi_block", dst, src, nelems, 
pe)
 
 
 def nvshmem_signal_op(sig_addr, signal, sig_op, pe):
@@ -4176,7 +4176,7 @@ def nvshmem_signal_op(sig_addr, signal, sig_op, pe):
     """
 
     _choice("sig_op", sig_op, _NVSHMEM_SIG_OP)
-    return call_intrin("", "tirx.nvshmem_signal_op", sig_addr, signal, sig_op, 
pe)
+    return call_intrin("", "tirx.nvshmem.signal_op", sig_addr, signal, sig_op, 
pe)
 
 
 def nvshmem_wait_until(ivar, cmp, cmp_value, type="uint64_t"):
@@ -4203,7 +4203,7 @@ def nvshmem_wait_until(ivar, cmp, cmp_value, 
type="uint64_t"):
     """
 
     _choice("cmp", cmp, _NVSHMEM_CMP)
-    return call_intrin("", "tirx.nvshmem_wait_until", ivar, cmp, cmp_value, 
type)
+    return call_intrin("", "tirx.nvshmem.wait_until", ivar, cmp, cmp_value, 
type)
 
 
 def nvshmem_quiet():
@@ -4215,7 +4215,7 @@ def nvshmem_quiet():
         The call expression.
     """
 
-    return call_intrin("", "tirx.nvshmem_quiet")
+    return call_intrin("", "tirx.nvshmem.quiet")
 
 
 def nvshmem_putmem_signal_nbi(dst, src, nelems, sig_addr, signal, sig_op, pe):
@@ -4251,7 +4251,7 @@ def nvshmem_putmem_signal_nbi(dst, src, nelems, sig_addr, 
signal, sig_op, pe):
     """  # noqa: E501
 
     return call_intrin(
-        "", "tirx.nvshmem_putmem_signal_nbi", dst, src, nelems, sig_addr, 
signal, sig_op, pe
+        "", "tirx.nvshmem.putmem_signal_nbi", dst, src, nelems, sig_addr, 
signal, sig_op, pe
     )
 
 
@@ -4288,7 +4288,7 @@ def nvshmem_putmem_signal_nbi_warp(dst, src, nelems, 
sig_addr, signal, sig_op, p
     """  # noqa: E501
 
     return call_intrin(
-        "", "tirx.nvshmem_putmem_signal_nbi_warp", dst, src, nelems, sig_addr, 
signal, sig_op, pe
+        "", "tirx.nvshmem.putmem_signal_nbi_warp", dst, src, nelems, sig_addr, 
signal, sig_op, pe
     )
 
 
@@ -4325,7 +4325,7 @@ def nvshmem_putmem_signal_nbi_block(dst, src, nelems, 
sig_addr, signal, sig_op,
     """  # noqa: E501
 
     return call_intrin(
-        "", "tirx.nvshmem_putmem_signal_nbi_block", dst, src, nelems, 
sig_addr, signal, sig_op, pe
+        "", "tirx.nvshmem.putmem_signal_nbi_block", dst, src, nelems, 
sig_addr, signal, sig_op, pe
     )
 
 
@@ -4338,7 +4338,7 @@ def nvshmem_fence():
         The call expression.
     """
 
-    return call_intrin("", "tirx.nvshmem_fence")
+    return call_intrin("", "tirx.nvshmem.fence")
 
 
 def nvshmem_barrier_all():
@@ -4350,4 +4350,4 @@ def nvshmem_barrier_all():
         The call expression.
     """
 
-    return call_intrin("", "tirx.nvshmem_barrier_all")
+    return call_intrin("", "tirx.nvshmem.barrier_all")
diff --git a/python/tvm/backend/cuda/operator/intrinsics/cp_async.py 
b/python/tvm/backend/cuda/operator/intrinsics/cp_async.py
index 3e6bc015e8..a63c3e784f 100644
--- a/python/tvm/backend/cuda/operator/intrinsics/cp_async.py
+++ b/python/tvm/backend/cuda/operator/intrinsics/cp_async.py
@@ -383,6 +383,9 @@ def codegen_ptx_cp_async(*args):
     return result[0] if isinstance(result, tuple) else result
 
 
+CODEGEN_REGISTRY["tirx.ptx.cp_async_raw"] = 
CODEGEN_REGISTRY["tirx.ptx.cp_async"]
+
+
 # =============================================================================
 # cp.async.bulk.tensor (TMA) — one device_intrinsic per arity variant of each
 # PTX form. Per-dim coord operands materialise via the ``c_signature`` 
callable.
diff --git 
a/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py 
b/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py
index 0afa042b5e..55beaaa69a 100644
--- a/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py
+++ b/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py
@@ -813,7 +813,7 @@ def gemm_async_tcgen05_impl(op_call: TilePrimitiveCall, 
sctx: DispatchContext) -
         """Build: { AllocBuffer(desc); encode(desc, smem); krp }"""
         encode_call = tvm.tirx.call_intrin(
             "",
-            "tirx.ptx_tcgen05_encode_matrix_descriptor",
+            "tirx.ptx.tcgen05_encode_matrix_descriptor",
             tvm.tirx.address_of(desc_buf[0]),
             smem_buf.ptr_to(base),
             ldo,
diff --git a/python/tvm/backend/cuda/script.py 
b/python/tvm/backend/cuda/script.py
index 76ba87344b..effea0c885 100644
--- a/python/tvm/backend/cuda/script.py
+++ b/python/tvm/backend/cuda/script.py
@@ -136,7 +136,7 @@ class CpAsyncNamespace:
     def __call__(self, *args, **kwds):
         # Accept the legacy 6-arg form ``(elem_dtype, dst, dst_off, src,
         # src_off, cp_size)`` that the printer round-trips for the raw
-        # ``tirx.ptx_cp_async`` Call emitted by
+        # ``tirx.ptx.cp_async`` Call emitted by
         # ``tvm.backend.cuda.transform.InjectPTXAsyncCopy``. The pass-emitted
         # Call has 5 args (no ``tvm_access_ptr`` fold) and a
         # per-element-dtype Call.dtype, so build it directly.
@@ -146,7 +146,7 @@ class CpAsyncNamespace:
             elem_dtype, dst, dst_off, src, src_off, cp_size = args
             return tvm.tirx.Call(
                 tvm.DataType(elem_dtype),
-                tvm.ir.Op.get("tirx.ptx_cp_async"),
+                tvm.ir.Op.get("tirx.ptx.cp_async_raw"),
                 [dst, dst_off, src, src_off, cp_size],
             )
         return _dtype_forward(_cuda_op.ptx_cp_async)(*args, **kwds)
@@ -201,7 +201,7 @@ class CpAsyncBulkTensorNamespace:
         cache_policy, has_cache_policy = 
_cuda_op._resolve_cache_policy(cache_hint, cache_policy)
         return _tir_op.call_intrin(
             "",
-            "tirx.ptx_cp_async_bulk_tensor_global_to_cluster",
+            "tirx.ptx.cp_async_bulk_tensor_global_to_cluster",
             dim,
             dst_ptr,
             bar_addr,
@@ -230,7 +230,7 @@ class CpAsyncBulkTensorNamespace:
         cache_policy, has_cache_policy = 
_cuda_op._resolve_cache_policy(cache_hint, cache_policy)
         return _tir_op.call_intrin(
             "",
-            "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster",
+            "tirx.ptx.cp_async_bulk_tensor_tile_gather4_global_to_cluster",
             dim,
             dst_ptr,
             bar_addr,
diff --git a/python/tvm/backend/trn/op.py b/python/tvm/backend/trn/op.py
index d919e4fb85..88cdd048fa 100644
--- a/python/tvm/backend/trn/op.py
+++ b/python/tvm/backend/trn/op.py
@@ -22,49 +22,49 @@ from tvm.tirx.op import call_intrin
 
 
 def nki_load(res, data):
-    return call_intrin("", "tirx.nki_load", res, data)
+    return call_intrin("", "tirx.nki.load", res, data)
 
 
 def nki_store(res, data):
-    return call_intrin("", "tirx.nki_store", res, data)
+    return call_intrin("", "tirx.nki.store", res, data)
 
 
 def nki_tensor_copy(res, data):
-    return call_intrin("", "tirx.nki_tensor_copy", res, data)
+    return call_intrin("", "tirx.nki.tensor_copy", res, data)
 
 
 def nki_matmul(res, lhs, rhs, accum=True):
-    return call_intrin("", "tirx.nki_matmul", res, lhs, rhs, accum)
+    return call_intrin("", "tirx.nki.matmul", res, lhs, rhs, accum)
 
 
 def nki_activation(result, data, opcode, bias=0.0, scale=1.0):
-    return call_intrin("", "tirx.nki_activation", result, data, opcode, bias, 
scale)
+    return call_intrin("", "tirx.nki.activation", result, data, opcode, bias, 
scale)
 
 
 def nki_reciprocal(result, data):
-    return call_intrin("", "tirx.nki_reciprocal", result, data)
+    return call_intrin("", "tirx.nki.reciprocal", result, data)
 
 
 def nki_tensorreduce(result, data, opcode, negate, *axes):
-    return call_intrin("", "tirx.nki_tensorreduce", result, data, opcode, 
negate, *axes)
+    return call_intrin("", "tirx.nki.tensorreduce", result, data, opcode, 
negate, *axes)
 
 
 def nki_tensortensor(result, operand0, operand1, opcode):
-    return call_intrin("", "tirx.nki_tensortensor", result, operand0, 
operand1, opcode)
+    return call_intrin("", "tirx.nki.tensortensor", result, operand0, 
operand1, opcode)
 
 
 def nki_tensorscalar(result, operand0, operand1, opcode, reverse=False):
-    return call_intrin("", "tirx.nki_tensorscalar", result, operand0, 
operand1, opcode, reverse)
+    return call_intrin("", "tirx.nki.tensorscalar", result, operand0, 
operand1, opcode, reverse)
 
 
 def nki_memset(result, value):
-    return call_intrin("", "tirx.nki_memset", result, value)
+    return call_intrin("", "tirx.nki.memset", result, value)
 
 
 def nki_activation_reduce(reduce_res, act_res, data, opcode, reduce_opcode, 
bias=0.0, scale=1.0):
     return call_intrin(
         "",
-        "tirx.nki_activation_reduce",
+        "tirx.nki.activation_reduce",
         reduce_res,
         act_res,
         data,
@@ -80,7 +80,7 @@ def nki_tensorscalar_reduce(
 ):
     return call_intrin(
         "",
-        "tirx.nki_tensorscalar_reduce",
+        "tirx.nki.tensorscalar_reduce",
         reduce_res,
         tensorscalar_res,
         operand0,
@@ -92,7 +92,7 @@ def nki_tensorscalar_reduce(
 
 
 def nki_identity(result, size):
-    return call_intrin("", "tirx.nki_identity", result, size)
+    return call_intrin("", "tirx.nki.identity", result, size)
 
 
 def nki_scalar_tensor_tensor(
@@ -100,7 +100,7 @@ def nki_scalar_tensor_tensor(
 ):
     return call_intrin(
         "",
-        "tirx.nki_scalar_tensor_tensor",
+        "tirx.nki.scalar_tensor_tensor",
         result,
         data,
         operand0,
@@ -117,7 +117,7 @@ def nki_scalar_tensor_scalar(
 ):
     return call_intrin(
         "",
-        "tirx.nki_scalar_tensor_scalar",
+        "tirx.nki.scalar_tensor_scalar",
         result,
         data,
         operand0,
@@ -130,7 +130,7 @@ def nki_scalar_tensor_scalar(
 
 
 def nki_affine_select(result, pred, true_value, false_value):
-    return call_intrin("", "tirx.nki_affine_select", result, pred, true_value, 
false_value)
+    return call_intrin("", "tirx.nki.affine_select", result, pred, true_value, 
false_value)
 
 
 __all__ = [
diff --git a/src/backend/cuda/codegen/codegen_cuda.cc 
b/src/backend/cuda/codegen/codegen_cuda.cc
index 357f2c9585..034620bf1c 100644
--- a/src/backend/cuda/codegen/codegen_cuda.cc
+++ b/src/backend/cuda/codegen/codegen_cuda.cc
@@ -958,18 +958,18 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
   static const Op& tvm_store_matrix_sync_op = 
Op::Get("tirx.tvm_store_matrix_sync");
   static const Op& tvm_mma_sync_op = Op::Get("tirx.tvm_mma_sync");
   static const Op& tvm_bmma_sync_op = Op::Get("tirx.tvm_bmma_sync");
-  static const Op& ptx_mma_op = Op::Get("tirx.ptx_mma");
-  static const Op& ptx_mma_sp_op = Op::Get("tirx.ptx_mma_sp");
+  static const Op& ptx_mma_op = Op::Get("tirx.ptx.mma");
+  static const Op& ptx_mma_sp_op = Op::Get("tirx.ptx.mma_sp");
   static const Op& mma_store_op = Op::Get("tirx.mma_store");
   static const Op& mma_fill_op = Op::Get("tirx.mma_fill");
-  static const Op& ptx_mma_legacy_op = Op::Get("tirx.ptx_mma_legacy");
-  static const Op& ptx_ldmatrix_legacy_op = 
Op::Get("tirx.ptx_ldmatrix_legacy");
+  static const Op& ptx_mma_legacy_op = Op::Get("tirx.ptx.mma_legacy");
+  static const Op& ptx_ldmatrix_legacy_op = 
Op::Get("tirx.ptx.ldmatrix_legacy");
   static const Op& mma_store_legacy_op = Op::Get("tirx.mma_store_legacy");
   static const Op& mma_fill_legacy_op = Op::Get("tirx.mma_fill_legacy");
-  static const Op& ptx_cp_async_bulk_op = Op::Get("tirx.ptx_cp_async_bulk");
-  static const Op& ptx_cp_async_mbarrier_arrive_op = 
Op::Get("tirx.ptx_cp_async_mbarrier_arrive");
+  static const Op& ptx_cp_async_bulk_op = Op::Get("tirx.ptx.cp_async_bulk");
+  static const Op& ptx_cp_async_mbarrier_arrive_op = 
Op::Get("tirx.ptx.cp_async_mbarrier_arrive");
   static const Op& ptx_ldg32_op = Op::Get("tirx.ptx.ldg32");
-  static const Op& cuda_func_call_op = Op::Get("tirx.cuda_func_call");
+  static const Op& cuda_func_call_op = Op::Get("tirx.cuda.func_call");
 
   if (op->op.same_as(tvm_fill_fragment_op)) {
     codegen_tags_.insert("mma");
@@ -1571,7 +1571,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
     TVM_FFI_ICHECK(queue_id && queue_id->value == 0)
         << "For CUDA, the index of an async queue must be 0.";
     this->VisitStmt(op->body);
-    static const Op& ptx_cp_async_commit_group_op = 
Op::Get("tirx.ptx_cp_async_commit_group");
+    static const Op& ptx_cp_async_commit_group_op = 
Op::Get("tirx.ptx.cp_async_commit_group");
     auto commit_group = Call(DataType::Void(), ptx_cp_async_commit_group_op, 
{});
     this->PrintIndent();
     this->VisitExpr(commit_group, this->stream);
@@ -1583,7 +1583,7 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
     TVM_FFI_ICHECK(queue_id && queue_id->value == 0)
         << "For CUDA, the index of an async queue must be 0.";
     auto wait_cnt = wait_attrs.second;
-    static const Op& ptx_cp_async_wait_group_op = 
Op::Get("tirx.ptx_cp_async_wait_group");
+    static const Op& ptx_cp_async_wait_group_op = 
Op::Get("tirx.ptx.cp_async_wait_group");
     auto wait_group = Call(DataType::Void(), ptx_cp_async_wait_group_op, 
{wait_cnt});
     this->PrintIndent();
     this->VisitExpr(wait_group, this->stream);
diff --git a/src/backend/cuda/op/target_builtin.cc 
b/src/backend/cuda/op/target_builtin.cc
index 353c04b501..5c5ad0b12d 100644
--- a/src/backend/cuda/op/target_builtin.cc
+++ b/src/backend/cuda/op/target_builtin.cc
@@ -66,26 +66,11 @@ TIRX_DEFINE_BUILTIN_FUNC(tvm_fill_fragment)
 TIRX_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
 
-TIRX_DEFINE_BUILTIN_FUNC(ptx_mma)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque))
-    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
-                                         
static_cast<int64_t>(ScriptDtypePrintLocation::kFirst));
-
 // Siblings of ptx_mma / ptx_ldmatrix / mma_store / mma_fill that accept
 // (ptr_var, offset) pairs. Codegen emits `ptr + offset` C-pointer
 // arithmetic and lower_warp_memory rewrites the offset's group component
 // to its thread-local index. Used by the s_tir tensor_intrin tensorize
 // path so per-thread fragment offsets stay element-accurate.
-TIRX_DEFINE_BUILTIN_FUNC(ptx_mma_legacy)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque))
-    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
-                                         
static_cast<int64_t>(ScriptDtypePrintLocation::kFirst));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_ldmatrix_legacy)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque))
-    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
-                                         
static_cast<int64_t>(ScriptDtypePrintLocation::kFirst));
-
 TIRX_DEFINE_BUILTIN_FUNC(mma_store_legacy)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
 
@@ -100,100 +85,6 @@ OpRegEntry::RegisterOrGet("tirx.ptx.ldg32")
     .set_attr<TIRxOpCategory>("TIRxOpCategory", ffi::String("device_intrin"), 
10)
     .set_attr<TDeviceIntrinsicNamespace>("TDeviceIntrinsicNamespace", 
ffi::String("ptx"), 10);
 
-TIRX_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque))
-    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
-                                         
static_cast<int64_t>(ScriptDtypePrintLocation::kFirst));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque))
-    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
-                                         
static_cast<int64_t>(ScriptDtypePrintLocation::kFirst));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque))
-    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
-                                         
static_cast<int64_t>(ScriptDtypePrintLocation::kFirst));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque))
-    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
-                                         
static_cast<int64_t>(ScriptDtypePrintLocation::kFirst));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_shared_to_cluster)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque))
-    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
-                                         
static_cast<int64_t>(ScriptDtypePrintLocation::kFirst));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_commit_group)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_wait_group)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_mbarrier_arrive)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_fence).set_attr<TCallEffectKind>(
-    "TCallEffectKind", static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_fence_proxy_async)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_init)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_arrive)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_arrive_expect_tx)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait_acquire_cluster)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_bar_arrive)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_bar_sync)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_tensor_global_to_cluster)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_tensor_shared_to_global)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_tensor_global_to_cluster_prefetch)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_tensor_shared_to_global_reduce)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_commit_group)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_cp_async_bulk_wait_group)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_barrier_cluster_arrive)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_barrier_cluster_wait)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_elect_sync)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_fence_mbarrier_init)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
 OpRegEntry::RegisterOrGet("tirx.ptx.fetch_register")
     .set_name()
     .set_num_inputs(-1)
@@ -202,21 +93,17 @@ OpRegEntry::RegisterOrGet("tirx.ptx.fetch_register")
     .set_attr<TDeviceIntrinsicNamespace>("TDeviceIntrinsicNamespace", 
ffi::String("ptx"))
     .set_attr<TScriptPrinterName>("TScriptPrinterName", 
ffi::String("ptx.fetch_register"));
 
-OpRegEntry::RegisterOrGet("tirx.ptx_fetch_register")
+// Raw legacy cp.async form emitted by InjectPTXAsyncCopy (and round-tripped by
+// the T.ptx.cp_async 6-arg surface). It carries the element dtype in 
Call.dtype
+// and prints it dtype-first; the fork-native tirx.ptx.cp_async form does not.
+OpRegEntry::RegisterOrGet("tirx.ptx.cp_async_raw")
     .set_name()
-    .set_num_inputs(-1)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kPure))
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque))
     .set_attr<TIRxOpCategory>("TIRxOpCategory", ffi::String("device_intrin"))
     .set_attr<TDeviceIntrinsicNamespace>("TDeviceIntrinsicNamespace", 
ffi::String("ptx"))
-    .set_attr<TScriptPrinterName>("TScriptPrinterName", 
ffi::String("ptx.fetch_register"));
-
-// griddepcontrol — programmatic dependent launch synchronization (sm_90+).
-// Both are memory barriers; mark kOpaque to prevent CSE/reordering.
-TIRX_DEFINE_BUILTIN_FUNC(ptx_griddepcontrol_wait)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_griddepcontrol_launch_dependents)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
+    .set_attr<TScriptPrinterName>("TScriptPrinterName", 
ffi::String("ptx.cp_async"))
+    .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
+                                         
static_cast<int64_t>(ScriptDtypePrintLocation::kFirst));
 
 TIRX_DEFINE_BUILTIN_FUNC(mma_store)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque))
@@ -228,99 +115,6 @@ TIRX_DEFINE_BUILTIN_FUNC(mma_fill)
     .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
                                          
static_cast<int64_t>(ScriptDtypePrintLocation::kFirst));
 
-TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_encode_matrix_descriptor)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_noop_barrier)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_mma_async_ss)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_mma_async_rs)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_fence)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_commit_group)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_wgmma_wait_group)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_stmatrix)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_setmaxnreg)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_ld_global_acquire)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_alloc)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_dealloc)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_relinquish_alloc_permit)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_fence_before_thread_sync)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_fence_after_thread_sync)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_ld)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_st)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_wait_ld)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_wait_st)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_encode_matrix_descriptor)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_encode_instr_descriptor)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_encode_instr_descriptor_block_scaled)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_mma)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_mma_block_scale)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_mma_sp)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_mma_sp_block_scale)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_commit)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_cp)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_tcgen05_shift)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(ptx_map_shared_rank)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(cuda_func_call)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
 TIRX_DEFINE_BUILTIN_FUNC(timer_init_cuda)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
 
@@ -333,54 +127,6 @@ TIRX_DEFINE_BUILTIN_FUNC(timer_end_cuda)
 TIRX_DEFINE_BUILTIN_FUNC(timer_finalize_cuda)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
 
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_my_pe)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_n_pes)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_getmem_nbi)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_nbi)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_getmem_nbi_warp)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_nbi_warp)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_getmem_nbi_block)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_nbi_block)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_signal_op)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_wait_until)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_quiet)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_signal_nbi)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_signal_nbi_warp)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_putmem_signal_nbi_block)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_fence)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nvshmem_barrier_all)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
 RegisterDeviceIntrinsicAliases();
   // clang-format on
 }
@@ -388,21 +134,20 @@ RegisterDeviceIntrinsicAliases();
 namespace {
 
 struct DeviceIntrinsicRegistration {
-  const char* flat_name;
+  const char* name;
   const char* namespace_name;
   CallEffectKind effect_kind;
 };
 
 void RegisterDeviceIntrinsic(const DeviceIntrinsicRegistration& reg) {
-  std::string flat_name(reg.flat_name);
+  std::string name(reg.name);
   std::string namespace_name(reg.namespace_name);
   std::string prefix = namespace_name + "_";
-  std::string suffix = flat_name;
+  std::string suffix = name;
   if (suffix.rfind(prefix, 0) == 0) {
     suffix = suffix.substr(prefix.size());
   }
 
-  std::string flat_op_name = "tirx." + flat_name;
   std::string canonical_op_name = "tirx." + namespace_name + "." + suffix;
   ffi::String namespace_attr(namespace_name);
   ffi::String printer_name(namespace_name + "." + suffix);
@@ -419,7 +164,6 @@ void RegisterDeviceIntrinsic(const 
DeviceIntrinsicRegistration& reg) {
         .set_attr<TScriptPrinterName>("TScriptPrinterName", printer_name, 
/*plevel=*/15);
   };
 
-  register_one(flat_op_name);
   register_one(canonical_op_name);
 }
 
diff --git a/src/backend/trn/codegen/codegen_trn.cc 
b/src/backend/trn/codegen/codegen_trn.cc
index 9b798c3dc8..b057cb1509 100644
--- a/src/backend/trn/codegen/codegen_trn.cc
+++ b/src/backend/trn/codegen/codegen_trn.cc
@@ -360,22 +360,22 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, 
std::ostream& os) {  // NOL
   auto is_op = [&](const Op& compat, const char* canonical_name) {
     return op->op.same_as(compat) || (op_node != nullptr && op_node->name == 
canonical_name);
   };
-  static const Op& nki_matmul_op = Op::Get("tirx.nki_matmul");
-  static const Op& nki_load_op = Op::Get("tirx.nki_load");
-  static const Op& nki_store_op = Op::Get("tirx.nki_store");
-  static const Op& nki_tensor_copy_op = Op::Get("tirx.nki_tensor_copy");
-  static const Op& nki_activation_op = Op::Get("tirx.nki_activation");
-  static const Op& nki_reciprocal_op = Op::Get("tirx.nki_reciprocal");
-  static const Op& nki_tensortensor_op = Op::Get("tirx.nki_tensortensor");
-  static const Op& nki_tensorscalar_op = Op::Get("tirx.nki_tensorscalar");
-  static const Op& nki_memset_op = Op::Get("tirx.nki_memset");
-  static const Op& nki_tensorreduce_op = Op::Get("tirx.nki_tensorreduce");
-  static const Op& nki_activation_reduce_op = 
Op::Get("tirx.nki_activation_reduce");
-  static const Op& nki_tensorscalar_reduce_op = 
Op::Get("tirx.nki_tensorscalar_reduce");
-  static const Op& nki_identity_op = Op::Get("tirx.nki_identity");
-  static const Op& nki_scalar_tensor_tensor_op = 
Op::Get("tirx.nki_scalar_tensor_tensor");
-  static const Op& nki_scalar_tensor_scalar_op = 
Op::Get("tirx.nki_scalar_tensor_scalar");
-  static const Op& nki_affine_select_op = Op::Get("tirx.nki_affine_select");
+  static const Op& nki_matmul_op = Op::Get("tirx.nki.matmul");
+  static const Op& nki_load_op = Op::Get("tirx.nki.load");
+  static const Op& nki_store_op = Op::Get("tirx.nki.store");
+  static const Op& nki_tensor_copy_op = Op::Get("tirx.nki.tensor_copy");
+  static const Op& nki_activation_op = Op::Get("tirx.nki.activation");
+  static const Op& nki_reciprocal_op = Op::Get("tirx.nki.reciprocal");
+  static const Op& nki_tensortensor_op = Op::Get("tirx.nki.tensortensor");
+  static const Op& nki_tensorscalar_op = Op::Get("tirx.nki.tensorscalar");
+  static const Op& nki_memset_op = Op::Get("tirx.nki.memset");
+  static const Op& nki_tensorreduce_op = Op::Get("tirx.nki.tensorreduce");
+  static const Op& nki_activation_reduce_op = 
Op::Get("tirx.nki.activation_reduce");
+  static const Op& nki_tensorscalar_reduce_op = 
Op::Get("tirx.nki.tensorscalar_reduce");
+  static const Op& nki_identity_op = Op::Get("tirx.nki.identity");
+  static const Op& nki_scalar_tensor_tensor_op = 
Op::Get("tirx.nki.scalar_tensor_tensor");
+  static const Op& nki_scalar_tensor_scalar_op = 
Op::Get("tirx.nki.scalar_tensor_scalar");
+  static const Op& nki_affine_select_op = Op::Get("tirx.nki.affine_select");
 
   if (is_op(nki_matmul_op, "tirx.nki.matmul")) {
     TVM_FFI_ICHECK_EQ(op->args.size(), 4);
diff --git a/src/backend/trn/op/target_builtin.cc 
b/src/backend/trn/op/target_builtin.cc
index c0d915bb2a..a73057e609 100644
--- a/src/backend/trn/op/target_builtin.cc
+++ b/src/backend/trn/op/target_builtin.cc
@@ -35,12 +35,6 @@ namespace tvm {
 namespace tirx {
 namespace builtin {
 
-#define TIRX_DEFINE_BUILTIN_FUNC(OpName)                                       
    \
-  OpRegEntry::RegisterOrGet("tirx." #OpName)                                   
    \
-      .set_name()                                                              
    \
-      .set_attr<TScriptPrinterName>("TScriptPrinterName", 
ffi::String(#OpName), 1) \
-      .set_attr<TIRxOpCategory>("TIRxOpCategory", ffi::String("builtin"), 
/*plevel=*/1)
-
 namespace {
 void RegisterNKIIntrinsicAliases();
 }
@@ -51,69 +45,19 @@ static bool registered = false;
 if (registered) return;
 registered = true;
 
-TIRX_DEFINE_BUILTIN_FUNC(nki_load).set_attr<TCallEffectKind>(
-    "TCallEffectKind", static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_store).set_attr<TCallEffectKind>(
-    "TCallEffectKind", static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_tensor_copy)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_matmul)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_activation)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_reciprocal)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_tensortensor)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_tensorscalar)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_memset)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_tensorreduce)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_activation_reduce)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_tensorscalar_reduce)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_identity)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_scalar_tensor_tensor)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_scalar_tensor_scalar)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
-TIRX_DEFINE_BUILTIN_FUNC(nki_affine_select)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
static_cast<int64_t>(CallEffectKind::kOpaque));
-
 RegisterNKIIntrinsicAliases();
   // clang-format on
 }
 
 namespace {
 
-void RegisterNKIIntrinsic(const char* flat_name) {
-  std::string flat(flat_name);
+void RegisterNKIIntrinsic(const char* name) {
   std::string prefix = "nki_";
-  std::string suffix = flat;
+  std::string suffix(name);
   if (suffix.rfind(prefix, 0) == 0) {
     suffix = suffix.substr(prefix.size());
   }
 
-  std::string flat_op_name = "tirx." + flat;
   std::string canonical_op_name = "tirx.nki." + suffix;
   ffi::String namespace_attr("nki");
   ffi::String printer_name("nki." + suffix);
@@ -130,7 +74,6 @@ void RegisterNKIIntrinsic(const char* flat_name) {
         .set_attr<TScriptPrinterName>("TScriptPrinterName", printer_name, 
/*plevel=*/15);
   };
 
-  register_one(flat_op_name);
   register_one(canonical_op_name);
 }
 
@@ -161,8 +104,6 @@ void RegisterNKIIntrinsicAliases() {
 
 }  // namespace
 
-#undef TIRX_DEFINE_BUILTIN_FUNC
-
 TVM_FFI_STATIC_INIT_BLOCK() { RegisterTRNTargetBuiltins(); }
 
 }  // namespace builtin
diff --git a/src/s_tir/transform/inject_ptx_async_copy.cc 
b/src/s_tir/transform/inject_ptx_async_copy.cc
index 3a0f113499..500c2623be 100644
--- a/src/s_tir/transform/inject_ptx_async_copy.cc
+++ b/src/s_tir/transform/inject_ptx_async_copy.cc
@@ -90,7 +90,7 @@ class PTXAsyncCopyInjector : public StmtMutator {
           if (predicated) {
             args.push_back(predicate_value);
           }
-          static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async");
+          static const Op& ptx_cp_async_op = Op::Get("tirx.ptx.cp_async_raw");
           return Evaluate(Call(store->buffer->dtype, ptx_cp_async_op, args));
         }
 
@@ -119,7 +119,7 @@ class PTXAsyncCopyInjector : public StmtMutator {
             return PrimExpr();
           }();
           if (src_offset.defined() && dst_offset.defined()) {
-            static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async");
+            static const Op& ptx_cp_async_op = 
Op::Get("tirx.ptx.cp_async_raw");
             return Evaluate(Call(store->buffer->dtype, ptx_cp_async_op,
                                  {store->buffer->data, mul(dst_offset, 
PrimExpr(index_factor)),
                                   load->buffer->data, src_offset, 
PrimExpr(bytes)}));
@@ -149,7 +149,7 @@ class PTXAsyncCopyInjector : public StmtMutator {
           }();
 
           if (src_offset.defined() && dst_offset.defined()) {
-            static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async");
+            static const Op& ptx_cp_async_op = 
Op::Get("tirx.ptx.cp_async_raw");
             return Evaluate(
                 Call(store->buffer->dtype, ptx_cp_async_op,
                      {store->buffer->data, mul(dst_offset, 
PrimExpr(index_factor)),
diff --git a/src/s_tir/transform/merge_shared_memory_allocations.cc 
b/src/s_tir/transform/merge_shared_memory_allocations.cc
index 4b61c8994c..1626d02e3f 100644
--- a/src/s_tir/transform/merge_shared_memory_allocations.cc
+++ b/src/s_tir/transform/merge_shared_memory_allocations.cc
@@ -487,7 +487,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const CallNode* op) final {
-    static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async");
+    static const Op& ptx_cp_async_op = Op::Get("tirx.ptx.cp_async_raw");
     if (op->op.same_as(builtin::tvm_access_ptr())) {
       TVM_FFI_ICHECK_EQ(op->args.size(), 5U);
       DataType dtype = op->args[0].dtype();
diff --git a/src/tirx/analysis/filter_canonical.cc 
b/src/tirx/analysis/filter_canonical.cc
index dfefdd51c0..61af4812f7 100644
--- a/src/tirx/analysis/filter_canonical.cc
+++ b/src/tirx/analysis/filter_canonical.cc
@@ -45,12 +45,8 @@ bool IsBitwiseAndCall(const CallNode* call) {
 }
 
 bool IsPtxElectSyncCall(const CallNode* call) {
-  static const Op& ptx_elect_sync_op = Op::Get("tirx.ptx_elect_sync");
-  if (call->op.same_as(ptx_elect_sync_op)) return true;
-  if (auto op = call->op.as<Op>()) {
-    return op.value()->name == "tirx.ptx.elect_sync";
-  }
-  return false;
+  static const Op& ptx_elect_sync_op = Op::Get("tirx.ptx.elect_sync");
+  return call->op.same_as(ptx_elect_sync_op);
 }
 
 // Strip implicit Cast wrappers from a predicate. Bool-vs-int mixing in the
@@ -194,7 +190,7 @@ bool TryParseCompareAtom(const PrimExpr& expr, const 
ScopeIdPredicate& is_scope_
   return true;
 }
 
-// Try to read `expr` as a direct `Call("tirx.ptx_elect_sync")` atom.
+// Try to read `expr` as a direct `Call("tirx.ptx.elect_sync")` atom.
 // Composed forms like `elect_sync() != 0` or `not elect_sync()` are NOT
 // accepted -- the canonical grammar requires a bare elect_sync call.
 bool TryParseElectSyncAtom(const PrimExpr& expr, FilterAtom* out) {
diff --git a/src/tirx/analysis/filter_canonical.h 
b/src/tirx/analysis/filter_canonical.h
index f3eb579214..6dfb6bff72 100644
--- a/src/tirx/analysis/filter_canonical.h
+++ b/src/tirx/analysis/filter_canonical.h
@@ -27,7 +27,7 @@
  *
  *   pred := atom (AND atom)*        // pure n-ary conjunction (no OR/NOT)
  *   atom := scopeid_var <op> const  // op in {==, <, <=, >, >=}
- *         | Call("tirx.ptx_elect_sync")
+ *         | Call("tirx.ptx.elect_sync")
  *
  * Consumers:
  *   1. tile_primitive_dispatch routes a bare `if cond:` to atom-based
@@ -62,7 +62,7 @@ namespace tirx {
  */
 enum class FilterAtomKind {
   kRange,      // scopeid_var in [lo, hi); covers ==, <, <=, >, >=
-  kElectSync,  // Call("tirx.ptx_elect_sync")
+  kElectSync,  // Call("tirx.ptx.elect_sync")
 };
 
 /*!
@@ -77,7 +77,7 @@ enum class FilterAtomKind {
  *   - `elect_sync_call` is unset.
  *
  * For `kElectSync`:
- *   - `elect_sync_call`: the original `Call("tirx.ptx_elect_sync")` PrimExpr,
+ *   - `elect_sync_call`: the original `Call("tirx.ptx.elect_sync")` PrimExpr,
  *     preserved verbatim so downstream consumers (e.g. selector construction
  *     in tile_primitive_dispatch) can reuse it without re-synthesizing.
  *   - `scopeid_var`, `lo`, `hi` are unset.
@@ -123,7 +123,7 @@ using ScopeIdPredicate = std::function<bool(const Var&)>;
  * Grammar (see file header):
  *   pred := atom (AND atom)*
  *   atom := scopeid_var <op> const  (op in {==, <, <=, >, >=})
- *         | Call("tirx.ptx_elect_sync")
+ *         | Call("tirx.ptx.elect_sync")
  *
  * Returns:
  *   - `std::nullopt` if `cond` does not match the grammar. The caller should
diff --git a/src/tirx/script/printer/expr.cc b/src/tirx/script/printer/expr.cc
index f732aa619b..d4212bbb10 100644
--- a/src/tirx/script/printer/expr.cc
+++ b/src/tirx/script/printer/expr.cc
@@ -315,7 +315,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
         }
         // cuda_func_call: last arg is source_code (keyword-only in the Python 
API).
         // Print it as source_code=... to enable TVMScript round-trip.
-        if (op->name == "tirx.cuda_func_call" || op->name == 
"tirx.cuda.func_call") {
+        if (op->name == "tirx.cuda.func_call") {
           int n_args = call->args.size();
           ffi::Array<ExprDoc> args;
           // All args except the last (source_code) are positional.
diff --git a/src/tirx/transform/tile_primitive_dispatch.cc 
b/src/tirx/transform/tile_primitive_dispatch.cc
index 80cde75e7e..eceb897893 100644
--- a/src/tirx/transform/tile_primitive_dispatch.cc
+++ b/src/tirx/transform/tile_primitive_dispatch.cc
@@ -125,12 +125,8 @@ class ElectSyncFinder : public StmtExprVisitor {
 
   void VisitExpr_(const CallNode* op) final {
     auto is_canonical_elect_sync = [&]() {
-      static const Op& ptx_elect_sync_op = Op::Get("tirx.ptx_elect_sync");
-      if (op->op.same_as(ptx_elect_sync_op)) return true;
-      if (auto call_op = op->op.as<Op>()) {
-        return call_op.value()->name == "tirx.ptx.elect_sync";
-      }
-      return false;
+      static const Op& ptx_elect_sync_op = Op::Get("tirx.ptx.elect_sync");
+      return op->op.same_as(ptx_elect_sync_op);
     };
     if (is_canonical_elect_sync()) {
       found_ = true;
diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py 
b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
index 1ed69262a4..e71df73e71 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py
@@ -28,11 +28,26 @@ from tvm.script import tirx as T
 from tvm.testing import env
 
 
+def test_cp_async_raw_dtype_round_trips():
+    # The raw cp.async form emitted by InjectPTXAsyncCopy carries the element
+    # dtype in Call.dtype and must survive a TVMScript print -> parse 
round-trip
+    # (it prints dtype-first via tirx.ptx.cp_async_raw). Guards the regression
+    # where the element dtype was dropped after the flat op was phased out.
+    @T.prim_func
+    def f(A: T.Buffer((128,), "float16"), B: T.Buffer((128,), "float16")):
+        T.func_attr({"global_symbol": "f"})
+        for i in T.serial(8):
+            T.ptx.cp_async("float16", B.data, i * 16, A.data, i * 16, 16)
+
+    reparsed = tvm.script.from_source(f.script())
+    tvm.ir.assert_structural_equal(f, reparsed)
+
+
 def count_cp_async(stmt):
     num_alloc = [0]
 
     def verify(n):
-        if isinstance(n, tvm.tirx.Call) and n.op.name == "tirx.ptx_cp_async":
+        if isinstance(n, tvm.tirx.Call) and n.op.name == 
"tirx.ptx.cp_async_raw":
             num_alloc[0] += 1
 
     tvm.tirx.stmt_functor.post_order_visit(stmt, verify)


Reply via email to