This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c309e4ea5f [TIR] Add cooperative_tensor builtins and 
metal.cooperative_tensor storage scope (#19423)
c309e4ea5f is described below

commit c309e4ea5f6f5d7bae6ab9753995269d56548263
Author: Yichen Yan <[email protected]>
AuthorDate: Mon May 11 21:16:29 2026 +0800

    [TIR] Add cooperative_tensor builtins and metal.cooperative_tensor storage 
scope (#19423)
    
    part of https://github.com/tile-ai/tilelang/pull/1869
    
    ## Summary
    Add TIR builtins and storage scope for Metal cooperative_tensor
    operations (MetalPerformancePrimitives / Metal 4).
    
    ## Motivation
    Apple Metal 4 introduces MetalPerformancePrimitives (MPP) with
    `matmul2d` using `cooperative_tensor` operands. On M5, this routes to
    NAX tensor cores; on M1-M4, it falls back to simdgroup matrix
    instructions. These TIR primitives enable backend codegen to emit MPP
    calls.
    
    ## Changes
    
    ### New TIR builtins
    - `cooperative_tensor_fill(d, index, value, rows, cols)`
    - `cooperative_tensor_load(d, index, ptr, stride, rows, cols,
    transpose)`
    - `cooperative_tensor_store(d, index, ptr, stride, rows, cols,
    transpose)`
    - `cooperative_tensor_multiply_accumulate(d, di, a, ai, b, bi, c, ci, M,
    N, K, trans_a, trans_b)`
    
    ### New storage scope
    - `metal.cooperative_tensor` (`StorageRank::kMetalCooperativeTensor`)
    
    ### Files changed
    - `include/tvm/tirx/builtin.h` — Op declarations
    - `src/tirx/op/builtin.cc` — Op registrations
    - `python/tvm/tirx/op.py` — Python wrappers
    - `python/tvm/script/ir_builder/tirx/ir.py` — Script parser exports
    - `src/runtime/thread_storage_scope.h` — StorageRank enum + scope
    parsing
    
    These builtins mirror the existing `simdgroup_*` builtins for the older
    Metal simdgroup matrix API, extended with M/N/K dimension parameters for
    the matmul2d descriptor.
---
 include/tvm/tirx/builtin.h           |  45 +++++++++++++++
 python/tvm/tirx/op.py                | 104 +++++++++++++++++++++++++++++++++++
 python/tvm/tirx/script/builder/ir.py |   8 +++
 src/runtime/thread_storage_scope.h   |   7 +++
 src/tirx/op/builtin.cc               |  12 ++++
 5 files changed, 176 insertions(+)

diff --git a/include/tvm/tirx/builtin.h b/include/tvm/tirx/builtin.h
index 3339b1aa49..2e69ce80d2 100644
--- a/include/tvm/tirx/builtin.h
+++ b/include/tvm/tirx/builtin.h
@@ -782,6 +782,51 @@ TVM_DLL const Op& simdgroup_store();
  */
 TVM_DLL const Op& simdgroup_multiply_accumulate();
 
+// Metal cooperative_tensor intrinsics (MetalPerformancePrimitives / Metal 4)
+
+/*!
+ * \brief Fill a cooperative_tensor with a given value.
+ *
+ * void cooperative_tensor_fill(Var d, PrimExpr index, PrimExpr value,
+ *                              int rows, int cols);
+ */
+TVM_DLL const Op& cooperative_tensor_fill();
+
+/*!
+ * \brief Load data from device or threadgroup memory into a 
cooperative_tensor.
+ *
+ * void cooperative_tensor_load(Var d, PrimExpr index, PrimExpr ptr,
+ *                              PrimExpr stride, int rows, int cols,
+ *                              bool transpose_matrix,
+ *                              int mma_M, int mma_N, int mma_K,
+ *                              int operand_role);
+ * operand_role: 0=left(A), 1=right(B), 2=destination(C)
+ */
+TVM_DLL const Op& cooperative_tensor_load();
+
+/*!
+ * \brief Store data from a cooperative_tensor to device or threadgroup memory.
+ *
+ * void cooperative_tensor_store(Var d, PrimExpr index, PrimExpr ptr,
+ *                               PrimExpr stride, int rows, int cols,
+ *                               bool transpose_matrix,
+ *                               int mma_M, int mma_N, int mma_K,
+ *                               int operand_role);
+ * operand_role: 0=left(A), 1=right(B), 2=destination(C)
+ */
+TVM_DLL const Op& cooperative_tensor_store();
+
+/*!
+ * \brief Multiply and accumulate two matrices using cooperative_tensor
+ *        (MetalPerformancePrimitives matmul2d).
+ *
+ * void cooperative_tensor_multiply_accumulate(
+ *     Var d, PrimExpr index_d, Var a, PrimExpr index_a,
+ *     Var b, PrimExpr index_b, Var c, PrimExpr index_c,
+ *     int M, int N, int K, bool transpose_a, bool transpose_b);
+ */
+TVM_DLL const Op& cooperative_tensor_multiply_accumulate();
+
 // TODO(tvm-team) replace the usage of the vector operations by Shuffle.
 /*!
  * \brief Get the high level half of the vector
diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py
index 566f5d905b..2cdc6f0b36 100644
--- a/python/tvm/tirx/op.py
+++ b/python/tvm/tirx/op.py
@@ -1793,6 +1793,110 @@ def simdgroup_multiply_accumulate(
     )
 
 
+def cooperative_tensor_fill(
+    d: Var,
+    index: PrimExpr,
+    value: PrimExpr,
+    rows: int,
+    cols: int,
+):
+    return call_intrin("handle", "tirx.cooperative_tensor_fill", d, index, 
value, rows, cols)
+
+
+def cooperative_tensor_load(
+    d: Var,
+    index: PrimExpr,
+    ptr: PrimExpr,
+    stride: PrimExpr,
+    rows: int,
+    cols: int,
+    transpose_matrix: bool = False,
+    mma_M: int = 0,
+    mma_N: int = 0,
+    mma_K: int = 0,
+    operand_role: int = 0,
+):
+    return call_intrin(
+        "handle",
+        "tirx.cooperative_tensor_load",
+        d,
+        index,
+        ptr,
+        stride,
+        rows,
+        cols,
+        transpose_matrix,
+        mma_M,
+        mma_N,
+        mma_K,
+        operand_role,
+    )
+
+
+def cooperative_tensor_store(
+    d: PrimExpr,
+    index: PrimExpr,
+    ptr: PrimExpr,
+    stride: PrimExpr,
+    rows: int,
+    cols: int,
+    transpose_matrix: bool = False,
+    mma_M: int = 0,
+    mma_N: int = 0,
+    mma_K: int = 0,
+    operand_role: int = 0,
+):
+    return call_intrin(
+        "handle",
+        "tirx.cooperative_tensor_store",
+        d,
+        index,
+        ptr,
+        stride,
+        rows,
+        cols,
+        transpose_matrix,
+        mma_M,
+        mma_N,
+        mma_K,
+        operand_role,
+    )
+
+
+def cooperative_tensor_multiply_accumulate(
+    d: Var,
+    index_d: PrimExpr,
+    a: Var,
+    index_a: PrimExpr,
+    b: Var,
+    index_b: PrimExpr,
+    c: Var,
+    index_c: PrimExpr,
+    M: int,
+    N: int,
+    K: int,
+    transpose_a: bool = False,
+    transpose_b: bool = False,
+):
+    return call_intrin(
+        "handle",
+        "tirx.cooperative_tensor_multiply_accumulate",
+        d,
+        index_d,
+        a,
+        index_a,
+        b,
+        index_b,
+        c,
+        index_c,
+        M,
+        N,
+        K,
+        transpose_a,
+        transpose_b,
+    )
+
+
 def vectorlow(dtype, vec):
     """Get the low level half of the vector
 
diff --git a/python/tvm/tirx/script/builder/ir.py 
b/python/tvm/tirx/script/builder/ir.py
index 76f0397a8e..95f1fbea80 100644
--- a/python/tvm/tirx/script/builder/ir.py
+++ b/python/tvm/tirx/script/builder/ir.py
@@ -1965,6 +1965,10 @@ make_filled_simdgroup_matrix = 
_op_wrapper(_tir_op.make_filled_simdgroup_matrix)
 simdgroup_load = _op_wrapper(_tir_op.simdgroup_load)
 simdgroup_store = _op_wrapper(_tir_op.simdgroup_store)
 simdgroup_multiply_accumulate = 
_op_wrapper(_tir_op.simdgroup_multiply_accumulate)
+cooperative_tensor_fill = _op_wrapper(_tir_op.cooperative_tensor_fill)
+cooperative_tensor_load = _op_wrapper(_tir_op.cooperative_tensor_load)
+cooperative_tensor_store = _op_wrapper(_tir_op.cooperative_tensor_store)
+cooperative_tensor_multiply_accumulate = 
_op_wrapper(_tir_op.cooperative_tensor_multiply_accumulate)
 create_barriers = _op_wrapper(_tir_op.create_barriers)
 assume = _op_wrapper(_tir_op.assume)
 undef = _op_wrapper(_tir_op.undef)
@@ -2255,6 +2259,10 @@ __all__ = float_types + [
     "simdgroup_load",
     "simdgroup_store",
     "simdgroup_multiply_accumulate",
+    "cooperative_tensor_fill",
+    "cooperative_tensor_load",
+    "cooperative_tensor_store",
+    "cooperative_tensor_multiply_accumulate",
     "create_barriers",
     "mma_store",
     "mma_fill",
diff --git a/src/runtime/thread_storage_scope.h 
b/src/runtime/thread_storage_scope.h
index 313e4cfe48..0155aa1ffd 100644
--- a/src/runtime/thread_storage_scope.h
+++ b/src/runtime/thread_storage_scope.h
@@ -71,6 +71,8 @@ enum class StorageRank {
   kMMAMatrixC = 11,
   /*! \brief Metal SIMD group memory */
   kMetalSimdGroup = 12,
+  /*! \brief Metal cooperative_tensor memory (MetalPerformancePrimitives) */
+  kMetalCooperativeTensor = 13,
 };
 
 /*!
@@ -129,6 +131,8 @@ struct StorageScope {
         return "m16n8k8.matrixC" + tag;
       case StorageRank::kMetalSimdGroup:
         return "metal.simdgroup" + tag;
+      case StorageRank::kMetalCooperativeTensor:
+        return "metal.cooperative_tensor" + tag;
       default:
         TVM_FFI_THROW(InternalError) << "unknown storage scope";
         return "";
@@ -182,6 +186,9 @@ struct StorageScope {
     } else if (s.compare(0, 15, "metal.simdgroup") == 0) {
       r.rank = StorageRank::kMetalSimdGroup;
       r.tag = s.substr(15, std::string::npos);
+    } else if (s.compare(0, 24, "metal.cooperative_tensor") == 0) {
+      r.rank = StorageRank::kMetalCooperativeTensor;
+      r.tag = s.substr(24, std::string::npos);
     } else {
       TVM_FFI_THROW(InternalError) << "unknown storage scope " << s;
     }
diff --git a/src/tirx/op/builtin.cc b/src/tirx/op/builtin.cc
index 4355583d79..7ac487144f 100644
--- a/src/tirx/op/builtin.cc
+++ b/src/tirx/op/builtin.cc
@@ -345,6 +345,18 @@ TIR_DEFINE_BUILTIN_FUNC(simdgroup_store)
 TIR_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
 
+TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_fill)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_load)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_store)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_multiply_accumulate)
+    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure))
     .set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",

Reply via email to