This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c9a77d6712 [S-TIR][Tests] Fix transform test failures after TIRx 
bringup (#19735)
c9a77d6712 is described below

commit c9a77d671232a192df5241a8ffd5f15be8274224
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 11 17:34:05 2026 -0400

    [S-TIR][Tests] Fix transform test failures after TIRx bringup (#19735)
    
    This PR fixes 11 test failures in `tests/python/s_tir/transform/`
    introduced as side effects of the TIRx bringup (#19581 / 859498dc01), in
    three independent commits.
    
    ### 1. LowerOpaqueBlock: update expected IR for buffer metadata
    annotations
    
    `LowerOpaqueBlock` now emits `buffer_allocated_addr` and
    `buffer_data_alignment` annotations on lowered allocations (intentional
    in #19581: the annotations are consumed downstream by `codegen_cuda.cc`
    / `codegen_trn.cc`; the alignment value 64 comes from
    `kAllocAlignment`). The tests' expected IR predates this, so
    `assert_structural_equal` failed on the missing annotations.
    
    Fix: update the expected IR in
    `test_s_tir_transform_lower_opaque_block.py` to carry the annotations
    (`T.decl_buffer(...)` → `T.alloc_buffer(..., annotations={...})`). Fixes
    6 tests.
    
    ### 2. DefaultGPUSchedule: parse scalar-block test in s_tir mode
    
    #19581 added a well-formedness rule rejecting `SBlockRealize` in
    `tirx=True` mode, which is correct — sblocks are s_tir-mode constructs.
    The hand-written `Before`/`Expected` modules in
    `test_scalar_block_no_loops` were the only ones in the file still using
    plain `T.prim_func`, so they failed at parse time before the pass under
    test even ran.
    
    Fix: parse both modules with `T.prim_func(s_tir=True)`, consistent with
    every other test in the file. Fixes 1 test.
    
    ### 3. InjectPermutedLayout: match legacy PTX intrinsics by canonical
    name
    
    #19581 registers device intrinsics under two Op identities: a flat
    builtin name (returned by `builtin::xxx()` in C++) and a canonical
    dotted name (e.g. `tirx.ptx.ldmatrix_legacy`, produced when TVMScript /
    tensor intrinsics are parsed). `InjectPermutedLayout` only compared with
    `same_as(builtin::...)`, so it silently skipped rewriting the swizzled
    shared-memory offsets of parsed legacy-form calls, leaving the expected
    swizzle index expressions unmatched.
    
    Fix: match `ptx_ldmatrix_legacy` / `mma_store_legacy` by both the
    builtin Op and the canonical name via an `IsOp` helper, following the
    existing pattern in `lower_warp_memory.cc` and `codegen_cuda.cc`. Only
    the legacy intrinsic forms fold shared-memory access into
    `tvm_access_ptr` + offset; non-legacy forms address shared memory
    through `BufferLoad` and are already handled by the BufferLoad visitor,
    so the unreachable `InternalError` throw is replaced by a pass-through.
    (`mma_store_legacy` has no dotted alias, hence the asymmetric name
    strings.) Fixes 4 tests.
---
 src/s_tir/transform/inject_permuted_layout.cc      | 28 ++++++++++-----
 .../test_s_tir_transform_default_gpu_schedule.py   |  4 +--
 .../test_s_tir_transform_lower_opaque_block.py     | 41 ++++++++++++++++++----
 3 files changed, 55 insertions(+), 18 deletions(-)

diff --git a/src/s_tir/transform/inject_permuted_layout.cc 
b/src/s_tir/transform/inject_permuted_layout.cc
index fe90f38cec..74e843a6e5 100644
--- a/src/s_tir/transform/inject_permuted_layout.cc
+++ b/src/s_tir/transform/inject_permuted_layout.cc
@@ -246,6 +246,17 @@ class PermutedLayoutInjector : private 
IRMutatorWithAnalyzer {
     return access_ptr_call;
   }
 
+  // Device intrinsics are registered under both a flat name (the builtin Op)
+  // and a canonical dotted name (emitted by TVMScript and the tensor
+  // intrinsics), so compare against both.
+  static bool IsOp(const Call& call, const Op& compat_op, const char* 
canonical_name) {
+    if (call->op.same_as(compat_op)) {
+      return true;
+    }
+    const auto* op_node = call->op.as<OpNode>();
+    return op_node != nullptr && op_node->name == canonical_name;
+  }
+
   PrimExpr VisitExpr_(const CallNode* op) final {
     // Rewrite from/to shared or shared.dyn to/from local
     auto call = Downcast<Call>(IRMutatorWithAnalyzer::VisitExpr_(op));
@@ -254,12 +265,12 @@ class PermutedLayoutInjector : private 
IRMutatorWithAnalyzer {
       return call;
     }
 
-    if (!call->op.same_as(builtin::ptx_ldmatrix()) && 
!call->op.same_as(builtin::mma_store())) {
-      return call;
-    }
-
-    if (call->op.same_as(builtin::ptx_ldmatrix())) {
-      // form: T.ptx_ldmatrix(..., smem_ptr, smem_offset)
+    // Only the legacy intrinsic forms fold the shared memory access into a
+    // tvm_access_ptr + offset, which must be rewritten here. The non-legacy
+    // forms address shared memory through BufferLoad (e.g. via address_of),
+    // which is already handled by the BufferLoad visitor above.
+    if (IsOp(call, builtin::ptx_ldmatrix_legacy(), 
"tirx.ptx.ldmatrix_legacy")) {
+      // form: T.ptx.ldmatrix_legacy(..., smem_ptr, smem_offset)
       // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask)
       auto access_ptr = call->args[5];
       PrimExpr smem_offset = call->args[6];
@@ -268,7 +279,7 @@ class PermutedLayoutInjector : private 
IRMutatorWithAnalyzer {
       new_call->args.Set(5, new_access_ptr);
       new_call->args.Set(6, IntImm(smem_offset->dtype, 0));
       return call;
-    } else if (call->op.same_as(builtin::mma_store())) {
+    } else if (IsOp(call, builtin::mma_store_legacy(), 
"tirx.mma_store_legacy")) {
       // TODO(yixin): mma_store is not fully tested yet
       // because we will directly store result to Buffer instead of calling 
mma_store now
       auto access_ptr = call->args[2];
@@ -276,9 +287,8 @@ class PermutedLayoutInjector : private 
IRMutatorWithAnalyzer {
       auto new_call = call.CopyOnWrite();
       new_call->args.Set(2, new_access_ptr);
       return call;
-    } else {
-      TVM_FFI_THROW(InternalError) << "Invalid call node: " << call;
     }
+    return call;
   }
 
   static constexpr size_t VECTORIZE_FACTOR = 8;
diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py 
b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
index 891ba3f208..875fe18182 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
@@ -575,14 +575,14 @@ def test_scalar_block_no_loops():
     # fmt: off
     @tvm.script.ir_module
     class Before:
-        @T.prim_func
+        @T.prim_func(s_tir=True)
         def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), 
c: T.Buffer((), "float32")):
             with T.sblock("scalar_add"):
                 c[()] = a[()] + b[()]
 
     @tvm.script.ir_module
     class Expected:
-        @T.prim_func
+        @T.prim_func(s_tir=True)
         def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), 
c: T.Buffer((), "float32")):
             T.func_attr({"tirx.is_scheduled": True})
             # with T.sblock("root"):
diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py 
b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py
index 62ad915a57..441074128e 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py
@@ -56,7 +56,11 @@ def transformed_elementwise_func(a: T.handle, c: T.handle) 
-> None:
     A = T.match_buffer(a, (16, 16), "float32")
     C = T.match_buffer(c, (16, 16), "float32")
     for i in T.serial(0, 16):
-        B_new = T.decl_buffer(shape=[1, 16], dtype="float32")
+        B_new = T.alloc_buffer(
+            [1, 16],
+            "float32",
+            annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 
64},
+        )
         for j in T.serial(0, 16):
             B_new[0, j] = A[i, j] + 1.0
         for j in T.serial(0, 16):
@@ -98,7 +102,12 @@ def transformed_gpu_func(a: T.handle, c: T.handle) -> None:
     T.launch_thread(i0, 4)
     T.launch_thread(i1, 2)
     T.launch_thread(i2, 2)
-    B = T.decl_buffer(shape=[1, 16], dtype="float32", scope="local")
+    B = T.alloc_buffer(
+        [1, 16],
+        "float32",
+        scope="local",
+        annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 64},
+    )
     for j in range(0, 16):
         B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
     for j in range(0, 16):
@@ -133,7 +142,11 @@ def transformed_symbolic_func(a: T.handle, c: T.handle, n: 
T.int32, m: T.int32)
     C = T.match_buffer(c, (n, m), "float32")
 
     for i in range(0, n):
-        B = T.decl_buffer(shape=[m], dtype="float32")
+        B = T.alloc_buffer(
+            [m],
+            "float32",
+            annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 
64},
+        )
         for j in range(0, m):
             B[j] = A[i, j] + 1.0
         for j in range(0, m):
@@ -206,8 +219,16 @@ def transformed_multi_alloc_func(a: T.handle, d: T.handle) 
-> None:
     D = T.match_buffer(d, (32), "float32")
 
     for i in range(0, 32):
-        B = T.decl_buffer(shape=(32,), dtype="float32")
-        C = T.decl_buffer(shape=(32,), dtype="float32")
+        B = T.alloc_buffer(
+            (32,),
+            "float32",
+            annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 
64},
+        )
+        C = T.alloc_buffer(
+            (32,),
+            "float32",
+            annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 
64},
+        )
         B[i] = A[i] + 1.0
         C[i] = A[i] + B[i]
         D[i] = C[i] * 2.0
@@ -242,7 +263,12 @@ def transformed_strided_buffer_func(
 ) -> None:
     # body
     for i0 in T.serial(4):
-        B = T.decl_buffer(shape=[4, 16], dtype="float32", strides=[17, 1])
+        B = T.alloc_buffer(
+            [4, 16],
+            "float32",
+            strides=[17, 1],
+            annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 
64},
+        )
         for i1, j in T.grid(4, 16):
             B[i1, j] = A[i0 * 4 + i1, j] + T.float32(1)
         for i1, j in T.grid(4, 16):
@@ -275,10 +301,11 @@ def transformed_symbolic_strided_buffer_func(a: T.handle):
     n = T.int32()
     A = T.match_buffer(a, (1, n, 10240))
     for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
-        A_pad_shared_dyn = T.decl_buffer(
+        A_pad_shared_dyn = T.alloc_buffer(
             (1, T.min((n + 63) // 64 * 64, 96), 64),
             strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1),
             scope="shared.dyn",
+            annotations={"buffer_allocated_addr": [], "buffer_data_alignment": 
64},
         )
         for ax0, ax1 in T.grid(96, 64):
             if i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64:

Reply via email to