This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ab25b49225 [TIR] Fix InjectPTXLDG32 segfaults and skip non-CUDA 
targets (#18671)
ab25b49225 is described below

commit ab25b49225cfc4b91171f111578bfb5906cabae1
Author: YinHanke <[email protected]>
AuthorDate: Wed Jan 28 00:51:17 2026 +0800

    [TIR] Fix InjectPTXLDG32 segfaults and skip non-CUDA targets (#18671)
    
    ### Motivation
    InjectPTXLDG32 rewrites BufferStore when encountering if_then_else, but
    it only
    initializes temporary buffers when an Allocate node exists. For
    functions without
    Allocate, this leads to uninitialized buffers and a hard segfault during
    compilation.
    In addition, the PTX-only pass can run on CPU/LLVM targets when
    tir.ptx_ldg32=1,
    injecting PTX intrinsics that are invalid for non-CUDA codegen.
    
    This PR ensures temporary buffers are created even when no Allocate
    exists, and
    skips InjectPTXLDG32 on non-CUDA targets, preventing segfaults and
    invalid PTX
    intrinsics on CPU.
    
    ### Changes
    - Ensure temp buffers are created when the rewrite path is taken without
    Allocate
    - Insert allocations at the function level when needed
    - Guard InjectPTXLDG32 so it only runs on CUDA targets
    - Add tests for CUDA (insertion) and CPU (skip) behavior
    
    ### Testing
    test_tir_transform_inject_ptx_ldg32.py
    
    ### Fixes
    - [#18612](https://github.com/apache/tvm/issues/18612)
    - [#18617](https://github.com/apache/tvm/issues/18617)
    - [#18599](https://github.com/apache/tvm/issues/18599)
---
 src/tir/transforms/inject_ptx_ldg32.cc             | 44 +++++++++---
 .../test_tir_transform_inject_ptx_ldg32.py         | 80 ++++++++++++++++++++++
 2 files changed, 115 insertions(+), 9 deletions(-)

diff --git a/src/tir/transforms/inject_ptx_ldg32.cc 
b/src/tir/transforms/inject_ptx_ldg32.cc
index 8cdef1be44..f52539fa77 100644
--- a/src/tir/transforms/inject_ptx_ldg32.cc
+++ b/src/tir/transforms/inject_ptx_ldg32.cc
@@ -35,16 +35,22 @@ namespace tir {
 
 class PTXRewriter : public StmtMutator {
  public:
-  Stmt VisitStmt_(const AllocateNode* allocate) final {
-    if (!has_buffer_1) {
-      has_buffer_1 = true;
-      // addr[0] -> global_addr /  addr[1] -> local_addr
-      addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)}, 
DataType::Int(32), "addr", "local");
-      predicate_buffer =
-          decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(), 
"predicate", "local");
+  Stmt AddAllocationsIfNeeded(Stmt body) {
+    if (!needs_buffer || has_buffer_2) {
+      return body;
     }
+    EnsureBuffers();
+    body = Allocate(addr_buffer->data, addr_buffer->dtype, addr_buffer->shape, 
Bool(true), body);
+    body = Allocate(predicate_buffer->data, predicate_buffer->dtype, 
predicate_buffer->shape,
+                    Bool(true), body);
+    has_buffer_2 = true;
+    return body;
+  }
+
+  Stmt VisitStmt_(const AllocateNode* allocate) final {
     Stmt result = StmtMutator::VisitStmt_(allocate);
-    if (!has_buffer_2) {
+    if (needs_buffer && !has_buffer_2) {
+      EnsureBuffers();
       has_buffer_2 = true;
       result =
           Allocate(addr_buffer->data, addr_buffer->dtype, addr_buffer->shape, 
Bool(true), result);
@@ -82,6 +88,8 @@ class PTXRewriter : public StmtMutator {
         if (ramp != nullptr) {
           return result;
         }
+        EnsureBuffers();
+        needs_buffer = true;
         local_addr = store->indices[0];
         BufferStore addr_store(addr_buffer, global_addr, 
{IntImm(DataType::Int(32), 0)});
         BufferStore local_addr_store(addr_buffer, local_addr, 
{IntImm(DataType::Int(32), 1)});
@@ -104,7 +112,19 @@ class PTXRewriter : public StmtMutator {
     return result;
   }
 
+  void EnsureBuffers() {
+    if (has_buffer_1) {
+      return;
+    }
+    has_buffer_1 = true;
+    // addr[0] -> global_addr /  addr[1] -> local_addr
+    addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)}, 
DataType::Int(32), "addr", "local");
+    predicate_buffer =
+        decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(), 
"predicate", "local");
+  }
+
   bool has_buffer_1 = false, has_buffer_2 = false;
+  bool needs_buffer = false;
   Buffer addr_buffer, predicate_buffer;
 };
 
@@ -113,8 +133,14 @@ namespace transform {
 Pass InjectPTXLDG32(bool enable_inject_ptx_intrin) {
   auto pass_func = [enable_inject_ptx_intrin](PrimFunc f, IRModule m, 
PassContext ctx) {
     if (enable_inject_ptx_intrin) {
+      auto target = f->GetAttr<Target>("target");
+      if (!target.defined() || target.value()->kind->name != "cuda") {
+        return f;
+      }
       auto* n = f.CopyOnWrite();
-      n->body = PTXRewriter()(n->body);
+      PTXRewriter rewriter;
+      Stmt body = rewriter(n->body);
+      n->body = rewriter.AddAllocationsIfNeeded(body);
       // inject ptx
     }
     return f;
diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py 
b/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py
new file mode 100644
index 0000000000..55099f252c
--- /dev/null
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T
+
+
+def _count_alloc(stmt):
+    num_alloc = [0]
+
+    def visit(n):
+        if isinstance(n, tvm.tir.Allocate):
+            num_alloc[0] += 1
+
+    tvm.tir.stmt_functor.post_order_visit(stmt, visit)
+    return num_alloc[0]
+
+
+def _count_ptx_ldg32(stmt):
+    num_call = [0]
+
+    def visit(n):
+        if isinstance(n, tvm.tir.Call) and n.op.name == "tir.ptx_ldg32":
+            num_call[0] += 1
+
+    tvm.tir.stmt_functor.post_order_visit(stmt, visit)
+    return num_call[0]
+
+
[email protected]_func
+def where_no_alloc(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "float32")) 
-> None:
+    T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": 
T.target("cuda")})
+    for i in range(4):
+        C[i] = T.if_then_else(A[i] > T.float32(0), A[i], T.float32(0))
+
+
[email protected]_func
+def where_no_alloc_cpu(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), 
"float32")) -> None:
+    T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": 
T.target("llvm")})
+    for i in range(4):
+        C[i] = T.if_then_else(A[i] > T.float32(0), A[i], T.float32(0))
+
+
+def test_inject_ptx_ldg32_inserts_alloc_for_no_alloc_func():
+    mod = tvm.IRModule.from_expr(where_no_alloc)
+    assert _count_alloc(mod["main"].body) == 0
+
+    mod = tvm.tir.transform.InjectPTXLDG32()(mod)
+    assert _count_alloc(mod["main"].body) > 0
+    assert _count_ptx_ldg32(mod["main"].body) == 1
+
+
+def test_inject_ptx_ldg32_skip_non_cuda_target():
+    mod = tvm.IRModule.from_expr(where_no_alloc_cpu)
+    cpu_target = tvm.target.Target("llvm")
+    mod = tvm.IRModule({"main": mod["main"].with_attr("target", cpu_target)})
+    assert _count_alloc(mod["main"].body) == 0
+
+    mod = tvm.tir.transform.InjectPTXLDG32()(mod)
+    assert _count_alloc(mod["main"].body) == 0
+    assert _count_ptx_ldg32(mod["main"].body) == 0
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to