This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9c8e5a6037 [TIR] Handle Bind in LowerDeviceKernelLaunch (#18912)
9c8e5a6037 is described below

commit 9c8e5a60376ff16a7a88dc841271befd9f32bf96
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Mar 27 09:30:37 2026 -0400

    [TIR] Handle Bind in LowerDeviceKernelLaunch (#18912)
    
    DeviceInfoCollector did not track Bind statements, so when CSE (or
    any other pass) inserted a Bind before a thread_extent AttrStmt, the
    collected extent referenced a locally-bound variable instead of
    function parameters.  LowerDeviceKernelLaunch then produced dangling
    references in the host function.
    
    Fix: record Bind definitions in DeviceInfoCollector and inline them
    when extracting thread_extent values and dynamic shared memory sizes.
---
 src/tirx/transform/lower_device_kernel_launch.cc   | 22 ++++++++-
 .../test_tir_transform_device_kernel_launch.py     | 52 ++++++++++++++++++++++
 2 files changed, 73 insertions(+), 1 deletion(-)

diff --git a/src/tirx/transform/lower_device_kernel_launch.cc 
b/src/tirx/transform/lower_device_kernel_launch.cc
index 3ff4cf17c5..fea8d458b9 100644
--- a/src/tirx/transform/lower_device_kernel_launch.cc
+++ b/src/tirx/transform/lower_device_kernel_launch.cc
@@ -104,6 +104,17 @@ class DeviceInfoCollector : public StmtVisitor {
     return extent.value();
   }
 
+  void VisitStmt_(const BindNode* op) final {
+    // Track Bind definitions so that thread_extent values and
+    // dyn_shmem_size expressions that reference locally-bound
+    // variables (e.g. CSE variables) can be inlined back to
+    // expressions over function parameters.  Substitute earlier
+    // bindings into the value to handle chains (cse_v2 = f(cse_v1)).
+    PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) : 
op->value;
+    bind_map_.Set(op->var, value);
+    StmtVisitor::VisitStmt_(op);
+  }
+
   void VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == attr::thread_extent) {
       IterVar iv = Downcast<IterVar>(op->node);
@@ -113,7 +124,10 @@ class DeviceInfoCollector : public StmtVisitor {
       if (!defined_thread.count(iv.get())) {
         defined_thread.insert(iv.get());
         info_.launch_params.push_back(iv->thread_tag);
-        thread_extent.Set(iv->thread_tag, op->value);
+        // Inline any locally-bound variables (e.g. from CSE) so
+        // that the extent is expressible in terms of function params.
+        PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) : 
op->value;
+        thread_extent.Set(iv->thread_tag, value);
       }
     }
 
@@ -133,6 +147,10 @@ class DeviceInfoCollector : public StmtVisitor {
       }
       dyn_size *= op->buffer->dtype.bytes();
 
+      // Inline any locally-bound variables (e.g. from CSE).
+      if (bind_map_.size()) {
+        dyn_size = Substitute(dyn_size, bind_map_);
+      }
       dyn_shmem_size = dyn_size;
     }
     StmtVisitor::VisitStmt_(op);
@@ -146,6 +164,8 @@ class DeviceInfoCollector : public StmtVisitor {
   ffi::Map<ffi::String, PrimExpr> thread_extent;
   // The amount of dynamic shared memory used
   ffi::Optional<PrimExpr> dyn_shmem_size{std::nullopt};
+  // Accumulated Bind definitions for inlining into extent/size expressions.
+  ffi::Map<Var, PrimExpr> bind_map_;
 };
 
 class ReturnRemover : public StmtExprMutator {
diff --git 
a/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py 
b/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py
index 6d77d7e871..3dab487ab5 100644
--- a/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py
+++ b/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py
@@ -223,5 +223,57 @@ def test_same_device_different_target():
     tvm.ir.assert_structural_equal(After, Expected)
 
 
+def test_bind_before_thread_extent():
+    """DeviceInfoCollector inlines Bind-defined variables in thread extents.
+
+    When CSE (or another pass) inserts Bind statements before
+    thread_extent AttrStmts, the extent value may reference a
+    locally-bound variable instead of function parameters.
+    LowerDeviceKernelLaunch must inline these bindings so that the
+    launch argument is expressible in terms of the caller's arguments.
+    """
+
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer(16, "float32"), n: T.int32):
+            T.func_attr({"target": T.target("llvm")})
+            Before.kernel(A.data, n)
+
+        @T.prim_func
+        def kernel(A_data: T.handle("float32"), n: T.int32):
+            T.func_attr({"target": T.target("cuda"), "global_symbol": 
"kernel"})
+            A = T.decl_buffer(16, dtype="float32", data=A_data)
+            v: T.int32 = n + 1
+            i = T.launch_thread("threadIdx.x", v)
+            A[i] = 0.0
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def main(A: T.Buffer(16, "float32"), n: T.int32):
+            T.func_attr({"target": T.target("llvm")})
+            T.call_packed("kernel", A.data, n, n + 1)
+
+        @T.prim_func
+        def kernel(A_data: T.handle("float32"), n: T.int32):
+            T.func_attr(
+                {
+                    "target": T.target("cuda"),
+                    "calling_conv": 2,
+                    "tirx.kernel_launch_params": ["threadIdx.x"],
+                    "global_symbol": "kernel",
+                    "tirx.is_global_func": True,
+                }
+            )
+            A = T.decl_buffer(16, dtype="float32", data=A_data)
+            v: T.int32 = n + 1
+            i = T.launch_thread("threadIdx.x", v)
+            A[i] = 0.0
+
+    After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to