This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0274d8e1f1 [TIR] Support tensorization using ldmatrix + MMA (#11355)
0274d8e1f1 is described below
commit 0274d8e1f124cecc159abf3234251bf010784581
Author: Masahiro Masuda <[email protected]>
AuthorDate: Sat May 21 03:33:54 2022 +0900
[TIR] Support tensorization using ldmatrix + MMA (#11355)
* [TIR] Support tensorization using ldmatrix + MMA
commit 3218facf100b0dfc55715acfd1cee156764129ba
Author: Masahiro Masuda <[email protected]>
Date: Wed May 18 14:04:56 2022 +0900
some clean up
commit 7a235b69dc2023b3098ed44d591edb63b20a8f4e
Author: Masahiro Masuda <[email protected]>
Date: Wed May 18 13:55:11 2022 +0900
parameterize over storage scope in mma store intrin
commit 827ea4c434c35607b241f8e0ae2efe3214ac2458
Author: Masahiro Masuda <[email protected]>
Date: Wed May 18 13:37:38 2022 +0900
properly handle floordiv/mod in codegen
commit 42d4c6f42182c9fd79566c0955f99cc82abd5144
Author: Masahiro Masuda <[email protected]>
Date: Wed May 18 09:53:57 2022 +0900
update tuned factors for fp16
commit 328d0aa36b2ea9ea1b051970d612bff82d2d20e6
Author: Masahiro Masuda <[email protected]>
Date: Wed May 18 08:43:30 2022 +0900
all tests working
commit 5e086cf5fd1404ac38f85c4bfbe692687b45a16c
Author: Masahiro Masuda <[email protected]>
Date: Wed May 18 07:48:43 2022 +0900
add doc for mma_fill and mma_store intrin
commit 4f945c4116b6d3bdc965ecb2be2229bb46dc11ab
Author: Masahiro Masuda <[email protected]>
Date: Wed May 18 06:39:01 2022 +0900
remove tests
commit df7708f7f67761d9c18f9564bc15abd50c12ac69
Author: Masahiro Masuda <[email protected]>
Date: Tue May 17 19:52:14 2022 +0900
unified test
commit 754c83eeb8510b31fb9652b089177f9b8e642ec0
Author: Masahiro Masuda <[email protected]>
Date: Tue May 17 19:36:24 2022 +0900
clean up LowerWarpmemory
commit 178c3dcee7bfa17d5d93fec02aa858dc62151670
Author: Masahiro Masuda <[email protected]>
Date: Tue May 17 19:15:04 2022 +0900
Use IndexMap
commit 07fb58910338c62847fd902b37801d09b8c673b0
Author: Masahiro Masuda <[email protected]>
Date: Tue May 17 17:51:44 2022 +0900
remove 16x8x8 test
commit 2b05b5a5470ac221d559f31a31a8e2ff753b2414
Author: Masahiro Masuda <[email protected]>
Date: Tue May 17 17:31:35 2022 +0900
generate mma fill/store
commit bf23fc50f0ffa99e875d9247ca66acec0c36677f
Author: Masahiro Masuda <[email protected]>
Date: Tue May 17 12:23:30 2022 +0900
mma intrin generation with meta programming
commit 5afb5f00afd642cb1e39872edc7965f476dcdcb7
Author: Masahiro Masuda <[email protected]>
Date: Tue May 17 05:26:14 2022 +0900
ldmatrix intrin generation with meta programming
commit fb62abb3424b88ec48c697e306e05889a3ac306f
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 20:30:49 2022 +0900
minor
commit 5a80adce24e84d3ec6bf931b60cb9c730d243394
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 19:55:57 2022 +0900
revert some change
commit e599a55078ee75f2480a721098341812db58cf6f
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 19:54:18 2022 +0900
remove obsolete files
commit 4b13b85ff91d0d592a7e0c01924e0b49b82f35a8
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 19:51:21 2022 +0900
wip
commit 848de63455539e25cd0d43e5a65fd048636ef0f7
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 19:44:29 2022 +0900
wip
commit b35bff97ed10c22559e2164eb7538db0f711ce7e
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 19:31:18 2022 +0900
update parse error msg
commit ad9b053ef865b1f91f03d7b15ed7aae3420ee213
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 19:26:51 2022 +0900
fix for avoiding Buffer.vload(...) case
commit 54c686443e370edbfae860d0809b1b6182d26414
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 18:59:55 2022 +0900
wip
commit 078060fe28d22f1db5f07b1c382dee438f02df60
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 18:57:34 2022 +0900
wip
commit 576f8415e65e0e8a8a7808885e219b3b53867950
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 18:52:15 2022 +0900
wip
commit 12a376ae2f44aa6660121e64e0358f2866624f7f
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 17:54:58 2022 +0900
Squashed commit of the following:
commit 48eef4981d1a55aaf3b0ac935f2a10347cb1ac2d
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 17:40:48 2022 +0900
more comment
commit 8f67fc87038834e9f7e2c5cd3dfe61fabf442206
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 17:11:27 2022 +0900
update test
commit ad85036621c005b733763e67ceffae39c356ec99
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 16:54:01 2022 +0900
add test
commit 4a5dc3ffd5d0bb4a1700e57897c9e0f26e3d2a88
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 16:40:47 2022 +0900
[TVMScript] Support function call to help construct AST
commit 76c1bcf0ade45d7433a0066236add8372b1cc547
Author: Masahiro Masuda <[email protected]>
Date: Mon May 16 16:30:07 2022 +0900
simplify iterator in layout transform
commit 936280324ea2c91429a6a85a1b8ee89c7b825928
Author: Masahiro Masuda <[email protected]>
Date: Sat May 14 11:31:39 2022 +0900
remove obsolet files
commit 2e119b422d72d726d5f2bd20fe48a1e62fcb0510
Author: Masahiro Masuda <[email protected]>
Date: Sat May 14 10:43:59 2022 +0900
calculate mma store dst index using inverse affine map
commit 9489434ee52b546e2abb2ab28173eefd51525ba4
Author: Masahiro Masuda <[email protected]>
Date: Sat May 14 10:01:12 2022 +0900
simplify store
commit 1adcb77b8bba8e5d91080fe6cbfc7add7f4365c2
Author: Masahiro Masuda <[email protected]>
Date: Sat May 14 09:43:40 2022 +0900
simplified fill
commit 7b13c736d23e0eac94137aa918101d788e60d4f3
Author: Masahiro Masuda <[email protected]>
Date: Sat May 14 09:22:17 2022 +0900
simplify intrin desc using index map function
commit bcf212dda0f94c51f55c48921f61d92fd3b83777
Author: Masahiro Masuda <[email protected]>
Date: Sat May 14 07:16:42 2022 +0900
seems to work
commit dd8ccf9ec2e48100158152e5d4590d141424e2e2
Author: Masahiro Masuda <[email protected]>
Date: Sat May 14 07:11:57 2022 +0900
poking with the parser
commit 596582cbfbd08ebe23ea71aaf7a447472415ccd1
Author: Masahiro Masuda <[email protected]>
Date: Fri May 13 20:04:59 2022 +0900
16x8x32 4k trans working
commit 273f89a8a6ac34f7c79147563922d34d44bffd08
Author: Masahiro Masuda <[email protected]>
Date: Fri May 13 19:52:13 2022 +0900
add 16x8x16 fp16 trans
commit 8e2066cc4c6e86616bc9751324e63ba81a3b02af
Author: Masahiro Masuda <[email protected]>
Date: Fri May 13 19:32:37 2022 +0900
16x8x16 4k trans working
commit c2d0744051733e94f840d4517bcee9ca5d444c75
Author: Masahiro Masuda <[email protected]>
Date: Fri May 13 19:25:52 2022 +0900
16x8x16 trans working
commit c2e314cdda1c3a931781e51a863901ea178dffec
Author: Masahiro Masuda <[email protected]>
Date: Fri May 13 16:19:32 2022 +0900
tuned int8 4k, 91 TOPS
commit 94d9d965f19ff1a2ebdd342079ef420fb537b16a
Author: Masahiro Masuda <[email protected]>
Date: Fri May 13 15:59:33 2022 +0900
int8 4k tune working
commit 3ca8ca02593aff7540c9655aa831348246171752
Author: Masahiro Masuda <[email protected]>
Date: Fri May 13 08:43:57 2022 +0900
mma 16x8x32 int8 working with ldmatrix b workaround
commit 54f1cb731d4b42a6cbc08baf144e74646400eef5
Author: Masahiro Masuda <[email protected]>
Date: Fri May 13 18:23:27 2022 +0900
wip
commit 9d2844db602dc65af4dbd06a73fdd815f486b8b9
Author: Masahiro Masuda <[email protected]>
Date: Fri May 13 16:38:53 2022 +0900
test tensorize without layout transform
commit 86ee6dabc801aeb8d6917bec6de97b42025dbdd1
Author: Masahiro Masuda <[email protected]>
Date: Fri May 13 15:15:34 2022 +0900
int8 4k tensorize works
commit 39f9e32c9a64222c91daba2c32969b27207a31d2
Author: Masahiro Masuda <[email protected]>
Date: Fri May 13 12:44:39 2022 +0900
begin int8 4k tune
commit 6fa91e55b5ab2ba0f901d0d35be1b2fb3ab092b0
Author: Masahiro Masuda <[email protected]>
Date: Thu May 12 18:53:20 2022 +0900
try fix ldmatrix b for int8
commit 7a962cddc4799fa3df0c0fdf3c056146d3f2cbdf
Author: Masahiro Masuda <[email protected]>
Date: Thu May 12 18:28:34 2022 +0900
fixed warp_coeff
commit a0afb5698f307382147a38819e004a2db7f554b1
Author: Masahiro Masuda <[email protected]>
Date: Thu May 12 12:20:01 2022 +0900
wip
commit f70ccd09b07d5325454ffdc39a7619ea84aa7e06
Author: Masahiro Masuda <[email protected]>
Date: Thu May 12 12:09:57 2022 +0900
int8 tensorize working
commit 20321fa4674dabc78fe55b5e0e2876c35b245d21
Author: Masahiro Masuda <[email protected]>
Date: Thu May 12 07:06:22 2022 +0900
starting 16x8x32 int8
commit 441fd193c59cdc436d87ab35896cbb8c779ddf35
Author: Masahiro Masuda <[email protected]>
Date: Thu May 12 05:50:46 2022 +0900
adding fp16 accum case
commit c9d40b69b1b57bfaddffba09ea07624ae90ee465
Author: Masahiro Masuda <[email protected]>
Date: Wed May 11 17:04:29 2022 +0900
clean up
commit 5b2d48635e762c77c824d1c259ac8bcbcc949421
Author: Masahiro Masuda <[email protected]>
Date: Wed May 11 16:38:19 2022 +0900
16x8x16 4k tune working
commit c3cb170d85600d03da5c3f4cda03552208ca0b8c
Author: Masahiro Masuda <[email protected]>
Date: Wed May 11 16:20:27 2022 +0900
tensoriz fixed
commit 68039b081efcdd6aea1d132940b3745f50164974
Author: Masahiro Masuda <[email protected]>
Date: Wed May 11 15:55:25 2022 +0900
begin 16x8x16 4k tune
commit ced5d8d980cc267d4735957c25cb60d71ae977d2
Author: Masahiro Masuda <[email protected]>
Date: Wed May 11 15:50:11 2022 +0900
16x8x16 worked
commit 3d2c90d77c1bb2df2193e9af6cbaa2bd927a26d8
Author: Masahiro Masuda <[email protected]>
Date: Wed May 11 15:47:26 2022 +0900
fix
commit 403050b03ad6b4f0ee8d45088ffb324727bbae48
Author: Masahiro Masuda <[email protected]>
Date: Wed May 11 15:45:10 2022 +0900
add 16x8x16 test
commit 18e8d73661c99cd1c83021063b41a457afcb1638
Author: Masahiro Masuda <[email protected]>
Date: Wed May 11 06:50:32 2022 +0900
fixed mma store codegen for 16x8x16
commit ec81250561195705122bccb9a2372f71de68121f
Author: Masahiro Masuda <[email protected]>
Date: Wed May 11 04:25:25 2022 +0900
add 16x8x16 mma store codegen
commit e08df2a62a4809bcd39782949283c16e7703aa5c
Author: Masahiro Masuda <[email protected]>
Date: Wed May 11 03:47:47 2022 +0900
tensorized C_warp init
commit ae0678918929c1ceec73f2039467040c5bb7823b
Author: Masahiro Masuda <[email protected]>
Date: Wed May 11 03:06:06 2022 +0900
mma store codegen working
commit deb4d6646cc93d4cdb4f2560ce723bee4d86e144
Author: Masahiro Masuda <[email protected]>
Date: Tue May 10 19:22:57 2022 +0900
update lower warp memory
commit 71fe5fe465300705fa94f9544a2e1a5070de6e0d
Author: Masahiro Masuda <[email protected]>
Date: Tue May 10 09:01:42 2022 +0900
tensorizing mma store
commit e80a1f148c47f2a3fac2363a733d8d4e2a2631d0
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 28 19:54:08 2022 +0900
clean up
commit a9640f4b7c3c9f22b87ca74a61003438dfd8f992
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 28 19:40:55 2022 +0900
add tunable 4k test, 36 TFLOPS
commit b9f7eae7041d1a9b3e434c331c874e8347e89dc4
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 28 18:01:08 2022 +0900
fixed bug in LowerWarpMemory index splitting for ldmatrix
commit 00df30823f874910ed1ec1f74718100311764234
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 27 07:58:17 2022 +0900
fixed missing reverse_compute_at
commit 93f9fe7e5f7ad16c8d0e6240c16c0281a0e97dec
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 27 06:55:12 2022 +0900
add 4k test
commit 3689ef712aa4b282a4818fa2fa2e7e349c3a5eec
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 27 06:54:09 2022 +0900
temp disable high dim base indices check in tensorize
commit 0c859c4f385ba0b6f9477b569b80cee80b5b7282
Author: Masahiro Masuda <[email protected]>
Date: Tue Apr 26 19:18:23 2022 +0900
clean up
commit f6aadbfcfbd73c1667a6de7aedc5894232b8e750
Author: Masahiro Masuda <[email protected]>
Date: Tue Apr 26 19:13:09 2022 +0900
Add 16x8x8 MMA + LDMatrix test
commit 4cf6b20c6ca415e967ab58d80e4a77c701ad7255
Author: Masahiro Masuda <[email protected]>
Date: Tue Apr 26 18:04:17 2022 +0900
testing 16x8x8 ldmatrix tensoriation
* set measure_perf to False
* add requires_gpu decorator in tests, always test build on non-ampere
* skip cuda compile on old gpu
---
include/tvm/tir/builtin.h | 27 ++
python/tvm/tir/tensor_intrin/__init__.py | 1 +
python/tvm/tir/tensor_intrin/cuda.py | 469 +++++++++++++++++++++
src/target/source/codegen_cuda.cc | 76 +++-
src/tir/op/builtin.cc | 6 +
src/tir/transforms/lower_warp_memory.cc | 45 +-
.../test_tir_schedule_tensorize_ldmatrix_mma.py | 422 ++++++++++++++++++
7 files changed, 1042 insertions(+), 4 deletions(-)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index f33432645c..5fc42392c3 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -651,6 +651,33 @@ TVM_DLL const Op& ptx_cp_async();
TVM_DLL const Op& ptx_commit_group();
TVM_DLL const Op& ptx_wait_group();
+/*!
+ * \brief tvm intrinsic for storing the result of PTX MMA into a destination
pointer.
+ * For example, if each thread in a warp of size 32 has 4 elements from
the result of
+ * m16xn8xk16 MMA in its registers, this intrinsic can be used to store
the result in a
+ * 16x8 region in shared or global memory.
+ *
+ * There is no real PTX instruction that does that, but we want to hide
details of
+ * complex index manipulation behind this intrinsic to simplify TIR
lowering passes (e.g.
+ * LowerWarpMemory).
+ *
+ * void mma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr
src_offset, Var dst_stride);
+ */
+TVM_DLL const Op& mma_store();
+
+/*!
+ * \brief tvm intrinsic for zero-initalizing an MMA accumulation registor.
+ * For example, if each thread in a warp of size 32 has 8 elements from
the A matrix in
+ * m16xn8xk16 MMA in its registers, this intrinsic can be used to
zero-initialize its
+ * 4 accumulation registers.
+ *
+ * There is no real PTX instruction that does that, but we introduce
this intrinsic for the
+ * same reason as mma_store above.
+ *
+ * void mma_fill(IntImm local_size, Var local_ptr, Expr offset);
+ */
+TVM_DLL const Op& mma_fill();
+
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
diff --git a/python/tvm/tir/tensor_intrin/__init__.py
b/python/tvm/tir/tensor_intrin/__init__.py
index 4115c3b900..a3b47ff6d5 100644
--- a/python/tvm/tir/tensor_intrin/__init__.py
+++ b/python/tvm/tir/tensor_intrin/__init__.py
@@ -20,3 +20,4 @@ from .x86 import *
from .arm_cpu import *
from .dot_product_common import *
from .rocm import *
+from .cuda import *
diff --git a/python/tvm/tir/tensor_intrin/cuda.py
b/python/tvm/tir/tensor_intrin/cuda.py
new file mode 100644
index 0000000000..853a377354
--- /dev/null
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -0,0 +1,469 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,missing-function-docstring
+"""Intrinsics for tensorization on NVIDIA GPU."""
+from tvm.script import tir as T
+from .. import IntImm, Cast
+from ..._ffi import register_func
+from ...runtime import convert
+from .. import TensorIntrin
+
+
+def shared_16x16_to_ldmatrix_32x8_layout(i, j):
+ thread_id = 4 * (i % 8) + (j % 8) // 2
+ return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
+
+
+def shared_16x32_to_ldmatrix_32x16_layout(i, j):
+ thread_id = 4 * (i % 8) + (j % 16) // 4
+ return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4
+
+
+def shared_32x16_to_ldmatrix_32x16_layout(i, j):
+ thread_id = (i % 4) + 4 * (j % 8)
+ return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
+
+
+@register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
+def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind):
+ i, j = ind[0], ind[1]
+ thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
+ return convert([thread_id, local_id])
+
+
+lift = convert
+
+M_DIM = 16
+N_DIM = 16
+WARP_SIZE = 32
+HALF_WARP = WARP_SIZE // 2
+HALF_WARP_expr = lift(HALF_WARP)
+
+
+def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
+ local_size = (M_DIM * k_dim) // WARP_SIZE
+ shared_offset = None
+ index_map = None
+
+ if transposed:
+ assert is_b, "Transposed A matrix not supported"
+
+ ldmatrix_col_major = is_b and not transposed
+
+ if k_dim == 16:
+ assert dtype == "float16"
+
+ index_map = shared_16x16_to_ldmatrix_32x8_layout
+
+ if transposed:
+ shared_offset = (
+ lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
+ + stride * (tx % 8)
+ + 8 * ((tx % HALF_WARP_expr) // 8)
+ )
+ else:
+ shared_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr)
+ 8 * (
+ tx // HALF_WARP_expr
+ )
+ else:
+ assert (
+ k_dim == 32 and dtype == "int8"
+ ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now"
+
+ if ldmatrix_col_major:
+ index_map = shared_32x16_to_ldmatrix_32x16_layout
+ # A dummy offset, ldmatrix cannot be used for int8 + trans case.
+ # We still use the ldmatrix intrinsic, but lower it to a manual
loop in the codegen.
+ # Only the stride information is required.
+ shared_offset = lambda _, stride: stride
+ elif is_b and transposed:
+ index_map = shared_16x32_to_ldmatrix_32x16_layout
+ shared_offset = (
+ lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
+ + (tx % 8) * stride
+ + 16 * ((tx % HALF_WARP_expr) // 8)
+ )
+ else:
+ index_map = shared_16x32_to_ldmatrix_32x16_layout
+ shared_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx
// 16)
+
+ assert index_map and shared_offset
+
+ if is_b and not transposed:
+ row_dim = k_dim
+ col_dim = M_DIM
+ else:
+ row_dim = M_DIM
+ col_dim = k_dim
+
+ shmem_shape = (row_dim, col_dim)
+
+ @T.prim_func
+ def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
+ shared = T.match_buffer(
+ shared_handle, shmem_shape, dtype, align=128, offset_factor=16,
scope="shared"
+ )
+ warp = T.match_buffer(
+ warp_handle, (WARP_SIZE, local_size), dtype, align=128,
offset_factor=16, scope="warp"
+ )
+
+ with T.block("root"):
+ T.reads(shared[0:row_dim, 0:col_dim])
+ T.writes(warp[0:WARP_SIZE, 0:local_size])
+
+ for ax0, ax1 in T.grid(row_dim, col_dim):
+ with T.block("shared_warp"):
+ v0, v1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(shared[v0, v1])
+
+ thread_id, local_id = index_map(v0, v1)
+ T.writes(warp[thread_id, local_id])
+ warp[thread_id, local_id] = shared[v0, v1]
+
+ @T.prim_func
+ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
+ s0 = T.var("int32")
+ s1 = T.var("int32")
+ shared = T.match_buffer(
+ shared_handle,
+ shmem_shape,
+ dtype,
+ align=128,
+ offset_factor=16,
+ scope="shared",
+ strides=[s0, s1],
+ )
+ warp = T.match_buffer(
+ warp_handle, (WARP_SIZE, local_size), dtype, align=128,
offset_factor=16, scope="warp"
+ )
+
+ with T.block("root"):
+ T.reads(shared[0:row_dim, 0:col_dim])
+ T.writes(warp[0:WARP_SIZE, 0:local_size])
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, WARP_SIZE)
+
+ T.evaluate(
+ T.ptx_ldmatrix(
+ ldmatrix_col_major,
+ 4, # Always load 4 matrices
+ ".b16",
+ warp.data,
+ warp.elem_offset + lift(local_size) * tx,
+ shared.access_ptr("r"),
+ shared_offset(tx, s0),
+ dtype=dtype,
+ )
+ )
+
+ return ldmatrix_desc, ldmatrix_impl
+
+
+def get_mma_intrin(k_dim, out_dtype, b_transposed):
+ local_size = (M_DIM * k_dim) // WARP_SIZE
+ local_size_out = (M_DIM * N_DIM) // 32
+
+ index_map_C = shared_16x16_to_ldmatrix_32x8_layout
+
+ if k_dim == 16:
+ index_map_A = shared_16x16_to_ldmatrix_32x8_layout
+ index_map_B = shared_16x16_to_ldmatrix_32x8_layout
+ mma_prefix = "m16n8k16"
+ elif k_dim == 32 and b_transposed:
+ index_map_A = index_map_B = shared_16x32_to_ldmatrix_32x16_layout
+ mma_prefix = "m16n8k32"
+ elif k_dim == 32 and not b_transposed:
+ index_map_A = shared_16x32_to_ldmatrix_32x16_layout
+ index_map_B = shared_32x16_to_ldmatrix_32x16_layout
+ mma_prefix = "m16n8k32"
+ else:
+ assert False
+
+ out_dtype_abbrv = {"float16": "fp16", "float32": "fp32", "int32":
"int32"}[out_dtype]
+
+ if out_dtype in ["float16", "float32"]:
+ in_dtype = "float16"
+ in_dtype_abbrv = "fp16"
+ else:
+ in_dtype = "int8"
+ in_dtype_abbrv = "int8"
+
+ def maybe_cast(v):
+ if out_dtype in ["float32", "int32"]:
+ return Cast(out_dtype, v)
+ return v
+
+ def maybe_swap(i, j):
+ if b_transposed:
+ return j, i
+ return i, j
+
+ @T.prim_func
+ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(
+ a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16,
scope="warp"
+ )
+ B = T.match_buffer(
+ b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16,
scope="warp"
+ )
+ C = T.match_buffer(
+ c, (WARP_SIZE, local_size_out), out_dtype, align=128,
offset_factor=16, scope="warp"
+ )
+
+ with T.block("root"):
+ T.reads(
+ C[0:WARP_SIZE, 0:local_size_out],
+ A[0:WARP_SIZE, 0:local_size],
+ B[0:WARP_SIZE, 0:local_size],
+ )
+ T.writes(C[0:WARP_SIZE, 0:local_size_out])
+
+ for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
+ with T.block("C"):
+ i, j, k = T.axis.remap("SSR", [i, j, k])
+ b_row_ind, b_col_ind = maybe_swap(k, j)
+
+ thread_id_C, local_id_C = index_map_C(i, j)
+ thread_id_A, local_id_A = index_map_A(i, k)
+ thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind)
+
+ T.reads(
+ C[thread_id_C, local_id_C],
+ A[thread_id_A, local_id_A],
+ B[thread_id_B, local_id_B],
+ )
+ T.writes(C[thread_id_C, local_id_C])
+
+ C[thread_id_C, local_id_C] += maybe_cast(
+ A[thread_id_A, local_id_A]
+ ) * maybe_cast(B[thread_id_B, local_id_B])
+
+ @T.prim_func
+ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(
+ a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16,
scope="warp"
+ )
+ B = T.match_buffer(
+ b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16,
scope="warp"
+ )
+ C = T.match_buffer(
+ c, (WARP_SIZE, local_size_out), out_dtype, align=128,
offset_factor=16, scope="warp"
+ )
+
+ with T.block("root"):
+ T.reads(
+ C[0:WARP_SIZE, 0:local_size_out],
+ A[0:WARP_SIZE, 0:local_size],
+ B[0:WARP_SIZE, 0:local_size],
+ )
+ T.writes(C[0:WARP_SIZE, 0:local_size_out])
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, WARP_SIZE)
+
+ T.evaluate(
+ T.ptx_mma(
+ mma_prefix,
+ "row",
+ "col",
+ in_dtype_abbrv,
+ in_dtype_abbrv,
+ out_dtype_abbrv,
+ A.data,
+ A.elem_offset + tx * lift(local_size),
+ B.data,
+ B.elem_offset + tx * lift(local_size),
+ C.data,
+ C.elem_offset + tx * lift(local_size_out),
+ False,
+ dtype=out_dtype,
+ )
+ )
+
+ T.evaluate(
+ T.ptx_mma(
+ mma_prefix,
+ "row",
+ "col",
+ in_dtype_abbrv,
+ in_dtype_abbrv,
+ out_dtype_abbrv,
+ A.data,
+ A.elem_offset + tx * lift(local_size),
+ B.data,
+ B.elem_offset + tx * lift(local_size) + lift(local_size)
// 2,
+ C.data,
+ C.elem_offset + tx * lift(local_size_out) +
lift(local_size_out) // 2,
+ False,
+ dtype=out_dtype,
+ )
+ )
+
+ return mma_sync_desc, mma_sync_impl
+
+
+def get_mma_fill_intrin(dtype, local_size):
+ zero = IntImm("int32", 0).astype(dtype)
+
+ # Assume M = N = 16
+ index_map = shared_16x16_to_ldmatrix_32x8_layout
+
+ @T.prim_func
+ def mma_fill_desc(a: T.handle) -> None:
+ C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype,
scope="warp")
+
+ with T.block("root"):
+ T.reads()
+ T.writes(C_warp[0:WARP_SIZE, 0:local_size])
+ for i0, i1 in T.grid(M_DIM, N_DIM):
+ with T.block("C_warp"):
+ i, j = T.axis.remap("SS", [i0, i1])
+ thread_id, local_id = index_map(i, j)
+ T.reads()
+ T.writes(C_warp[thread_id, local_id])
+ C_warp[thread_id, local_id] = zero
+
+ @T.prim_func
+ def mma_fill_impl(a: T.handle) -> None:
+ C_warp = T.match_buffer(
+ a, [WARP_SIZE, local_size], dtype=dtype, scope="warp",
offset_factor=1
+ )
+
+ with T.block("root"):
+ T.reads()
+ T.writes(C_warp[0:WARP_SIZE, 0:local_size])
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, WARP_SIZE)
+
+ T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset,
dtype=dtype))
+
+ return mma_fill_desc, mma_fill_impl
+
+
+def get_mma_store_intrin(dtype, local_size, scope="global"):
+ # Assume M = N = 16
+ index_map = shared_16x16_to_ldmatrix_32x8_layout
+
+ @T.prim_func
+ def mma_store_desc(a: T.handle, c: T.handle) -> None:
+ C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype,
scope="warp")
+ C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope)
+
+ with T.block("root"):
+ T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+ T.writes(C[0:M_DIM, 0:N_DIM])
+ for i0, i1 in T.grid(M_DIM, N_DIM):
+ with T.block("C_warp"):
+ v0, v1 = T.axis.remap("SS", [i0, i1])
+ thread_id, local_id = index_map(v0, v1)
+ T.reads(C_warp[thread_id, local_id])
+ T.writes(C[v0, v1])
+ C[v0, v1] = C_warp[thread_id, local_id]
+
+ @T.prim_func
+ def mma_store_impl(a: T.handle, c: T.handle) -> None:
+ s0 = T.var("int32")
+ s1 = T.var("int32")
+
+ C_warp = T.match_buffer(
+ a, [WARP_SIZE, local_size], dtype=dtype, scope="warp",
offset_factor=1
+ )
+ C = T.match_buffer(
+ c, [M_DIM, N_DIM], dtype=dtype, scope="global", offset_factor=1,
strides=[s0, s1]
+ )
+
+ with T.block("root"):
+ T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+ T.writes(C[0:M_DIM, 0:N_DIM])
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(tx, WARP_SIZE)
+
+ T.evaluate(
+ T.mma_store(
+ M_DIM,
+ N_DIM,
+ C.access_ptr("w"),
+ C_warp.data,
+ C_warp.elem_offset,
+ s0,
+ dtype=dtype,
+ )
+ )
+
+ return mma_store_desc, mma_store_impl
+
+
+LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a"
+TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16,
"float16", False, False))
+
+LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b"
+TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16,
"float16", True, False))
+
+LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans"
+TensorIntrin.register(
+ LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True,
True)
+)
+
+LDMATRIX_16x32_A_INTRIN = "mma.ldmatrix_16x32_a"
+TensorIntrin.register(LDMATRIX_16x32_A_INTRIN, *get_ldmatrix_intrin(32,
"int8", False, False))
+
+LDMATRIX_32x16_B_INTRIN = "mma.ldmatrix_32x16_b"
+TensorIntrin.register(LDMATRIX_32x16_B_INTRIN, *get_ldmatrix_intrin(32,
"int8", True, False))
+
+LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans"
+TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32,
"int8", True, True))
+
+MMA_f16f16f32_INTRIN = "mma_f16f16f32"
+TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32",
False))
+
+MMA_f16f16f32_TRANS_INTRIN = "mma_f16f16f32_trans"
+TensorIntrin.register(MMA_f16f16f32_TRANS_INTRIN, *get_mma_intrin(16,
"float32", True))
+
+MMA_f16f16f16_INTRIN = "mma_f16f16f16"
+TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16",
False))
+
+MMA_f16f16f16_TRANS_INTRIN = "mma_f16f16f16_trans"
+TensorIntrin.register(MMA_f16f16f16_TRANS_INTRIN, *get_mma_intrin(16,
"float16", True))
+
+MMA_i8i8i32_INTRIN = "mma_i8i8i32"
+TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False))
+
+MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans"
+TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32",
True))
+
+MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32"
+TensorIntrin.register(MMA_fill_16x16_f32_INTRIN,
*get_mma_fill_intrin("float32", 8))
+
+MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16"
+TensorIntrin.register(MMA_fill_16x16_f16_INTRIN,
*get_mma_fill_intrin("float16", 8))
+
+MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32"
+TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32",
8))
+
+MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_"
+TensorIntrin.register(
+ MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8,
"global")
+)
+
+MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_"
+TensorIntrin.register(
+ MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8,
"global")
+)
+
+MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_"
+TensorIntrin.register(
+ MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8,
"global")
+)
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 7459d4c250..616e75f2e7 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -25,6 +25,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
+#include <tvm/tir/index_map.h>
#include <tvm/tir/stmt_functor.h>
#include <cmath>
@@ -818,9 +819,78 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
std::string local_ptr = this->PrintExpr(op->args[3]);
std::string local_elem_offset = this->PrintExpr(op->args[4]);
std::string smem_ptr = this->PrintExpr(op->args[5]);
- std::string smem_elem_offset = this->PrintExpr(op->args[6]);
- this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
local_elem_offset,
- smem_ptr, smem_elem_offset);
+ if (trans && op->dtype.bits() == 8) {
+ // Since ldmatrix assumes that a matrix element is 16 bit, it cannot
properly transpose an
+ // int8 matrix.
+ std::string smem_stride = this->PrintExpr(op->args[6]);
+ ICHECK(num == 4);
+ os << "for (int i = 0; i < 16; ++i) {\n";
+ os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
+ << "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 *
" + smem_stride +
+ "+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 + (i / 8)
* 8];\n";
+ os << "}\n";
+ } else {
+ std::string smem_elem_offset = this->PrintExpr(op->args[6]);
+ this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr,
local_elem_offset,
+ smem_ptr, smem_elem_offset);
+ }
+ } else if (op->op.same_as(builtin::mma_store())) {
+ int m = Downcast<Integer>(op->args[0])->value;
+ int n = Downcast<Integer>(op->args[1])->value;
+ std::string dst = this->PrintExpr(op->args[2]);
+ std::string src = this->PrintExpr(op->args[3]);
+ std::string src_offset = this->PrintExpr(op->args[4]);
+ PrimExpr stride = op->args[5];
+
+ ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for
now";
+
+ // Each thread in a warp holds a certain number of elements of an MMA
output.
+ // For example, if we compute a 16x16 tile using MMA, each thread holds 8
elements
+ // in its registers. So conceptually, a warp memory is organized as a 32x8
block.
+ // A map from a 16x16 tile to a 32x8 block of memory is specified by the
index map below.
+
+ // To store the 32x8 output back to a 16x16 tile in shared or global
memory, we invert this map
+ // to determine the output location for each 8 element.
+
+ const auto* index_map_func =
+
runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout");
+ ICHECK(index_map_func);
+
+ auto inverse_index_map =
+ IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0,
n)});
+ auto indices_16x16 = inverse_index_map->final_indices;
+
+ // "//" and "%" in the index map are translated to FloorDiv/Mod, but the
plain Div/Mod are fine.
+ // FloorDiv/Mod are supposed to be lowered before they reach codegen, so
manually replace them
+ // to the plain ones here.
+ class LowerFloorDivMod : public ExprMutator {
+ public:
+ PrimExpr VisitExpr_(const FloorDivNode* op) {
+ return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b));
+ }
+ PrimExpr VisitExpr_(const FloorModNode* op) {
+ return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b));
+ }
+ };
+
+ auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride +
indices_16x16[1]);
+
+ var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x";
+ var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
+
+ os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
+ os << dst << "[" + this->PrintExpr(dst_ind) + "]"
+ << " = " << src << "[" << src_offset << " + local_id];\n";
+ os << "}\n";
+
+ } else if (op->op.same_as(builtin::mma_fill())) {
+ std::string num_elem = this->PrintExpr(op->args[0]);
+ std::string dst = this->PrintExpr(op->args[1]);
+ std::string dst_offset = this->PrintExpr(op->args[2]);
+
+ os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
+ os << dst << "[" << dst_offset << " + i] = 0.0;";
+ os << "}\n";
} else if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]);
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 0415d1bbec..1871a3d7bf 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -256,6 +256,12 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group)
TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(mma_store).set_attr<TCallEffectKind>("TCallEffectKind",
+
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(mma_fill).set_attr<TCallEffectKind>("TCallEffectKind",
+
Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
diff --git a/src/tir/transforms/lower_warp_memory.cc
b/src/tir/transforms/lower_warp_memory.cc
index 40971114d4..d8250cd098 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -101,7 +101,7 @@ namespace tir {
// Visitor to find m in pattern
// store warp_mem[m * warp_index + (width * m) * y + x]
-class WarpStoreCoeffFinder : private StmtVisitor {
+class WarpStoreCoeffFinder : private StmtExprVisitor {
public:
WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer*
analyzer)
: buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {}
@@ -113,6 +113,18 @@ class WarpStoreCoeffFinder : private StmtVisitor {
private:
/// Visitor implementation
+ void VisitExpr_(const CallNode* op) final {
+ if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as<VarNode>()
== buffer_) {
+ UpdatePattern(op->args[4]);
+ } else if (op->op.same_as(builtin::mma_fill()) &&
op->args[1].as<VarNode>() == buffer_) {
+ auto* local_size = op->args[0].as<IntImmNode>();
+ ICHECK(local_size) << "Integer expected for the first argument of
mma_fill";
+ warp_coeff_ = local_size->value;
+ }
+
+ StmtExprVisitor::VisitExpr_(op);
+ }
+
void VisitStmt_(const StoreNode* op) final {
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use
BufferStoreNode instead.";
}
@@ -245,6 +257,37 @@ class WarpAccessRewriter : protected StmtExprMutator {
}
protected:
+ PrimExpr RewriteIndicesAt(const CallNode* op, const std::vector<int>&
indices) {
+ Array<PrimExpr> new_args = op->args;
+ for (int i : indices) {
+ if (op->args[i].get() == buffer_) {
+ PrimExpr local_index = SplitIndexByGroup(op->args[i + 1]).first;
+ new_args.Set(i + 1, local_index);
+ }
+ }
+ return Call(op->dtype, op->op, new_args);
+ }
+
+ PrimExpr VisitExpr_(const CallNode* op) override {
+ if (op->op.same_as(builtin::ptx_mma())) {
+ return RewriteIndicesAt(op, {6, 8, 10});
+ }
+
+ if (op->op.same_as(builtin::ptx_ldmatrix())) {
+ return RewriteIndicesAt(op, {3});
+ }
+
+ if (op->op.same_as(builtin::mma_store())) {
+ return RewriteIndicesAt(op, {3});
+ }
+
+ if (op->op.same_as(builtin::mma_fill())) {
+ return RewriteIndicesAt(op, {1});
+ }
+
+ return StmtExprMutator::VisitExpr_(op);
+ }
+
PrimExpr VisitExpr_(const VarNode* op) override {
ICHECK(op != buffer_) << "Cannot access address of warp memory directly";
return StmtExprMutator::VisitExpr_(op);
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
new file mode 100644
index 0000000000..67e8ae0ad8
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
@@ -0,0 +1,422 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm
+from tvm import te
+from tvm.tir.tensor_intrin.cuda import (
+ LDMATRIX_16x16_A_INTRIN,
+ LDMATRIX_16x16_B_INTRIN,
+ LDMATRIX_16x16_B_TRANS_INTRIN,
+ LDMATRIX_16x32_A_INTRIN,
+ LDMATRIX_32x16_B_INTRIN,
+ LDMATRIX_16x32_B_TRANS_INTRIN,
+ MMA_f16f16f32_INTRIN,
+ MMA_f16f16f32_TRANS_INTRIN,
+ MMA_f16f16f16_INTRIN,
+ MMA_f16f16f16_TRANS_INTRIN,
+ MMA_i8i8i32_INTRIN,
+ MMA_i8i8i32_TRANS_INTRIN,
+ MMA_fill_16x16_f32_INTRIN,
+ MMA_fill_16x16_f16_INTRIN,
+ MMA_fill_16x16_i32_INTRIN,
+ MMA_store_16x16_f32_global_INTRIN,
+ MMA_store_16x16_f16_global_INTRIN,
+ MMA_store_16x16_i32_global_INTRIN,
+ shared_16x16_to_ldmatrix_32x8_layout,
+ shared_32x16_to_ldmatrix_32x16_layout,
+ shared_16x32_to_ldmatrix_32x16_layout,
+)
+import tvm.testing
+import numpy as np
+
+
+M = 4096
+N = 4096
+K = 4096
+measure_perf = False
+gflops = (N * M * K) * 2 / 1e9
+
+
+def matmul(m, n, k, in_dtype, out_dtype, b_transposed):
+ b_shape = (n, k) if b_transposed else (k, n)
+ a = te.placeholder((m, k), name="A", dtype=in_dtype)
+ b = te.placeholder(b_shape, name="B", dtype=in_dtype)
+ k = te.reduce_axis((0, k), name="k")
+
+ def maybe_cast(v):
+ if in_dtype != out_dtype:
+ return tvm.tir.Cast(out_dtype, v)
+ return v
+
+ def maybe_swap(i, j):
+ if b_transposed:
+ return j, i
+ return i, j
+
+ c = te.compute(
+ (m, n),
+ lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k,
j)]), axis=[k]),
+ name="C",
+ )
+ return (a, b, c)
+
+
+def is_ampere_or_newer():
+ arch = tvm.contrib.nvcc.get_target_compute_version()
+ major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
+ return major >= 8
+
+
+def run_test(
+ k_inner,
+ in_dtype,
+ out_dtype,
+ b_transposed,
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map_A,
+ index_map_B,
+ index_map_C,
+ ldmatrix_a_intrin,
+ ldmatrix_b_intrin,
+ mma_intrin,
+ mma_fill_intrin,
+ mma_store_intrin,
+):
+ workload = te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype,
b_transposed))
+ ir_module = tvm.IRModule({"main": workload})
+ sch = tvm.tir.Schedule(ir_module)
+
+ block = sch.get_block("C")
+ i, j, k = sch.get_loops(block)
+ i, i_tc = sch.split(i, factors=[None, 16])
+ j, j_tc = sch.split(j, factors=[None, 16])
+ k, k_tc = sch.split(k, factors=[None, k_inner])
+
+ sch.reorder(i, j, k, i_tc, j_tc, k_tc)
+
+ block_inner = sch.blockize(i_tc)
+ block_outer, block_inner = block_inner, block
+
+ num_ty = i_factors[2] * j_factors[2]
+
+ i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
+ j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors)
+ k0, k1, k2 = sch.split(k, k_factors)
+
+ sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4)
+
+ block_idx = sch.fuse(i0, j0)
+ block_idy = sch.fuse(i1, j1)
+ thread_idy = sch.fuse(j2, i2)
+ sch.bind(block_idx, "blockIdx.x")
+ sch.bind(block_idy, "blockIdx.y")
+ sch.bind(thread_idy, "threadIdx.y")
+
+ def fetch_to_shared(block, idx, ndim):
+ block_read = sch.cache_read(block, idx, "shared")
+ sch.compute_at(block_read, k0)
+ vector_size = 16 if in_dtype == "int8" else 8
+ warp_size = 32
+ fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
+ _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size,
vector_size])
+ sch.bind(f_2, "threadIdx.x")
+ sch.bind(f_1, "threadIdx.y")
+ sch.vectorize(f_3)
+ offset = 8 if in_dtype == "float16" else 16
+ sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)
+
+ return block_read
+
+ fetch_to_shared(block_outer, 0, 2)
+ fetch_to_shared(block_outer, 1, 2)
+
+ A_warp = sch.cache_read(block_outer, 0, "warp")
+ B_warp = sch.cache_read(block_outer, 1, "warp")
+
+ sch.compute_at(A_warp, k1)
+ sch.compute_at(B_warp, k1)
+
+ C_warp = sch.cache_write(block_outer, 0, "warp")
+ sch.reverse_compute_at(C_warp, thread_idy)
+
+ ii, jj = sch.get_loops(C_warp)[-2:]
+ io, ii = sch.split(ii, factors=[None, 16])
+ jo, ji = sch.split(jj, factors=[None, 16])
+ sch.reorder(io, jo, ii, ji)
+
+ sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
+ block_init_c = sch.get_block("C_init")
+
+ def tile_wmma_fragment(block_read, height, width):
+ i, j = sch.get_loops(block_read)[-2:]
+ i0, i1 = sch.split(i, factors=[None, height])
+ j0, j1 = sch.split(j, factors=[None, width])
+ sch.reorder(i0, j0, i1, j1)
+ return i1
+
+ loop_a = tile_wmma_fragment(A_warp, 16, k_inner)
+
+ if b_transposed:
+ loop_b = tile_wmma_fragment(B_warp, 16, k_inner)
+ else:
+ loop_b = tile_wmma_fragment(B_warp, k_inner, 16)
+
+ sch.transform_layout(A_warp, 0, "write", index_map_A)
+ sch.transform_layout(B_warp, 0, "write", index_map_B)
+ sch.transform_layout(C_warp, 0, "read", index_map_C)
+
+ sch.tensorize(loop_a, ldmatrix_a_intrin)
+ sch.tensorize(loop_b, ldmatrix_b_intrin)
+ sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin)
+ sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin)
+ sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)
+
+ if not is_ampere_or_newer():
+ return None
+
+ f = tvm.build(sch.mod["main"], target="cuda", name="dense")
+
+ dev = tvm.device("cuda", 0)
+
+ if in_dtype == "float16":
+ a_np = np.random.uniform(size=(M, K)).astype("float16")
+
+ if b_transposed:
+ b_np = np.random.uniform(size=(N, K)).astype("float16")
+ c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32").transpose()).astype(
+ out_dtype
+ )
+ else:
+ b_np = np.random.uniform(size=(K, N)).astype("float16")
+ c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32")).astype(out_dtype)
+ else:
+ a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
+
+ if b_transposed:
+ b_np = np.random.randint(-128, 128, (N, K)).astype("int8")
+ c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32").transpose()).astype(
+ "int32"
+ )
+ else:
+ b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
+ c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32")).astype("int32")
+
+ a = tvm.nd.array(a_np, dev)
+ b = tvm.nd.array(b_np, dev)
+ c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev)
+
+ f(a, b, c)
+
+ if out_dtype != "float16":
+ # The numpy reference is computed with fp32 precision (otherwise too
slow).
+ # So there is non-trivial accuracy difference if TVM result is
computed with fp16 accumulation.
+ tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
+
+ return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)
+
+
[email protected]_cuda
+def test_f16f16f32_m16n16k16():
+ def index_map(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
+ )
+
+ k_inner = 16
+ in_dtype = "float16"
+ out_dtype = "float32"
+ i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128,
2, 1]
+
+ timer = run_test(
+ k_inner,
+ in_dtype,
+ out_dtype,
+ False, # b_transposed
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map,
+ index_map,
+ index_map,
+ LDMATRIX_16x16_A_INTRIN,
+ LDMATRIX_16x16_B_INTRIN,
+ MMA_f16f16f32_INTRIN,
+ MMA_fill_16x16_f32_INTRIN,
+ MMA_store_16x16_f32_global_INTRIN,
+ )
+
+ if measure_perf and timer:
+ print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
+
+ timer = run_test(
+ k_inner,
+ in_dtype,
+ out_dtype,
+ True, # b_transposed
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map,
+ index_map,
+ index_map,
+ LDMATRIX_16x16_A_INTRIN,
+ LDMATRIX_16x16_B_TRANS_INTRIN,
+ MMA_f16f16f32_TRANS_INTRIN,
+ MMA_fill_16x16_f32_INTRIN,
+ MMA_store_16x16_f32_global_INTRIN,
+ )
+
+ if measure_perf and timer:
+ print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops /
(timer().mean)))
+
+
[email protected]_cuda
+def test_f16f16f16_m16n16k16():
+ def index_map(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
+ )
+
+ k_inner = 16
+ in_dtype = "float16"
+ out_dtype = "float16"
+ i_factors, j_factors, k_factors = [16, 2, 1, 4, 2], [16, 2, 2, 1, 4],
[128, 2, 1]
+
+ timer = run_test(
+ k_inner,
+ in_dtype,
+ out_dtype,
+ False, # b_transposed
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map,
+ index_map,
+ index_map,
+ LDMATRIX_16x16_A_INTRIN,
+ LDMATRIX_16x16_B_INTRIN,
+ MMA_f16f16f16_INTRIN,
+ MMA_fill_16x16_f16_INTRIN,
+ MMA_store_16x16_f16_global_INTRIN,
+ )
+
+ if measure_perf and timer:
+ print("f16f16f16_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
+
+ timer = run_test(
+ k_inner,
+ in_dtype,
+ out_dtype,
+ True, # b_transposed
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map,
+ index_map,
+ index_map,
+ LDMATRIX_16x16_A_INTRIN,
+ LDMATRIX_16x16_B_TRANS_INTRIN,
+ MMA_f16f16f16_TRANS_INTRIN,
+ MMA_fill_16x16_f16_INTRIN,
+ MMA_store_16x16_f16_global_INTRIN,
+ )
+
+ if measure_perf and timer:
+ print("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops /
(timer().mean)))
+
+
[email protected]_cuda
+def test_i8i8i32_m16n16k32():
+ def index_map_A(i, j):
+ return (
+ i // 16,
+ j // 32,
+ *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32),
+ )
+
+ def index_map_B(i, j):
+ return (
+ i // 32,
+ j // 16,
+ *shared_32x16_to_ldmatrix_32x16_layout(i % 32, j % 16),
+ )
+
+ def index_map_C(i, j):
+ return (
+ i // 16,
+ j // 16,
+ *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
+ )
+
+ k_inner = 32
+ in_dtype = "int8"
+ out_dtype = "int32"
+ i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32,
2, 2]
+
+ timer = run_test(
+ k_inner,
+ in_dtype,
+ out_dtype,
+ False, # b_transposed
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map_A,
+ index_map_B,
+ index_map_C,
+ LDMATRIX_16x32_A_INTRIN,
+ LDMATRIX_32x16_B_INTRIN,
+ MMA_i8i8i32_INTRIN,
+ MMA_fill_16x16_i32_INTRIN,
+ MMA_store_16x16_i32_global_INTRIN,
+ )
+
+ if measure_perf and timer:
+ print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean)))
+
+ timer = run_test(
+ k_inner,
+ in_dtype,
+ out_dtype,
+ True, # b_transposed
+ i_factors,
+ j_factors,
+ k_factors,
+ index_map_A,
+ index_map_A,
+ index_map_C,
+ LDMATRIX_16x32_A_INTRIN,
+ LDMATRIX_16x32_B_TRANS_INTRIN,
+ MMA_i8i8i32_TRANS_INTRIN,
+ MMA_fill_16x16_i32_INTRIN,
+ MMA_store_16x16_i32_global_INTRIN,
+ )
+
+ if measure_perf and timer:
+ print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean)))
+
+
+if __name__ == "__main__":
+ test_f16f16f32_m16n16k16()
+ test_f16f16f16_m16n16k16()
+ test_i8i8i32_m16n16k32()