This is an automated email from the ASF dual-hosted git repository. tlopex pushed a commit to branch investigate-s-tir-nested-pipeline-mma in repository https://gitbox.apache.org/repos/asf/tvm.git
commit f180256f0a51be12924eea6e0c6678fb4038ed38 Author: tlopex <[email protected]> AuthorDate: Thu Jun 11 22:20:35 2026 -0400 Fix software pipeline legacy MMA offsets --- src/s_tir/transform/inject_software_pipeline.cc | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index d9da151f39..ec72da8574 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -28,6 +28,7 @@ #include <tvm/s_tir/transform.h> #include <tvm/target/target.h> #include <tvm/tirx/builtin.h> +#include <tvm/tirx/op.h> #include <map> #include <unordered_set> @@ -42,6 +43,14 @@ using namespace tvm::tirx; namespace software_pipeline { +bool IsOp(const Call& call, const Op& compat_op, const char* canonical_name) { + if (call->op.same_as(compat_op)) { + return true; + } + const auto* op_node = call->op.as<OpNode>(); + return op_node != nullptr && op_node->name == canonical_name; +} + /*! * \brief Create a block and infer the access region with the given body. * @@ -112,6 +121,10 @@ class PipelineOpaqueAccessRewriter { static const auto& access_ptr = builtin::tvm_access_ptr(); static const auto& ptx_ldmatrix = builtin::ptx_ldmatrix(); static const auto& ptx_mma = builtin::ptx_mma(); + static const auto& ptx_ldmatrix_legacy = builtin::ptx_ldmatrix_legacy(); + static const auto& ptx_mma_legacy = builtin::ptx_mma_legacy(); + static const auto& mma_store_legacy = builtin::mma_store_legacy(); + static const auto& mma_fill_legacy = builtin::mma_fill_legacy(); if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[0])); auto it = buffer_remap_.find(buffer); @@ -140,6 +153,14 @@ class PipelineOpaqueAccessRewriter { return RewriteBufferAccess(call, {6, 8, 10}); } else if (call->op.same_as(ptx_ldmatrix)) { return RewriteBufferAccess(call, {3}); + } else if (IsOp(call, ptx_mma_legacy, "tirx.ptx.mma_legacy")) { + return RewriteBufferAccess(call, {6, 8, 10}); + } else if (IsOp(call, ptx_ldmatrix_legacy, "tirx.ptx.ldmatrix_legacy")) { + return RewriteBufferAccess(call, {3}); + } else if (call->op.same_as(mma_store_legacy)) { + return RewriteBufferAccess(call, {3}); + } else if (call->op.same_as(mma_fill_legacy)) { + return RewriteBufferAccess(call, {1}); } return call; }
