This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch investigate-s-tir-nested-pipeline-mma
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit f180256f0a51be12924eea6e0c6678fb4038ed38
Author: tlopex <[email protected]>
AuthorDate: Thu Jun 11 22:20:35 2026 -0400

    Fix software pipeline legacy MMA offsets
---
 src/s_tir/transform/inject_software_pipeline.cc | 21 +++++++++++++++++++++
 1 file changed, 21 insertions(+)

diff --git a/src/s_tir/transform/inject_software_pipeline.cc 
b/src/s_tir/transform/inject_software_pipeline.cc
index d9da151f39..ec72da8574 100644
--- a/src/s_tir/transform/inject_software_pipeline.cc
+++ b/src/s_tir/transform/inject_software_pipeline.cc
@@ -28,6 +28,7 @@
 #include <tvm/s_tir/transform.h>
 #include <tvm/target/target.h>
 #include <tvm/tirx/builtin.h>
+#include <tvm/tirx/op.h>
 
 #include <map>
 #include <unordered_set>
@@ -42,6 +43,14 @@ using namespace tvm::tirx;
 
 namespace software_pipeline {
 
+bool IsOp(const Call& call, const Op& compat_op, const char* canonical_name) {
+  if (call->op.same_as(compat_op)) {
+    return true;
+  }
+  const auto* op_node = call->op.as<OpNode>();
+  return op_node != nullptr && op_node->name == canonical_name;
+}
+
 /*!
  * \brief Create a block and infer the access region with the given body.
  *
@@ -112,6 +121,10 @@ class PipelineOpaqueAccessRewriter {
     static const auto& access_ptr = builtin::tvm_access_ptr();
     static const auto& ptx_ldmatrix = builtin::ptx_ldmatrix();
     static const auto& ptx_mma = builtin::ptx_mma();
+    static const auto& ptx_ldmatrix_legacy = builtin::ptx_ldmatrix_legacy();
+    static const auto& ptx_mma_legacy = builtin::ptx_mma_legacy();
+    static const auto& mma_store_legacy = builtin::mma_store_legacy();
+    static const auto& mma_fill_legacy = builtin::mma_fill_legacy();
     if (call->op.same_as(load_matrix_sync) || 
call->op.same_as(store_matrix_sync)) {
       const Buffer& buffer = 
buffer_data_to_buffer_.at(Downcast<Var>(call->args[0]));
       auto it = buffer_remap_.find(buffer);
@@ -140,6 +153,14 @@ class PipelineOpaqueAccessRewriter {
       return RewriteBufferAccess(call, {6, 8, 10});
     } else if (call->op.same_as(ptx_ldmatrix)) {
       return RewriteBufferAccess(call, {3});
+    } else if (IsOp(call, ptx_mma_legacy, "tirx.ptx.mma_legacy")) {
+      return RewriteBufferAccess(call, {6, 8, 10});
+    } else if (IsOp(call, ptx_ldmatrix_legacy, "tirx.ptx.ldmatrix_legacy")) {
+      return RewriteBufferAccess(call, {3});
+    } else if (call->op.same_as(mma_store_legacy)) {
+      return RewriteBufferAccess(call, {3});
+    } else if (call->op.same_as(mma_fill_legacy)) {
+      return RewriteBufferAccess(call, {1});
     }
     return call;
   }

Reply via email to