This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c0abab769f [TIR][DLight] Enable SimdGroup op for Metal (#17112)
c0abab769f is described below
commit c0abab769ff152d87f84963f18a98d2f7c9bdf31
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Jun 24 21:24:32 2024 +0800
[TIR][DLight] Enable SimdGroup op for Metal (#17112)
---
include/tvm/tir/builtin.h | 44 ++-
python/tvm/dlight/gpu/matmul.py | 145 ++++++++++
python/tvm/script/ir_builder/tir/ir.py | 8 +
python/tvm/tir/__init__.py | 6 +
python/tvm/tir/op.py | 191 ++++++++++++-
python/tvm/tir/tensor_intrin/metal.py | 350 +++++++++++++++++++++++
src/runtime/thread_storage_scope.h | 7 +
src/target/source/codegen_metal.cc | 82 +++++-
src/target/source/codegen_metal.h | 3 +
src/tir/op/builtin.cc | 12 +
tests/python/dlight/test_gpu_matmul_tensorize.py | 283 +++++++++++++++++-
11 files changed, 1124 insertions(+), 7 deletions(-)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index 5836eb8ea9..120c1b71be 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -746,7 +746,7 @@ TVM_DLL const Op& create_barriers();
TVM_DLL const Op& mma_store();
/*!
- * \brief tvm intrinsic for zero-initalizing an MMA accumulation registor.
+ * \brief tvm intrinsic for zero-initializing an MMA accumulation register.
* For example, if each thread in a warp of size 32 has 8 elements from
the A matrix in
* m16xn8xk16 MMA in its registers, this intrinsic can be used to
zero-initialize its
* 4 accumulation registers.
@@ -758,6 +758,48 @@ TVM_DLL const Op& mma_store();
*/
TVM_DLL const Op& mma_fill();
+// Metal SimdGroup matrix intrinsics
+
+/*!
+ * \brief tvm intrinsic for initializing and simdgroup with given value.
+ * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep
shape as params,
+ * keeping the similar interface with Metal Spec.
+ *
+ * void make_filled_simdgroup_matrix(Var d, PrimExpr index, PrimExpr value,
+ * int col = 8, int row = 8);
+ */
+TVM_DLL const Op& make_filled_simdgroup_matrix();
+
+/*!
+ * \brief tvm intrinsic for loading data from device memory or threadgroup
memory to simdgroup.
+ * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep
shape as params,
+ * keeping the similar interface with Metal Spec.
+ *
+ * void simdgroup_load(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride,
+ int col = 8, int row = 8, bool transpose_matrix =
false);
+ */
+TVM_DLL const Op& simdgroup_load();
+
+/*!
+ * \brief tvm intrinsic for storing data from simdgroup to device memory or
threadgroup memory.
+ * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep
shape as params,
+ * keeping the similar interface with Metal Spec.
+ *
+ * void simdgroup_store(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride,
+ * int col = 8, int row = 8, bool transpose_matrix =
false);
+ */
+TVM_DLL const Op& simdgroup_store();
+
+/*!
+ * \brief tvm intrinsic for multiply and accumulate two matrices in simdgroup
+ * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep
shape as params,
+ * keeping the similar interface with Metal Spec.
+ *
+ * void simdgroup_mma(Var d, PrimExpr index_d, Var a, PrimExpr index_a,
+ * Var b, PrimExpr index_b, Var c, PrimExpr index_c);
+ */
+TVM_DLL const Op& simdgroup_multiply_accumulate();
+
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index f4ef1f5044..a5759941ca 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -313,6 +313,146 @@ def check_sm_version(arch: str) -> int:
return int(sm_version) if sm_version.isdigit() else -1
+class MetalMatmul(GPUScheduleRule):
+ """
+ The schedule rule for Metal matmul computation.
+ """
+
+ def apply( # pylint: disable=too-many-locals,missing-docstring
+ self,
+ func: tir.PrimFunc,
+ target: Target,
+ _: bool,
+ ) -> Optional[tir.Schedule]:
+ from tvm.tir.tensor_intrin.metal import ( # pylint:
disable=import-outside-toplevel
+ get_simdgroup_intrin_group,
+ )
+
+ if not isinstance(func, tir.PrimFunc) or not
self.is_target_available(target):
+ return None
+ sch = tir.Schedule(func)
+ root_block = analysis.get_root_block(sch)
+ blocks = sch.get_child_blocks(root_block)
+
+ reduction_blocks = get_reduction_blocks(sch, blocks)
+ if reduction_blocks is None:
+ return None
+
+ main_block = reduction_blocks[0]
+ block_stmt = sch.get(main_block)
+ index_maps = get_index_map(block_stmt)
+ if index_maps is None:
+ return None
+ matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+ # Step 0. Configs
+ block_size_x: int = 16
+ block_size_y: int = 16
+ block_size_k: int = 32
+ micro_size: int = 8
+ warp_size: int = 32
+ ty_len: int = 1
+ tz_len: int = 4
+ vector_size: int = 4
+
+ # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S,
J, K]
+ block = sch.reindex(main_block, ("read", 0))
+ sch.transform_layout(block, ("write", 0), a_index_map)
+ block = sch.reindex(main_block, ("read", 1))
+ sch.transform_layout(block, ("write", 0), b_index_map)
+ block = sch.reindex(main_block, ("write", 0))
+ sch.transform_layout(block, ("read", 0), c_index_map)
+ sch.transform_block_layout(main_block, matmul_index_map)
+
+ # Step 2. Padding for dynamic shape kernels
+ sch.pad_einsum(
+ main_block,
+ [
+ 1,
+ ty_len * block_size_x,
+ tz_len * block_size_y,
+ block_size_k,
+ ],
+ )
+
+ # Step 3. Schedule matmul to use simdgroup intrinsics
+ batch, i, j, k = sch.get_loops(main_block)
+ bx, ty, i0, i1 = sch.split(i, [None, ty_len, block_size_x //
micro_size, micro_size])
+ by, tz, j0, j1 = sch.split(j, [None, tz_len, block_size_y //
micro_size, micro_size])
+ k0, k1, k2 = sch.split(k, [None, block_size_k // micro_size,
micro_size])
+ sch.reorder(bx, by, ty, tz, k0, k1, i0, j0, i1, j1, k2)
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(by, "blockIdx.y")
+ sch.bind(batch, "blockIdx.z")
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tz, "threadIdx.z")
+
+ def fetch_to_shared(block, idx):
+ block_read = sch.cache_read(block, idx, "shared")
+ sch.compute_at(block_read, k0, preserve_unit_loops=True)
+ fused = sch.fuse(*sch.get_loops(block_read)[-2:])
+ _, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len,
warp_size, vector_size])
+
+ sch.bind(_tz, "threadIdx.z")
+ sch.bind(_ty, "threadIdx.y")
+ sch.bind(_tx, "threadIdx.x")
+ sch.vectorize(vec)
+
+ return block_read
+
+ a_g2s = fetch_to_shared(main_block, 0)
+ b_g2s = fetch_to_shared(main_block, 1)
+
+ auto_inline_producers(sch, a_g2s)
+ auto_inline_producers(sch, b_g2s)
+
+ # create read cache to load matrix from shared memory to wmma fragments
+ A_simdgroup = sch.cache_read(main_block, 0, "metal.simdgroup")
+ B_simdgroup = sch.cache_read(main_block, 1, "metal.simdgroup")
+ sch.compute_at(A_simdgroup, k1)
+ sch.compute_at(B_simdgroup, k1)
+
+ C_simd2s = sch.cache_write(main_block, 0, "metal.simdgroup")
+ C_s2g = sch.cache_write(C_simd2s, 0, "shared")
+ sch.reverse_compute_at(C_simd2s, tz, preserve_unit_loops=True)
+ sch.reverse_compute_at(C_s2g, by, preserve_unit_loops=True)
+
+ intrin_group = get_simdgroup_intrin_group(
+ load_scope="shared",
+ store_scope="shared",
+ dtype="float16",
+ trans_a=False,
+ trans_b=True,
+ )
+ sch.transform_layout(B_simdgroup, ("write", 0), lambda s, i, j: (s, j,
i))
+
+ def tensorize_block(block: tir.schedule.BlockRV, intrin: str):
+ *_, i, j = sch.get_loops(block)
+ io, ii = sch.split(i, [None, micro_size])
+ jo, ji = sch.split(j, [None, micro_size])
+ sch.reorder(io, jo, ii, ji)
+ sch.tensorize(ii, intrin)
+
+ C_init = sch.decompose_reduction(main_block, k0)
+ tensorize_block(A_simdgroup, intrin_group["load_a"])
+ tensorize_block(B_simdgroup, intrin_group["load_b"])
+ tensorize_block(C_simd2s, intrin_group["store"])
+ tensorize_block(C_init, intrin_group["init"])
+
+ *_, i, j, k = sch.get_loops(main_block)
+ sch.tensorize(i, intrin_group["compute"])
+
+ auto_inline_consumer_chain(sch, C_s2g)
+ fused = sch.fuse(*sch.get_loops(C_s2g)[-2:])
+ _, _tz, _ty, _tx, vec = sch.split(fused, [None, tz_len, ty_len,
warp_size, vector_size])
+ sch.bind(_tz, "threadIdx.z")
+ sch.bind(_ty, "threadIdx.y")
+ sch.bind(_tx, "threadIdx.x")
+ sch.vectorize(vec)
+
+ return sch
+
+
class MatmulTensorization(GPUScheduleRule):
"""
The schedule rule for float16 tensor core matmul computation.
@@ -848,6 +988,11 @@ class Matmul(GPUScheduleRule):
tensorize_sch = MatmulTensorization().apply(func, target,
_)
if tensorize_sch is not None:
return tensorize_sch
+ elif target.kind.name == "metal":
+ try:
+ return MetalMatmul().apply(func, target, _)
+ except: # pylint: disable=bare-except
+ pass
# Step 2. Get schedule config.
config = self.get_configs(target)
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index 18abc0ca5d..caefc6a6bc 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1887,6 +1887,10 @@ ptx_init_barrier_thread_count =
_op_wrapper(_tir_op.ptx_init_barrier_thread_coun
ptx_arrive_barrier = _op_wrapper(_tir_op.ptx_arrive_barrier)
ptx_arrive_barrier_expect_tx =
_op_wrapper(_tir_op.ptx_arrive_barrier_expect_tx)
ptx_wait_barrier = _op_wrapper(_tir_op.ptx_wait_barrier)
+make_filled_simdgroup_matrix =
_op_wrapper(_tir_op.make_filled_simdgroup_matrix)
+simdgroup_load = _op_wrapper(_tir_op.simdgroup_load)
+simdgroup_store = _op_wrapper(_tir_op.simdgroup_store)
+simdgroup_multiply_accumulate =
_op_wrapper(_tir_op.simdgroup_multiply_accumulate)
create_barriers = _op_wrapper(_tir_op.create_barriers)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
@@ -2177,6 +2181,10 @@ __all__ = [
"ptx_arrive_barrier",
"ptx_arrive_barrier_expect_tx",
"ptx_wait_barrier",
+ "make_filled_simdgroup_matrix",
+ "simdgroup_load",
+ "simdgroup_store",
+ "simdgroup_multiply_accumulate",
"create_barriers",
"mma_store",
"mma_fill",
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 0fee976eb1..5360ab2b96 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -73,6 +73,12 @@ from .op import (
ptx_wait_barrier,
create_barriers,
)
+from .op import (
+ make_filled_simdgroup_matrix,
+ simdgroup_load,
+ simdgroup_multiply_accumulate,
+ simdgroup_store,
+)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index 95a85ab77d..81d6604259 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=redefined-builtin, invalid-name
+# pylint: disable=redefined-builtin, invalid-name, too-many-arguments
"""Operators used in TIR expression."""
from typing import Any, Optional, Union
@@ -1567,6 +1567,195 @@ def create_barriers(barrier_count):
return call_intrin("", "tir.create_barriers", barrier_count)
+def make_filled_simdgroup_matrix(
+ d: Var,
+ index: PrimExpr,
+ value: PrimExpr,
+ col: int = 8,
+ row: int = 8,
+):
+ """Create a filled SIMDGroup matrix
+
+ Parameters
+ ----------
+ d : var
+ The simdgroup var
+
+ index : PrimExpr
+ The index of the matrix.
+
+ value : PrimExpr
+ The value to fill.
+
+ col : int
+ The number of columns.
+
+ row : int
+ The number of rows.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin("handle", "tir.make_filled_simdgroup_matrix", d, index,
value, col, row)
+
+
+def simdgroup_load(
+ d: Var,
+ index: PrimExpr,
+ ptr: PrimExpr,
+ stride: PrimExpr,
+ col: int = 8,
+ row: int = 8,
+ transpose_matrix: bool = False,
+):
+ """Load data from device memory or threadgroup memory to simdgroup
+
+ Parameters
+ ----------
+ d : var
+ The simdgroup var
+
+ index : PrimExpr
+ The index of the matrix.
+
+ ptr : PrimExpr
+ The pointer.
+
+ stride : PrimExpr
+ The stride.
+
+ col : int
+ The number of columns.
+
+ row : int
+ The number of rows.
+
+ transpose_matrix : bool
+ Whether to transpose the matrix.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(
+ "handle",
+ "tir.simdgroup_load",
+ d,
+ index,
+ ptr,
+ stride,
+ col,
+ row,
+ transpose_matrix,
+ )
+
+
+def simdgroup_store(
+ d: PrimExpr,
+ index: PrimExpr,
+ ptr: PrimExpr,
+ stride: PrimExpr,
+ col: int = 8,
+ row: int = 8,
+ transpose_matrix: bool = False,
+):
+ """Store data from simdgroup to device memory or threadgroup memory
+
+ Parameters
+ ----------
+ d : PrimExpr
+ The SIMDGroup.
+
+ index : PrimExpr
+ The index of the matrix.
+
+ ptr : PrimExpr
+ The pointer.
+
+ stride : PrimExpr
+ The stride.
+
+ col : int
+ The number of columns.
+
+ row : int
+ The number of rows.
+
+
+ transpose_matrix : bool
+ Whether to transpose the matrix.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(
+ "handle", "tir.simdgroup_store", d, index, ptr, stride, col, row,
transpose_matrix
+ )
+
+
+def simdgroup_multiply_accumulate(
+ d: Var,
+ index_d: PrimExpr,
+ a: Var,
+ index_a: PrimExpr,
+ b: Var,
+ index_b: PrimExpr,
+ c: Var,
+ index_c: PrimExpr,
+):
+ """Multiply and accumulate two matrices in simdgroup
+ i.e. d = a * b + c
+
+ Parameters
+ ----------
+ d : Var
+ The destination matrix.
+
+ index_d : PrimExpr
+ The index of the destination matrix.
+
+ a : Var
+ The first matrix.
+
+ index_a : PrimExpr
+ The index of the first matrix.
+
+ b : Var
+ The second matrix.
+
+ index_b : PrimExpr
+ The index of the second matrix.
+
+ c : Var
+ The third matrix.
+
+ index_c : PrimExpr
+ The index of the third matrix.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(
+ "handle",
+ "tir.simdgroup_multiply_accumulate",
+ d,
+ index_d,
+ a,
+ index_a,
+ b,
+ index_b,
+ c,
+ index_c,
+ )
+
+
def vectorlow(dtype, vec):
"""Get the low level half of the vector
diff --git a/python/tvm/tir/tensor_intrin/metal.py
b/python/tvm/tir/tensor_intrin/metal.py
new file mode 100644
index 0000000000..be34a9e266
--- /dev/null
+++ b/python/tvm/tir/tensor_intrin/metal.py
@@ -0,0 +1,350 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,missing-function-docstring,unused-variable
+"""Intrinsics for tensorization on Apple GPU."""
+from typing import Dict, Literal, Tuple
+
+from tvm.script import tir as T
+from tvm.tir import Buffer, PrimExpr, PrimFunc, TensorIntrin
+
+######## simdgroup matrix intrinsics ########
+
+
+def get_simdgroup_index(buffer: Buffer, stride: PrimExpr, col: int, row: int):
+ """Compute simdgroup index using elem_offset of the buffer"""
+
+ # NOTE: Need further check the usage between `col`` and `row`
+ # Currently, Metal only supports 8x8, which means the values of `col` and
`row` are the same
+ frag_index_m = buffer.elem_offset // stride // col
+ frag_index_n = buffer.elem_offset % stride // row
+
+ num_fragments_per_row = stride // row
+ return frag_index_m * num_fragments_per_row + frag_index_n
+
+
+def get_make_filled_simdgroup_matrix_intrin(
+ dtype: str, col: int = 8, row: int = 8
+) -> Tuple[PrimFunc, PrimFunc]:
+ @T.prim_func
+ def desc(a: T.handle) -> None:
+ A = T.match_buffer(a, (col, row), dtype, scope="metal.simdgroup",
offset_factor=1)
+ with T.block("root"):
+ T.reads()
+ T.writes(A[0:col, 0:row])
+ for i, j in T.grid(col, row):
+ with T.block("init"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ A[vi, vj] = T.float32(0)
+
+ @T.prim_func
+ def impl(a: T.handle) -> None:
+ d0, d1 = T.int32(), T.int32()
+ A = T.match_buffer(
+ a, (col, row), dtype, scope="metal.simdgroup", strides=[d1, d0],
offset_factor=1
+ )
+ with T.block("root"):
+ T.reads()
+ T.writes(A[0:col, 0:row])
+ T.make_filled_simdgroup_matrix(
+ A.data,
+ index=get_simdgroup_index(A, d1, col, row),
+ value=T.float32(0),
+ col=col,
+ row=row,
+ )
+
+ return desc, impl
+
+
+def get_simdgroup_load_intrin(
+ dtype: str,
+ scope: Literal["global", "shared"],
+ col: int = 8,
+ row: int = 8,
+ transpose_matrix: bool = False,
+) -> Tuple[PrimFunc, PrimFunc]:
+ align = col * row
+
+ @T.prim_func
+ def desc(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (col, row), dtype, align=align, scope=scope,
offset_factor=1)
+ C = T.match_buffer(
+ c, (col, row), dtype, align=align, scope="metal.simdgroup",
offset_factor=1
+ )
+ with T.block("root"):
+ T.reads(A[0:col, 0:row])
+ T.writes(C[0:col, 0:row])
+ for i, j in T.grid(col, row):
+ with T.block("load"):
+ vii, vjj = T.axis.remap("SS", [i, j])
+ if transpose_matrix:
+ # C[vii, vjj] = A[vjj, vii]
+ C[vjj, vii] = A[vii, vjj]
+ else:
+ C[vii, vjj] = A[vii, vjj]
+
+ @T.prim_func
+ def impl(a: T.handle, c: T.handle) -> None:
+ s0, s1, d0, d1 = T.int32(), T.int32(), T.int32(), T.int32()
+ A = T.match_buffer(
+ a,
+ (col, row),
+ dtype,
+ align=align,
+ scope=scope,
+ strides=[s1, s0],
+ offset_factor=1,
+ )
+ C = T.match_buffer(
+ c,
+ (col, row),
+ dtype,
+ align=align,
+ scope="metal.simdgroup",
+ strides=[d1, d0],
+ offset_factor=1,
+ )
+ with T.block("root"):
+ T.reads(A[0:col, 0:row])
+ T.writes(C[0:col, 0:row])
+ T.simdgroup_load(
+ C.data,
+ index=get_simdgroup_index(C, d1, col, row),
+ ptr=A.access_ptr("r"),
+ stride=s1,
+ col=col,
+ row=row,
+ transpose_matrix=transpose_matrix,
+ )
+
+ return desc, impl
+
+
+def get_simdgroup_store_intrin(
+ dtype: str,
+ scope: Literal["global", "shared"],
+ col: int = 8,
+ row: int = 8,
+ transpose_matrix: bool = False,
+) -> Tuple[PrimFunc, PrimFunc]:
+ align = col * row
+
+ @T.prim_func
+ def desc(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(
+ a, (col, row), dtype, align=align, scope="metal.simdgroup",
offset_factor=1
+ )
+ C = T.match_buffer(c, (col, row), dtype, align=align, scope=scope,
offset_factor=1)
+ with T.block("root"):
+ T.reads(A[0:col, 0:row])
+ T.writes(C[0:col, 0:row])
+ for i, j in T.grid(col, row):
+ with T.block("store"):
+ vii, vjj = T.axis.remap("SS", [i, j])
+ if transpose_matrix:
+ C[vjj, vii] = A[vii, vjj]
+ else:
+ C[vii, vjj] = A[vii, vjj]
+
+ @T.prim_func
+ def impl(a: T.handle, c: T.handle) -> None:
+ s0, s1, d0, d1 = T.int32(), T.int32(), T.int32(), T.int32()
+ A = T.match_buffer(
+ a,
+ (col, row),
+ dtype,
+ align=align,
+ scope="metal.simdgroup",
+ strides=[s1, s0],
+ offset_factor=1,
+ )
+ C = T.match_buffer(
+ c, (col, row), dtype, align=align, scope=scope, strides=[d1, d0],
offset_factor=1
+ )
+ with T.block("root"):
+ T.reads(A[0:col, 0:row])
+ T.writes(C[0:col, 0:row])
+ T.simdgroup_store(
+ A.data,
+ index=get_simdgroup_index(A, s1, col, row),
+ ptr=C.access_ptr("w"),
+ stride=d1,
+ col=col,
+ row=row,
+ transpose_matrix=transpose_matrix,
+ )
+
+ return desc, impl
+
+
+def get_simdgroup_multiply_accumulate_intrin(
+ m_dim: int, n_dim: int, k_dim: int, dtype: str
+) -> Tuple[PrimFunc, PrimFunc]:
+ @T.prim_func
+ def desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (m_dim, k_dim), dtype, scope="metal.simdgroup",
offset_factor=1)
+ B = T.match_buffer(b, (k_dim, n_dim), dtype, scope="metal.simdgroup",
offset_factor=1)
+ C = T.match_buffer(c, (m_dim, n_dim), dtype, scope="metal.simdgroup",
offset_factor=1)
+ with T.block("root"):
+ T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:k_dim,
0:n_dim])
+ T.writes(C[0:m_dim, 0:n_dim])
+ for i, j, k in T.grid(m_dim, n_dim, k_dim):
+ with T.block(""):
+ vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
+ C[vii, vjj] += A[vii, vkk] * B[vkk, vjj]
+
+ @T.prim_func
+ def impl(a: T.handle, b: T.handle, c: T.handle) -> None:
+ a0, a1, b0, b1, c0, c1 = T.int32(), T.int32(), T.int32(), T.int32(),
T.int32(), T.int32()
+ A = T.match_buffer(
+ a, (m_dim, k_dim), dtype, scope="metal.simdgroup", strides=[a1,
a0], offset_factor=1
+ )
+ B = T.match_buffer(
+ b, (k_dim, n_dim), dtype, scope="metal.simdgroup", strides=[b1,
b0], offset_factor=1
+ )
+ C = T.match_buffer(
+ c, (m_dim, n_dim), dtype, scope="metal.simdgroup", strides=[c1,
c0], offset_factor=1
+ )
+ with T.block("root"):
+ T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:k_dim,
0:n_dim])
+ T.writes(C[0:m_dim, 0:n_dim])
+ T.simdgroup_multiply_accumulate(
+ C.data,
+ get_simdgroup_index(C, c1, m_dim, n_dim),
+ A.data,
+ get_simdgroup_index(A, a1, m_dim, k_dim),
+ B.data,
+ get_simdgroup_index(B, b1, k_dim, n_dim),
+ C.data,
+ get_simdgroup_index(C, c1, m_dim, n_dim),
+ )
+
+ return desc, impl
+
+
+# Make filled simdgroup matrix intrinsics
+
+SIMDGROUP_MAKE_FILLED_8x8x8_f16_INTRIN = "simdgroup_make_filled_8x8x8_f16"
+TensorIntrin.register(
+ SIMDGROUP_MAKE_FILLED_8x8x8_f16_INTRIN,
+ *get_make_filled_simdgroup_matrix_intrin("float16", 8, 8),
+)
+
+SIMDGROUP_FILLED_8x8x8_f32_INTRIN = "simdgroup_fill_8x8x8_f32"
+TensorIntrin.register(
+ SIMDGROUP_FILLED_8x8x8_f32_INTRIN,
*get_make_filled_simdgroup_matrix_intrin("float32", 8, 8)
+)
+
+SIMDGROUP_FILLED_8x8x8_bf16_INTRIN = "simdgroup_fill_8x8x8_bf16"
+TensorIntrin.register(
+ SIMDGROUP_FILLED_8x8x8_bf16_INTRIN,
*get_make_filled_simdgroup_matrix_intrin("bfloat16", 8, 8)
+)
+
+# Load intrinsics
+
+SIMDGROUP_LOAD_8x8x8_f16_SHARED_INTRIN = "simdgroup_load_8x8x8_f16_shared"
+TensorIntrin.register(
+ SIMDGROUP_LOAD_8x8x8_f16_SHARED_INTRIN,
+ *get_simdgroup_load_intrin("float16", "shared", 8, 8, False),
+)
+
+SIMDGROUP_LOAD_8x8x8_f16_SHARED_TRANS_INTRIN =
"simdgroup_load_8x8x8_f16_shared_trans"
+TensorIntrin.register(
+ SIMDGROUP_LOAD_8x8x8_f16_SHARED_TRANS_INTRIN,
+ *get_simdgroup_load_intrin("float16", "shared", 8, 8, True),
+)
+
+# Store intrinsics
+
+SIMDGROUP_STORE_8x8x8_f16_GLOBAL_INTRIN = "simdgroup_store_8x8x8_f16_global"
+TensorIntrin.register(
+ SIMDGROUP_STORE_8x8x8_f16_GLOBAL_INTRIN,
+ *get_simdgroup_store_intrin("float16", "global", 8, 8, False),
+)
+
+SIMDGROUP_STORE_8x8x8_f16_SHARED_INTRIN = "simdgroup_store_8x8x8_f16_shared"
+TensorIntrin.register(
+ SIMDGROUP_STORE_8x8x8_f16_SHARED_INTRIN,
+ *get_simdgroup_store_intrin("float16", "shared", 8, 8, False),
+)
+# Multiply accumulate intrinsics
+
+SIMDGROUP_MULTI_ACC_8x8x8_f16_INTRIN =
"simdgroup_multiply_accumulate_8x8x8_f16"
+TensorIntrin.register(
+ SIMDGROUP_MULTI_ACC_8x8x8_f16_INTRIN,
+ *get_simdgroup_multiply_accumulate_intrin(8, 8, 8, "float16"),
+)
+
+
+def get_simdgroup_intrin_group(
+ load_scope: Literal["shared"],
+ store_scope: Literal["global", "shared"],
+ dtype: str,
+ trans_a: bool = False,
+ trans_b: bool = False,
+) -> Dict[str, str]:
+ """Get a group of intrinsics for tensorization on Apple GPU.
+
+ Parameters
+ ----------
+ load_scope : Literal["shared"]
+ The memory scope of the input buffer.
+
+ store_scope : Literal["global", "shared"]
+ The memory scope of the result buffer.
+
+ dtype : str
+ The data type of the input and output buffers.
+
+ trans_a : bool
+ Whether the input matrix A is transposed.
+
+ trans_b : bool
+ Whether the input matrix B is transposed.
+
+ Returns
+ -------
+ ret : Dict[str, str]
+ A group of tensor intrinsics.
+ """
+ assert load_scope in ["shared"]
+ assert store_scope in ["global", "shared"]
+ assert dtype in ["float16", "bfloat16", "float32"]
+
+ shape = "8x8x8"
+ dtype = "f16" if dtype == "float16" else "bf16" if dtype == "bfloat16"
else "f32"
+ trans_a = "_trans" if trans_a else ""
+ trans_b = "_trans" if trans_b else ""
+
+ # e.g. simdgroup_load_8x8x8_f16_shared
+ load_a_intrin = f"simdgroup_load_{shape}_{dtype}_{load_scope}{trans_a}"
+ # e.g. simdgroup_load_8x8x8_f16_shared_trans
+ load_b_intrin = f"simdgroup_load_{shape}_{dtype}_{load_scope}{trans_b}"
+ # e.g. simdgroup_multiply_accumulate_8x8x8_f16
+ compute_intrin = f"simdgroup_multiply_accumulate_{shape}_{dtype}"
+ # e.g. simdgroup_make_filled_8x8x8_f16
+ init_intrin = f"simdgroup_make_filled_{shape}_{dtype}"
+ # e.g. simdgroup_store_8x8x8_f16_global
+ store_intrin = f"simdgroup_store_{shape}_{dtype}_{store_scope}"
+
+ return {
+ "init": init_intrin,
+ "load_a": load_a_intrin,
+ "load_b": load_b_intrin,
+ "compute": compute_intrin,
+ "store": store_intrin,
+ }
diff --git a/src/runtime/thread_storage_scope.h
b/src/runtime/thread_storage_scope.h
index 747b905812..d1af2cb701 100644
--- a/src/runtime/thread_storage_scope.h
+++ b/src/runtime/thread_storage_scope.h
@@ -70,6 +70,8 @@ enum class StorageRank {
kMMAMatrixB = 10,
/*! \brief mma scope memory of accumulator */
kMMAMatrixC = 11,
+ /*! \brief Metal SIMD group memory */
+ kMetalSimdGroup = 12,
};
/*!
@@ -126,6 +128,8 @@ struct StorageScope {
return "m16n8k8.matrixB" + tag;
case StorageRank::kMMAMatrixC:
return "m16n8k8.matrixC" + tag;
+ case StorageRank::kMetalSimdGroup:
+ return "metal.simdgroup" + tag;
default:
LOG(FATAL) << "unknown storage scope";
}
@@ -175,6 +179,9 @@ struct StorageScope {
} else if (s.compare(0, 15, "m16n8k8.matrixC") == 0) {
r.rank = StorageRank::kMMAMatrixC;
r.tag = s.substr(15, std::string::npos);
+ } else if (s.compare(0, 15, "metal.simdgroup") == 0) {
+ r.rank = StorageRank::kMetalSimdGroup;
+ r.tag = s.substr(15, std::string::npos);
} else {
LOG(FATAL) << "unknown storage scope " << s;
}
diff --git a/src/target/source/codegen_metal.cc
b/src/target/source/codegen_metal.cc
index e729af417c..2908514988 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -25,10 +25,10 @@
#include <tvm/tir/transform.h>
#include <algorithm>
+#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>
-#include <vector>
#include "../../runtime/metal/metal_module.h"
#include "../../runtime/thread_storage_scope.h"
@@ -262,6 +262,9 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os)
{ // NOLINT(*)
os << lanes;
return;
}
+ } else if (t.is_bfloat16()) {
+ os << "bfloat";
+ return;
}
LOG(FATAL) << "Cannot convert type " << t << " to Metal type";
}
@@ -296,9 +299,43 @@ void CodeGenMetal::PrintStorageScope(const std::string&
scope, std::ostream& os)
os << "device ";
} else if (scope == "shared") {
os << "threadgroup ";
- } else {
+ } else if (scope == "local") {
os << "thread ";
+ } else {
+ LOG(FATAL) << "Unknown storage scope `" << scope << "`";
+ }
+}
+
+void CodeGenMetal::VisitStmt_(const AllocateNode* op) {
+ ICHECK(!is_zero(op->condition));
+ std::string vid = AllocVarID(op->buffer_var.get());
+
+ this->PrintIndent();
+ size_t constant_size = op->ConstantAllocationSize();
+ ICHECK_GT(constant_size, 0) << "Can only handle constant size stack
allocation for now";
+
+ auto scope = GetPtrStorageScope(op->buffer_var);
+ alloc_storage_scope_[op->buffer_var.get()] = scope;
+ if (scope == "metal.simdgroup") {
+ ICHECK(op->dtype == DataType::Float(16) || op->dtype ==
DataType::Float(32) ||
+ op->dtype == DataType::BFloat(16))
+ << "Only float16, float32, and bfloat16 are supported, but got " <<
op->dtype;
+ ICHECK(constant_size % 64 == 0)
+ << "Only 8x8 matrix is supported, but got " << constant_size << "
bytes\n";
+
+ std::ostringstream dtype_os;
+ PrintType(op->dtype, dtype_os);
+ std::string dtype_str = dtype_os.str();
+ simdgroup_dtype_[op->buffer_var.get()] = dtype_str;
+ stream << "simdgroup_" << dtype_str << "8x8 " << vid << '[' <<
constant_size / 64 << "];\n";
+ } else {
+ PrintStorageScope(scope, stream);
+ PrintType(op->dtype, stream);
+ stream << ' ' << vid << '[' << constant_size << "];\n";
}
+
+ RegisterHandleType(op->buffer_var.get(), op->dtype);
+ this->PrintStmt(op->body);
}
void CodeGenMetal::VisitExpr_(const SelectNode* op, std::ostream& os) { //
NOLINT(*)
@@ -322,7 +359,46 @@ void CodeGenMetal::VisitExpr_(const CallNode* op,
std::ostream& os) { // NOLINT
CHECK(!op->op.as<GlobalVarNode>())
<< "CodegenMetal does not support inter-function calls, "
<< "but expression " << GetRef<Call>(op) << " calls PrimFunc " << op->op;
- if (op->op.same_as(builtin::reinterpret())) {
+ auto f_check_simdgroup_shape = [](PrimExpr col, PrimExpr row) {
+ ICHECK(col->IsInstance<IntImmNode>() && row->IsInstance<IntImmNode>())
+ << "Only constant shape is supported for simdgroup matrix, but got "
<< col << "x" << row;
+ int col_val = col.as<IntImmNode>()->value;
+ int row_val = row.as<IntImmNode>()->value;
+ ICHECK(col_val == 8 && row_val == 8)
+ << "Only 8x8 matrix is supported, but got " << col_val << "x" <<
row_val;
+ };
+ if (op->op.same_as(builtin::make_filled_simdgroup_matrix())) {
+ ICHECK_EQ(op->args.size(), 5);
+ Var var = runtime::Downcast<Var>(op->args[0]);
+ // Get the data type of the simdgroup matrix
+ auto it = simdgroup_dtype_.find(var.get());
+ ICHECK(it != simdgroup_dtype_.end())
+ << "Cannot find variable allocation for simdgroup: " << var;
+ const std::string& dtype_str = it->second;
+ f_check_simdgroup_shape(op->args[3], op->args[4]);
+ os << PrintExpr(var) << "[" << PrintExpr(op->args[1]) << "] =
make_filled_simdgroup_matrix<"
+ << dtype_str << ", " << PrintExpr(op->args[3]) << ", " <<
PrintExpr(op->args[4]) << ">("
+ << PrintExpr(op->args[2]) << ")";
+ } else if (op->op.same_as(builtin::simdgroup_load())) {
+ ICHECK_EQ(op->args.size(), 7);
+ f_check_simdgroup_shape(op->args[4], op->args[5]);
+ os << "simdgroup_load(" << PrintExpr(op->args[0]) << "[" <<
PrintExpr(op->args[1]) << "], "
+ << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, "
+ << PrintExpr(op->args[6]) << ")";
+ } else if (op->op.same_as(builtin::simdgroup_store())) {
+ ICHECK_EQ(op->args.size(), 7);
+ f_check_simdgroup_shape(op->args[4], op->args[5]);
+ os << "simdgroup_store(" << PrintExpr(op->args[0]) << "[" <<
PrintExpr(op->args[1]) << "], "
+ << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, "
+ << PrintExpr(op->args[6]) << ")";
+ } else if (op->op.same_as(builtin::simdgroup_multiply_accumulate())) {
+ ICHECK_EQ(op->args.size(), 8);
+ os << "simdgroup_multiply_accumulate(" //
+ << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " //
+ << PrintExpr(op->args[2]) << "[" << PrintExpr(op->args[3]) << "], " //
+ << PrintExpr(op->args[4]) << "[" << PrintExpr(op->args[5]) << "], " //
+ << PrintExpr(op->args[6]) << "[" << PrintExpr(op->args[7]) << "])";
+ } else if (op->op.same_as(builtin::reinterpret())) {
// generate as_type<TYPE>(ARG)
os << "(as_type<";
this->PrintType(op->dtype, os);
diff --git a/src/target/source/codegen_metal.h
b/src/target/source/codegen_metal.h
index 9cff3211ce..9bc0e15d15 100644
--- a/src/target/source/codegen_metal.h
+++ b/src/target/source/codegen_metal.h
@@ -27,6 +27,7 @@
#include <tvm/target/codegen.h>
#include <string>
+#include <unordered_map>
#include "codegen_c.h"
@@ -50,6 +51,7 @@ class CodeGenMetal final : public CodeGenC {
// print store of single element.
void PrintVecElemStore(const std::string& vec, DataType t, int i, const
std::string& value) final;
// overload visitor
+ void VisitStmt_(const AllocateNode* op) final; //
NOLINT(*)
void VisitExpr_(const SelectNode* op, std::ostream& os) final; //
NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; //
NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; //
NOLINT(*)
@@ -59,6 +61,7 @@ class CodeGenMetal final : public CodeGenC {
using CodeGenC::PrintType;
private:
+ std::unordered_map<const VarNode*, std::string> simdgroup_dtype_;
int thread_index_bits_{32};
int thread_work_dim_{0};
Target target_;
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 67d01aa923..0404fd2823 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -328,6 +328,18 @@ TIR_DEFINE_BUILTIN_FUNC(mma_fill)
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst));
+TIR_DEFINE_BUILTIN_FUNC(make_filled_simdgroup_matrix)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(simdgroup_load)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(simdgroup_store)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py
b/tests/python/dlight/test_gpu_matmul_tensorize.py
index 095447766e..59ccfec55c 100644
--- a/tests/python/dlight/test_gpu_matmul_tensorize.py
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -14,12 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=missing-docstring
+# pylint: disable=missing-docstring, unused-variable, invalid-name
+# flake8: noqa: E501
import pytest
import tvm.testing
from tvm import dlight as dl
-from tvm.script import ir as I
from tvm.script import tir as T
from tvm.target import Target
@@ -698,5 +698,284 @@ class TestMatmulInt8Tensorize3d2dDyn(BaseBeforeAfter):
# fmt: on
+class MetalBeforeAfter(tvm.testing.CompareBeforeAfter):
+ @pytest.fixture
+ def transform(self):
+ def transform(mod):
+ with Target("metal"):
+ return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
+
+ return transform
+
+
+class TestMatmulMetal(MetalBeforeAfter):
+ # fmt: off
+ @T.prim_func(private=True)
+ def before(
+ var_A: T.handle,
+ B: T.Buffer((28672, 4096), "float16"),
+ var_C: T.handle,
+ ):
+ batch_size = T.int32()
+ A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16")
+ C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16")
+ for i0, i1, i2, k in T.grid(batch_size, 1, 28672, 4096):
+ with T.block("C"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ T.writes(C[v_i0, v_i1, v_i2])
+ with T.init():
+ C[v_i0, v_i1, v_i2] = T.float16(0)
+ C[v_i0, v_i1, v_i2] += A[v_i0, v_i1, v_k] * B[v_i2, v_k]
+
+ @T.prim_func
+ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"),
var_C: T.handle):
+ T.func_attr({"tir.is_scheduled": 1})
+ batch_size = T.int32()
+ A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16")
+ C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16")
+ # with T.block("root"):
+ A_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 *
16, 4096), "float16", scope="shared")
+ B_reindex_shared = T.alloc_buffer((1, 28672, 4096), "float16",
scope="shared")
+ A_reindex_pad_shared_metal_simdgroup = T.alloc_buffer((1, (batch_size
+ 15) // 16 * 16, 4096), "float16", scope="metal.simdgroup")
+ B_reindex_shared_metal_simdgroup = T.alloc_buffer((1, 4096, 28672),
"float16", scope="metal.simdgroup")
+ C_reindex_pad_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15)
// 16 * 16, 28672), "float16", scope="metal.simdgroup")
+ C_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 *
16, 28672), "float16", scope="shared")
+ for ax0 in T.thread_binding(1, thread="blockIdx.z"):
+ for ax1_0 in T.thread_binding((batch_size + 15) // 16,
thread="blockIdx.x"):
+ for ax2_0 in T.thread_binding(448, thread="blockIdx.y"):
+ for ax1_1 in T.thread_binding(1, thread="threadIdx.y"):
+ for ax2_1 in T.thread_binding(4, thread="threadIdx.z"):
+ for ax1_2_init, ax2_2_init, ax1_3_init_0,
ax2_3_init_0 in T.grid(2, 2, 1, 1):
+ with T.block("C_init_o"):
+ v0_o = T.axis.spatial(1, ax0)
+ v1_o = T.axis.spatial(2 * ((batch_size +
15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2_init + ax1_3_init_0)
+ v2_o = T.axis.spatial(3584, ax2_0 * 8 +
ax2_1 * 2 + ax2_2_init + ax2_3_init_0)
+ T.reads()
+ T.writes(C_reindex_pad_metal_simdgroup[0,
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+ A_1 =
T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o *
8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"),
scope="metal.simdgroup", offset_factor=1)
+ T.make_filled_simdgroup_matrix(A_1.data,
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) +
A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8)
+ for ax3_0 in range(128):
+ for ax0_1, ax1_ax2_fused_0 in T.grid(1, 1):
+ for ax1_ax2_fused_1 in T.thread_binding(4,
thread="threadIdx.z"):
+ for ax1_ax2_fused_2 in
T.thread_binding(1, thread="threadIdx.y"):
+ for ax1_ax2_fused_3 in
T.thread_binding(32, thread="threadIdx.x"):
+ for ax1_ax2_fused_4 in
T.vectorized(4):
+ with
T.block("A_reindex_pad_shared"):
+ v0 = T.axis.spatial(1,
ax0_1)
+ v1 =
T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 *
512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 +
ax1_ax2_fused_4) // 32)
+ v2 =
T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 *
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32)
+ T.reads(A[v1, 0, v2])
+
T.writes(A_reindex_pad_shared[v0, v1, v2])
+
A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, A[v1, 0,
v2], T.float16(0))
+ for ax0_1, ax1_ax2_fused_0 in T.grid(1, 4):
+ for ax1_ax2_fused_1 in T.thread_binding(4,
thread="threadIdx.z"):
+ for ax1_ax2_fused_2 in
T.thread_binding(1, thread="threadIdx.y"):
+ for ax1_ax2_fused_3 in
T.thread_binding(32, thread="threadIdx.x"):
+ for ax1_ax2_fused_4 in
T.vectorized(4):
+ with
T.block("B_reindex_shared"):
+ v0 = T.axis.spatial(1,
ax0_1)
+ v1 =
T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 *
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32)
+ v2 =
T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 *
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32)
+ T.reads(B[v1, v2])
+
T.writes(B_reindex_shared[v0, v1, v2])
+ B_reindex_shared[v0,
v1, v2] = B[v1, v2]
+ for ax3_1 in range(4):
+ for ax0_0, ax1_0_1 in T.grid(2, 1):
+ with
T.block("A_reindex_pad_shared_metal.simdgroup_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(2 *
((batch_size + 15) // 16), ax1_0 * 2 + ax0_0)
+ v2_o = T.axis.spatial(512, ax3_0 *
4 + ax3_1 + ax1_0_1)
+ T.reads(A_reindex_pad_shared[v0_o,
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+
T.writes(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o
* 8:v2_o * 8 + 8])
+ A_1 =
T.match_buffer(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o
* 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared",
offset_factor=1)
+ C_1 =
T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 +
8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"),
scope="metal.simdgroup", offset_factor=1)
+ T.simdgroup_load(C_1.data,
C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) +
C_1.elem_offset % C_1.strides[0] // 8,
T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset,
A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False))
+ for ax0_0, ax1_0_1 in T.grid(2, 1):
+ with
T.block("B_reindex_shared_metal.simdgroup_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(3584, ax2_0
* 8 + ax2_1 * 2 + ax0_0)
+ v2_o = T.axis.spatial(512, ax3_0 *
4 + ax3_1 + ax1_0_1)
+ T.reads(B_reindex_shared[v0_o,
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+
T.writes(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o *
8:v1_o * 8 + 8])
+ A_1 =
T.match_buffer(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8
+ 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared",
offset_factor=1)
+ C_1 =
T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8,
v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"),
scope="metal.simdgroup", offset_factor=1)
+ T.simdgroup_load(C_1.data,
C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) +
C_1.elem_offset % C_1.strides[0] // 8,
T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset,
A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True))
+ for ax1_2, ax2_2 in T.grid(2, 2):
+ with T.block("C_update_o"):
+ v0_o = T.axis.spatial(1, ax0)
+ v1_o = T.axis.spatial(2 *
((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2)
+ v2_o = T.axis.spatial(3584, ax2_0
* 8 + ax2_1 * 2 + ax2_2)
+ v3_o = T.axis.reduce(512, ax3_0 *
4 + ax3_1)
+
T.reads(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o *
8 + 8], A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o *
8:v3_o * 8 + 8], B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8,
v2_o * 8:v2_o * 8 + 8])
+
T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o
* 8 + 8])
+ A_1 =
T.match_buffer(A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8,
v3_o * 8:v3_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"),
scope="metal.simdgroup", offset_factor=1)
+ B_1 =
T.match_buffer(B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o
* 8:v2_o * 8 + 8], (8, 8), "float16", strides=("B_s0", "B_s1"),
scope="metal.simdgroup", offset_factor=1)
+ C_1 =
T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o *
8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"),
scope="metal.simdgroup", offset_factor=1)
+
T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] //
8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data,
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) +
A_1.elem_offset % A_1.strides[0] // 8, B_1.data, B_1.elem_offset //
B_1.strides[0] // 8 * (B_1.strides[0] // 8) + B_1.elem_offset % B_1.strides[0]
// 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8)
+ [...]
+ for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2):
+ with
T.block("C_reindex_pad_metal.simdgroup_o"):
+ v0_o = T.axis.spatial(1, ax0_1)
+ v1_o = T.axis.spatial(2 * ((batch_size +
15) // 16), ax1_0 * 2 + ax1_0_1)
+ v2_o = T.axis.spatial(3584, ax2_0 * 8 +
ax2_1 * 2 + ax2_0_1)
+
T.reads(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o *
8:v2_o * 8 + 8])
+ T.writes(C_reindex_pad_shared[v0_o, v1_o *
8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+ A_1 =
T.match_buffer(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o
* 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"),
scope="metal.simdgroup", offset_factor=1)
+ C_1 =
T.match_buffer(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o
* 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="shared",
offset_factor=1)
+ T.simdgroup_store(A_1.data,
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) +
A_1.elem_offset % A_1.strides[0] // 8,
T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset,
C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False))
+ for ax0_1, ax1_ax2_fused_0 in T.grid(1, 2):
+ for ax1_ax2_fused_1 in T.thread_binding(4,
thread="threadIdx.z"):
+ for ax1_ax2_fused_2 in T.thread_binding(1,
thread="threadIdx.y"):
+ for ax1_ax2_fused_3 in T.thread_binding(32,
thread="threadIdx.x"):
+ for ax1_ax2_fused_4 in T.vectorized(4):
+ with T.block("C_reindex_pad_shared"):
+ v0 = T.axis.spatial(1, ax0_1)
+ v1 = T.axis.spatial((batch_size +
15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 +
ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
+ v2 = T.axis.spatial(28672, ax2_0 *
64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 +
ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
+ T.reads(C_reindex_pad_shared[v0,
v1, v2])
+ T.writes(C[v1, 0, v2])
+ if v1 < batch_size:
+ C[v1, 0, v2] =
C_reindex_pad_shared[v0, v1, v2]
+ # fmt: on
+
+
+class TestMatmulMetalInt4Quant(MetalBeforeAfter):
+ # fmt: off
+ @T.prim_func(private=True)
+ def before(
+ B0: T.Buffer((28672, 512), "uint32"),
+ B1: T.Buffer((28672, 128), "float16"),
+ var_A: T.handle,
+ var_C: T.handle
+ ):
+ batch_size = T.int32()
+ A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16")
+ C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16")
+ compute = T.alloc_buffer((28672, 4096), "float16")
+ B = T.alloc_buffer((28672, 4096), "float16")
+ for i0, i1 in T.grid(28672, 4096):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ compute[v_i0, v_i1] = T.Cast("float16",
T.bitwise_and(T.shift_right(B0[v_i0, v_i1 // 8], T.Cast("uint32", v_i1 % 8 *
4)), T.uint32(15)))
+ for i0, i1 in T.grid(28672, 4096):
+ with T.block("dequantize"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ B[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7)) *
B1[v_i0, v_i1 // 32]
+ for i0, i1, i2, k in T.grid(batch_size, 1, 28672, 4096):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ with T.init():
+ C[v_i0, v_i1, v_i2] = T.float16(0)
+ C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k]
* B[v_i2, v_k]
+
+ @T.prim_func(private=True)
+ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672,
128), "float16"), var_A: T.handle, var_C: T.handle):
+ T.func_attr({"tir.is_scheduled": 1})
+ batch_size = T.int32()
+ A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16")
+ C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16")
+ # with T.block("root"):
+ A_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 *
16, 4096), "float16", scope="shared")
+ B_reindex_shared = T.alloc_buffer((1, 28672, 4096), "float16",
scope="shared")
+ A_reindex_pad_shared_metal_simdgroup = T.alloc_buffer((1, (batch_size
+ 15) // 16 * 16, 4096), "float16", scope="metal.simdgroup")
+ B_reindex_shared_metal_simdgroup = T.alloc_buffer((1, 4096, 28672),
"float16", scope="metal.simdgroup")
+ C_reindex_pad_metal_simdgroup = T.alloc_buffer((1, (batch_size + 15)
// 16 * 16, 28672), "float16", scope="metal.simdgroup")
+ C_reindex_pad_shared = T.alloc_buffer((1, (batch_size + 15) // 16 *
16, 28672), "float16", scope="shared")
+ for ax0 in T.thread_binding(1, thread="blockIdx.z"):
+ for ax1_0 in T.thread_binding((batch_size + 15) // 16,
thread="blockIdx.x"):
+ for ax2_0 in T.thread_binding(448, thread="blockIdx.y"):
+ for ax1_1 in T.thread_binding(1, thread="threadIdx.y"):
+ for ax2_1 in T.thread_binding(4, thread="threadIdx.z"):
+ for ax1_2_init, ax2_2_init, ax1_3_init_0,
ax2_3_init_0 in T.grid(2, 2, 1, 1):
+ with T.block("NT_matmul_init_o"):
+ v0_o = T.axis.spatial(1, ax0)
+ v1_o = T.axis.spatial(2 * ((batch_size +
15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2_init + ax1_3_init_0)
+ v2_o = T.axis.spatial(3584, ax2_0 * 8 +
ax2_1 * 2 + ax2_2_init + ax2_3_init_0)
+ T.reads()
+ T.writes(C_reindex_pad_metal_simdgroup[0,
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+ A_1 =
T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o *
8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"),
scope="metal.simdgroup", offset_factor=1)
+ T.make_filled_simdgroup_matrix(A_1.data,
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) +
A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8)
+ for ax3_0 in range(128):
+ for ax0_1, ax1_ax2_fused_0 in T.grid(1, 1):
+ for ax1_ax2_fused_1 in T.thread_binding(4,
thread="threadIdx.z"):
+ for ax1_ax2_fused_2 in
T.thread_binding(1, thread="threadIdx.y"):
+ for ax1_ax2_fused_3 in
T.thread_binding(32, thread="threadIdx.x"):
+ for ax1_ax2_fused_4 in
T.vectorized(4):
+ with
T.block("A_reindex_pad_shared"):
+ v0 = T.axis.spatial(1,
ax0_1)
+ v1 =
T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 *
512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 +
ax1_ax2_fused_4) // 32)
+ v2 =
T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 *
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32)
+ T.reads(A[v1, 0, v2])
+
T.writes(A_reindex_pad_shared[v0, v1, v2])
+
A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < batch_size, A[v1, 0,
v2], T.float16(0))
+ for ax0_1, ax1_ax2_fused_0 in T.grid(1, 4):
+ for ax1_ax2_fused_1 in T.thread_binding(4,
thread="threadIdx.z"):
+ for ax1_ax2_fused_2 in
T.thread_binding(1, thread="threadIdx.y"):
+ for ax1_ax2_fused_3 in
T.thread_binding(32, thread="threadIdx.x"):
+ for ax1_ax2_fused_4 in
T.vectorized(4):
+ with
T.block("B_reindex_shared"):
+ v0 = T.axis.spatial(1,
ax0_1)
+ v1 =
T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 *
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 32)
+ v2 =
T.axis.spatial(4096, ax3_0 * 32 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 *
128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 32)
+ T.reads(B0[v1, v2 //
8], B1[v1, v2 // 32])
+
T.writes(B_reindex_shared[v0, v1, v2])
+ B_reindex_shared[v0,
v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(B0[v1, v2 // 8],
T.Cast("uint32", v2 % 8 * 4)), T.uint32(15))) - T.float16(7)) * B1[v1, v2 // 32]
+ for ax3_1 in range(4):
+ for ax0_0, ax1_0_1 in T.grid(2, 1):
+ with
T.block("A_reindex_pad_shared_metal.simdgroup_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(2 *
((batch_size + 15) // 16), ax1_0 * 2 + ax0_0)
+ v2_o = T.axis.spatial(512, ax3_0 *
4 + ax3_1 + ax1_0_1)
+ T.reads(A_reindex_pad_shared[v0_o,
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+
T.writes(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o
* 8:v2_o * 8 + 8])
+ A_1 =
T.match_buffer(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o
* 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared",
offset_factor=1)
+ C_1 =
T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 +
8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"),
scope="metal.simdgroup", offset_factor=1)
+ T.simdgroup_load(C_1.data,
C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) +
C_1.elem_offset % C_1.strides[0] // 8,
T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset,
A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False))
+ for ax0_0, ax1_0_1 in T.grid(2, 1):
+ with
T.block("B_reindex_shared_metal.simdgroup_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(3584, ax2_0
* 8 + ax2_1 * 2 + ax0_0)
+ v2_o = T.axis.spatial(512, ax3_0 *
4 + ax3_1 + ax1_0_1)
+ T.reads(B_reindex_shared[v0_o,
v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+
T.writes(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o *
8:v1_o * 8 + 8])
+ A_1 =
T.match_buffer(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8
+ 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared",
offset_factor=1)
+ C_1 =
T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8,
v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"),
scope="metal.simdgroup", offset_factor=1)
+ T.simdgroup_load(C_1.data,
C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) +
C_1.elem_offset % C_1.strides[0] // 8,
T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset,
A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True))
+ for ax1_2, ax2_2 in T.grid(2, 2):
+ with T.block("NT_matmul_update_o"):
+ v0_o = T.axis.spatial(1, ax0)
+ v1_o = T.axis.spatial(2 *
((batch_size + 15) // 16), ax1_0 * 2 + ax1_1 * 2 + ax1_2)
+ v2_o = T.axis.spatial(3584, ax2_0
* 8 + ax2_1 * 2 + ax2_2)
+ v3_o = T.axis.reduce(512, ax3_0 *
4 + ax3_1)
+
T.reads(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o *
8 + 8], A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o *
8:v3_o * 8 + 8], B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8,
v2_o * 8:v2_o * 8 + 8])
+
T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o
* 8 + 8])
+ A_1 =
T.match_buffer(A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8,
v3_o * 8:v3_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"),
scope="metal.simdgroup", offset_factor=1)
+ B =
T.match_buffer(B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o
* 8:v2_o * 8 + 8], (8, 8), "float16", strides=("B_s0", "B_s1"),
scope="metal.simdgroup", offset_factor=1)
+ C_1 =
T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o *
8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"),
scope="metal.simdgroup", offset_factor=1)
+
T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] //
8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data,
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) +
A_1.elem_offset % A_1.strides[0] // 8, B.data, B.elem_offset // B.strides[0] //
8 * (B.strides[0] // 8) + B.elem_offset % B.strides[0] // 8, C_1.data,
C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_of
[...]
+ for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2):
+ with
T.block("C_reindex_pad_metal.simdgroup_o"):
+ v0_o = T.axis.spatial(1, ax0_1)
+ v1_o = T.axis.spatial(2 * ((batch_size +
15) // 16), ax1_0 * 2 + ax1_0_1)
+ v2_o = T.axis.spatial(3584, ax2_0 * 8 +
ax2_1 * 2 + ax2_0_1)
+
T.reads(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o *
8:v2_o * 8 + 8])
+ T.writes(C_reindex_pad_shared[v0_o, v1_o *
8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8])
+ A_1 =
T.match_buffer(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o
* 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"),
scope="metal.simdgroup", offset_factor=1)
+ C_1 =
T.match_buffer(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o
* 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="shared",
offset_factor=1)
+ T.simdgroup_store(A_1.data,
A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) +
A_1.elem_offset % A_1.strides[0] // 8,
T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset,
C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False))
+ for ax0_1, ax1_ax2_fused_0 in T.grid(1, 2):
+ for ax1_ax2_fused_1 in T.thread_binding(4,
thread="threadIdx.z"):
+ for ax1_ax2_fused_2 in T.thread_binding(1,
thread="threadIdx.y"):
+ for ax1_ax2_fused_3 in T.thread_binding(32,
thread="threadIdx.x"):
+ for ax1_ax2_fused_4 in T.vectorized(4):
+ with T.block("C_reindex_pad_shared"):
+ v0 = T.axis.spatial(1, ax0_1)
+ v1 = T.axis.spatial((batch_size +
15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 +
ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
+ v2 = T.axis.spatial(28672, ax2_0 *
64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 +
ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
+ T.reads(C_reindex_pad_shared[v0,
v1, v2])
+ T.writes(C[v1, 0, v2])
+ if v1 < batch_size:
+ C[v1, 0, v2] =
C_reindex_pad_shared[v0, v1, v2]
+
+
if __name__ == "__main__":
tvm.testing.main()