This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0274d8e1f1 [TIR] Support tensorization using ldmatrix + MMA (#11355)
0274d8e1f1 is described below

commit 0274d8e1f124cecc159abf3234251bf010784581
Author: Masahiro Masuda <[email protected]>
AuthorDate: Sat May 21 03:33:54 2022 +0900

    [TIR] Support tensorization using ldmatrix + MMA (#11355)
    
    * [TIR] Support tensorization using ldmatrix + MMA
    
    commit 3218facf100b0dfc55715acfd1cee156764129ba
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 18 14:04:56 2022 +0900
    
        some clean up
    
    commit 7a235b69dc2023b3098ed44d591edb63b20a8f4e
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 18 13:55:11 2022 +0900
    
        parameterize over storage scope in mma store intrin
    
    commit 827ea4c434c35607b241f8e0ae2efe3214ac2458
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 18 13:37:38 2022 +0900
    
        properly handle floordiv/mod in codegen
    
    commit 42d4c6f42182c9fd79566c0955f99cc82abd5144
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 18 09:53:57 2022 +0900
    
        update tuned factors for fp16
    
    commit 328d0aa36b2ea9ea1b051970d612bff82d2d20e6
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 18 08:43:30 2022 +0900
    
        all tests working
    
    commit 5e086cf5fd1404ac38f85c4bfbe692687b45a16c
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 18 07:48:43 2022 +0900
    
        add doc for mma_fill and mma_store intrin
    
    commit 4f945c4116b6d3bdc965ecb2be2229bb46dc11ab
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 18 06:39:01 2022 +0900
    
        remove tests
    
    commit df7708f7f67761d9c18f9564bc15abd50c12ac69
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue May 17 19:52:14 2022 +0900
    
        unified test
    
    commit 754c83eeb8510b31fb9652b089177f9b8e642ec0
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue May 17 19:36:24 2022 +0900
    
        clean up LowerWarpmemory
    
    commit 178c3dcee7bfa17d5d93fec02aa858dc62151670
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue May 17 19:15:04 2022 +0900
    
        Use IndexMap
    
    commit 07fb58910338c62847fd902b37801d09b8c673b0
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue May 17 17:51:44 2022 +0900
    
        remove 16x8x8 test
    
    commit 2b05b5a5470ac221d559f31a31a8e2ff753b2414
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue May 17 17:31:35 2022 +0900
    
        generate mma fill/store
    
    commit bf23fc50f0ffa99e875d9247ca66acec0c36677f
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue May 17 12:23:30 2022 +0900
    
        mma intrin generation with meta programming
    
    commit 5afb5f00afd642cb1e39872edc7965f476dcdcb7
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue May 17 05:26:14 2022 +0900
    
        ldmatrix intrin generation with meta programming
    
    commit fb62abb3424b88ec48c697e306e05889a3ac306f
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 20:30:49 2022 +0900
    
        minor
    
    commit 5a80adce24e84d3ec6bf931b60cb9c730d243394
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 19:55:57 2022 +0900
    
        revert some change
    
    commit e599a55078ee75f2480a721098341812db58cf6f
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 19:54:18 2022 +0900
    
        remove obsolete files
    
    commit 4b13b85ff91d0d592a7e0c01924e0b49b82f35a8
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 19:51:21 2022 +0900
    
        wip
    
    commit 848de63455539e25cd0d43e5a65fd048636ef0f7
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 19:44:29 2022 +0900
    
        wip
    
    commit b35bff97ed10c22559e2164eb7538db0f711ce7e
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 19:31:18 2022 +0900
    
        update parse error msg
    
    commit ad9b053ef865b1f91f03d7b15ed7aae3420ee213
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 19:26:51 2022 +0900
    
        fix for avoiding Buffer.vload(...) case
    
    commit 54c686443e370edbfae860d0809b1b6182d26414
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 18:59:55 2022 +0900
    
        wip
    
    commit 078060fe28d22f1db5f07b1c382dee438f02df60
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 18:57:34 2022 +0900
    
        wip
    
    commit 576f8415e65e0e8a8a7808885e219b3b53867950
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 18:52:15 2022 +0900
    
        wip
    
    commit 12a376ae2f44aa6660121e64e0358f2866624f7f
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 17:54:58 2022 +0900
    
        Squashed commit of the following:
    
        commit 48eef4981d1a55aaf3b0ac935f2a10347cb1ac2d
        Author: Masahiro Masuda <[email protected]>
        Date:   Mon May 16 17:40:48 2022 +0900
    
            more comment
    
        commit 8f67fc87038834e9f7e2c5cd3dfe61fabf442206
        Author: Masahiro Masuda <[email protected]>
        Date:   Mon May 16 17:11:27 2022 +0900
    
            update test
    
        commit ad85036621c005b733763e67ceffae39c356ec99
        Author: Masahiro Masuda <[email protected]>
        Date:   Mon May 16 16:54:01 2022 +0900
    
            add test
    
        commit 4a5dc3ffd5d0bb4a1700e57897c9e0f26e3d2a88
        Author: Masahiro Masuda <[email protected]>
        Date:   Mon May 16 16:40:47 2022 +0900
    
            [TVMScript] Support function call to help construct AST
    
    commit 76c1bcf0ade45d7433a0066236add8372b1cc547
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon May 16 16:30:07 2022 +0900
    
        simplify iterator in layout transform
    
    commit 936280324ea2c91429a6a85a1b8ee89c7b825928
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat May 14 11:31:39 2022 +0900
    
        remove obsolet files
    
    commit 2e119b422d72d726d5f2bd20fe48a1e62fcb0510
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat May 14 10:43:59 2022 +0900
    
        calculate mma store dst index using inverse affine map
    
    commit 9489434ee52b546e2abb2ab28173eefd51525ba4
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat May 14 10:01:12 2022 +0900
    
        simplify store
    
    commit 1adcb77b8bba8e5d91080fe6cbfc7add7f4365c2
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat May 14 09:43:40 2022 +0900
    
        simplified fill
    
    commit 7b13c736d23e0eac94137aa918101d788e60d4f3
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat May 14 09:22:17 2022 +0900
    
        simplify intrin desc using index map function
    
    commit bcf212dda0f94c51f55c48921f61d92fd3b83777
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat May 14 07:16:42 2022 +0900
    
        seems to work
    
    commit dd8ccf9ec2e48100158152e5d4590d141424e2e2
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat May 14 07:11:57 2022 +0900
    
        poking with the parser
    
    commit 596582cbfbd08ebe23ea71aaf7a447472415ccd1
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri May 13 20:04:59 2022 +0900
    
        16x8x32 4k trans working
    
    commit 273f89a8a6ac34f7c79147563922d34d44bffd08
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri May 13 19:52:13 2022 +0900
    
        add 16x8x16 fp16 trans
    
    commit 8e2066cc4c6e86616bc9751324e63ba81a3b02af
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri May 13 19:32:37 2022 +0900
    
        16x8x16 4k trans working
    
    commit c2d0744051733e94f840d4517bcee9ca5d444c75
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri May 13 19:25:52 2022 +0900
    
        16x8x16 trans working
    
    commit c2e314cdda1c3a931781e51a863901ea178dffec
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri May 13 16:19:32 2022 +0900
    
        tuned int8 4k, 91 TOPS
    
    commit 94d9d965f19ff1a2ebdd342079ef420fb537b16a
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri May 13 15:59:33 2022 +0900
    
        int8 4k tune working
    
    commit 3ca8ca02593aff7540c9655aa831348246171752
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri May 13 08:43:57 2022 +0900
    
        mma 16x8x32 int8 working with ldmatrix b workaround
    
    commit 54f1cb731d4b42a6cbc08baf144e74646400eef5
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri May 13 18:23:27 2022 +0900
    
        wip
    
    commit 9d2844db602dc65af4dbd06a73fdd815f486b8b9
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri May 13 16:38:53 2022 +0900
    
        test tensorize without layout transform
    
    commit 86ee6dabc801aeb8d6917bec6de97b42025dbdd1
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri May 13 15:15:34 2022 +0900
    
        int8 4k tensorize works
    
    commit 39f9e32c9a64222c91daba2c32969b27207a31d2
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri May 13 12:44:39 2022 +0900
    
        begin int8 4k tune
    
    commit 6fa91e55b5ab2ba0f901d0d35be1b2fb3ab092b0
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu May 12 18:53:20 2022 +0900
    
        try fix ldmatrix b for int8
    
    commit 7a962cddc4799fa3df0c0fdf3c056146d3f2cbdf
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu May 12 18:28:34 2022 +0900
    
        fixed warp_coeff
    
    commit a0afb5698f307382147a38819e004a2db7f554b1
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu May 12 12:20:01 2022 +0900
    
        wip
    
    commit f70ccd09b07d5325454ffdc39a7619ea84aa7e06
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu May 12 12:09:57 2022 +0900
    
        int8 tensorize working
    
    commit 20321fa4674dabc78fe55b5e0e2876c35b245d21
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu May 12 07:06:22 2022 +0900
    
        starting 16x8x32 int8
    
    commit 441fd193c59cdc436d87ab35896cbb8c779ddf35
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu May 12 05:50:46 2022 +0900
    
        adding fp16 accum case
    
    commit c9d40b69b1b57bfaddffba09ea07624ae90ee465
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 11 17:04:29 2022 +0900
    
        clean up
    
    commit 5b2d48635e762c77c824d1c259ac8bcbcc949421
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 11 16:38:19 2022 +0900
    
        16x8x16 4k tune working
    
    commit c3cb170d85600d03da5c3f4cda03552208ca0b8c
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 11 16:20:27 2022 +0900
    
        tensoriz fixed
    
    commit 68039b081efcdd6aea1d132940b3745f50164974
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 11 15:55:25 2022 +0900
    
        begin 16x8x16 4k tune
    
    commit ced5d8d980cc267d4735957c25cb60d71ae977d2
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 11 15:50:11 2022 +0900
    
        16x8x16 worked
    
    commit 3d2c90d77c1bb2df2193e9af6cbaa2bd927a26d8
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 11 15:47:26 2022 +0900
    
        fix
    
    commit 403050b03ad6b4f0ee8d45088ffb324727bbae48
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 11 15:45:10 2022 +0900
    
        add 16x8x16 test
    
    commit 18e8d73661c99cd1c83021063b41a457afcb1638
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 11 06:50:32 2022 +0900
    
        fixed mma store codegen for 16x8x16
    
    commit ec81250561195705122bccb9a2372f71de68121f
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 11 04:25:25 2022 +0900
    
        add 16x8x16 mma store codegen
    
    commit e08df2a62a4809bcd39782949283c16e7703aa5c
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 11 03:47:47 2022 +0900
    
        tensorized C_warp init
    
    commit ae0678918929c1ceec73f2039467040c5bb7823b
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed May 11 03:06:06 2022 +0900
    
        mma store codegen working
    
    commit deb4d6646cc93d4cdb4f2560ce723bee4d86e144
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue May 10 19:22:57 2022 +0900
    
        update lower warp memory
    
    commit 71fe5fe465300705fa94f9544a2e1a5070de6e0d
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue May 10 09:01:42 2022 +0900
    
        tensorizing mma store
    
    commit e80a1f148c47f2a3fac2363a733d8d4e2a2631d0
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Apr 28 19:54:08 2022 +0900
    
        clean up
    
    commit a9640f4b7c3c9f22b87ca74a61003438dfd8f992
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Apr 28 19:40:55 2022 +0900
    
        add tunable 4k test, 36 TFLOPS
    
    commit b9f7eae7041d1a9b3e434c331c874e8347e89dc4
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Apr 28 18:01:08 2022 +0900
    
        fixed bug in LowerWarpMemory index splitting for ldmatrix
    
    commit 00df30823f874910ed1ec1f74718100311764234
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Apr 27 07:58:17 2022 +0900
    
        fixed missing reverse_compute_at
    
    commit 93f9fe7e5f7ad16c8d0e6240c16c0281a0e97dec
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Apr 27 06:55:12 2022 +0900
    
        add 4k test
    
    commit 3689ef712aa4b282a4818fa2fa2e7e349c3a5eec
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Apr 27 06:54:09 2022 +0900
    
        temp disable high dim base indices check in tensorize
    
    commit 0c859c4f385ba0b6f9477b569b80cee80b5b7282
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Apr 26 19:18:23 2022 +0900
    
        clean up
    
    commit f6aadbfcfbd73c1667a6de7aedc5894232b8e750
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Apr 26 19:13:09 2022 +0900
    
        Add 16x8x8 MMA + LDMatrix test
    
    commit 4cf6b20c6ca415e967ab58d80e4a77c701ad7255
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Apr 26 18:04:17 2022 +0900
    
        testing 16x8x8 ldmatrix tensoriation
    
    * set measure_perf to False
    
    * add requires_gpu decorator in tests, always test build on non-ampere
    
    * skip cuda compile on old gpu
---
 include/tvm/tir/builtin.h                          |  27 ++
 python/tvm/tir/tensor_intrin/__init__.py           |   1 +
 python/tvm/tir/tensor_intrin/cuda.py               | 469 +++++++++++++++++++++
 src/target/source/codegen_cuda.cc                  |  76 +++-
 src/tir/op/builtin.cc                              |   6 +
 src/tir/transforms/lower_warp_memory.cc            |  45 +-
 .../test_tir_schedule_tensorize_ldmatrix_mma.py    | 422 ++++++++++++++++++
 7 files changed, 1042 insertions(+), 4 deletions(-)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index f33432645c..5fc42392c3 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -651,6 +651,33 @@ TVM_DLL const Op& ptx_cp_async();
 TVM_DLL const Op& ptx_commit_group();
 TVM_DLL const Op& ptx_wait_group();
 
+/*!
+ * \brief tvm intrinsic for storing the result of PTX MMA into a destination 
pointer.
+ *        For example, if each thread in a warp of size 32 has 4 elements from 
the result of
+ *        m16xn8xk16 MMA in its registers, this intrinsic can be used to store 
the result in a
+ *        16x8 region in shared or global memory.
+ *
+ *        There is no real PTX instruction that does that, but we want to hide 
details of
+ *        complex index manipulation behind this intrinsic to simplify TIR 
lowering passes (e.g.
+ *        LowerWarpMemory).
+ *
+ * void mma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr 
src_offset, Var dst_stride);
+ */
+TVM_DLL const Op& mma_store();
+
+/*!
+ * \brief tvm intrinsic for zero-initalizing an MMA accumulation registor.
+ *        For example, if each thread in a warp of size 32 has 8 elements from 
the A matrix in
+ *        m16xn8xk16 MMA in its registers, this intrinsic can be used to 
zero-initialize its
+ *        4 accumulation registers.
+ *
+ *        There is no real PTX instruction that does that, but we introduce 
this intrinsic for the
+ *        same reason as mma_store above.
+ *
+ * void mma_fill(IntImm local_size, Var local_ptr, Expr offset);
+ */
+TVM_DLL const Op& mma_fill();
+
 // TODO(tvm-team) replace the usage of the vector operations by Shuffle.
 /*!
  * \brief Get the high level half of the vector
diff --git a/python/tvm/tir/tensor_intrin/__init__.py 
b/python/tvm/tir/tensor_intrin/__init__.py
index 4115c3b900..a3b47ff6d5 100644
--- a/python/tvm/tir/tensor_intrin/__init__.py
+++ b/python/tvm/tir/tensor_intrin/__init__.py
@@ -20,3 +20,4 @@ from .x86 import *
 from .arm_cpu import *
 from .dot_product_common import *
 from .rocm import *
+from .cuda import *
diff --git a/python/tvm/tir/tensor_intrin/cuda.py 
b/python/tvm/tir/tensor_intrin/cuda.py
new file mode 100644
index 0000000000..853a377354
--- /dev/null
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -0,0 +1,469 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,missing-function-docstring
+"""Intrinsics for tensorization on NVIDIA GPU."""
+from tvm.script import tir as T
+from .. import IntImm, Cast
+from ..._ffi import register_func
+from ...runtime import convert
+from .. import TensorIntrin
+
+
+def shared_16x16_to_ldmatrix_32x8_layout(i, j):
+    thread_id = 4 * (i % 8) + (j % 8) // 2
+    return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
+
+
+def shared_16x32_to_ldmatrix_32x16_layout(i, j):
+    thread_id = 4 * (i % 8) + (j % 16) // 4
+    return thread_id, 8 * (j // 16) + (i // 8) * 4 + j % 4
+
+
+def shared_32x16_to_ldmatrix_32x16_layout(i, j):
+    thread_id = (i % 4) + 4 * (j % 8)
+    return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
+
+
+@register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
+def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind):
+    i, j = ind[0], ind[1]
+    thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j)
+    return convert([thread_id, local_id])
+
+
+lift = convert
+
+M_DIM = 16
+N_DIM = 16
+WARP_SIZE = 32
+HALF_WARP = WARP_SIZE // 2
+HALF_WARP_expr = lift(HALF_WARP)
+
+
+def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
+    local_size = (M_DIM * k_dim) // WARP_SIZE
+    shared_offset = None
+    index_map = None
+
+    if transposed:
+        assert is_b, "Transposed A matrix not supported"
+
+    ldmatrix_col_major = is_b and not transposed
+
+    if k_dim == 16:
+        assert dtype == "float16"
+
+        index_map = shared_16x16_to_ldmatrix_32x8_layout
+
+        if transposed:
+            shared_offset = (
+                lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
+                + stride * (tx % 8)
+                + 8 * ((tx % HALF_WARP_expr) // 8)
+            )
+        else:
+            shared_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) 
+ 8 * (
+                tx // HALF_WARP_expr
+            )
+    else:
+        assert (
+            k_dim == 32 and dtype == "int8"
+        ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now"
+
+        if ldmatrix_col_major:
+            index_map = shared_32x16_to_ldmatrix_32x16_layout
+            # A dummy offset, ldmatrix cannot be used for int8 + trans case.
+            # We still use the ldmatrix intrinsic, but lower it to a manual 
loop in the codegen.
+            # Only the stride information is required.
+            shared_offset = lambda _, stride: stride
+        elif is_b and transposed:
+            index_map = shared_16x32_to_ldmatrix_32x16_layout
+            shared_offset = (
+                lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
+                + (tx % 8) * stride
+                + 16 * ((tx % HALF_WARP_expr) // 8)
+            )
+        else:
+            index_map = shared_16x32_to_ldmatrix_32x16_layout
+            shared_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx 
// 16)
+
+    assert index_map and shared_offset
+
+    if is_b and not transposed:
+        row_dim = k_dim
+        col_dim = M_DIM
+    else:
+        row_dim = M_DIM
+        col_dim = k_dim
+
+    shmem_shape = (row_dim, col_dim)
+
+    @T.prim_func
+    def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
+        shared = T.match_buffer(
+            shared_handle, shmem_shape, dtype, align=128, offset_factor=16, 
scope="shared"
+        )
+        warp = T.match_buffer(
+            warp_handle, (WARP_SIZE, local_size), dtype, align=128, 
offset_factor=16, scope="warp"
+        )
+
+        with T.block("root"):
+            T.reads(shared[0:row_dim, 0:col_dim])
+            T.writes(warp[0:WARP_SIZE, 0:local_size])
+
+            for ax0, ax1 in T.grid(row_dim, col_dim):
+                with T.block("shared_warp"):
+                    v0, v1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(shared[v0, v1])
+
+                    thread_id, local_id = index_map(v0, v1)
+                    T.writes(warp[thread_id, local_id])
+                    warp[thread_id, local_id] = shared[v0, v1]
+
+    @T.prim_func
+    def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
+        s0 = T.var("int32")
+        s1 = T.var("int32")
+        shared = T.match_buffer(
+            shared_handle,
+            shmem_shape,
+            dtype,
+            align=128,
+            offset_factor=16,
+            scope="shared",
+            strides=[s0, s1],
+        )
+        warp = T.match_buffer(
+            warp_handle, (WARP_SIZE, local_size), dtype, align=128, 
offset_factor=16, scope="warp"
+        )
+
+        with T.block("root"):
+            T.reads(shared[0:row_dim, 0:col_dim])
+            T.writes(warp[0:WARP_SIZE, 0:local_size])
+            tx = T.env_thread("threadIdx.x")
+            T.launch_thread(tx, WARP_SIZE)
+
+            T.evaluate(
+                T.ptx_ldmatrix(
+                    ldmatrix_col_major,
+                    4,  # Always load 4 matrices
+                    ".b16",
+                    warp.data,
+                    warp.elem_offset + lift(local_size) * tx,
+                    shared.access_ptr("r"),
+                    shared_offset(tx, s0),
+                    dtype=dtype,
+                )
+            )
+
+    return ldmatrix_desc, ldmatrix_impl
+
+
+def get_mma_intrin(k_dim, out_dtype, b_transposed):
+    local_size = (M_DIM * k_dim) // WARP_SIZE
+    local_size_out = (M_DIM * N_DIM) // 32
+
+    index_map_C = shared_16x16_to_ldmatrix_32x8_layout
+
+    if k_dim == 16:
+        index_map_A = shared_16x16_to_ldmatrix_32x8_layout
+        index_map_B = shared_16x16_to_ldmatrix_32x8_layout
+        mma_prefix = "m16n8k16"
+    elif k_dim == 32 and b_transposed:
+        index_map_A = index_map_B = shared_16x32_to_ldmatrix_32x16_layout
+        mma_prefix = "m16n8k32"
+    elif k_dim == 32 and not b_transposed:
+        index_map_A = shared_16x32_to_ldmatrix_32x16_layout
+        index_map_B = shared_32x16_to_ldmatrix_32x16_layout
+        mma_prefix = "m16n8k32"
+    else:
+        assert False
+
+    out_dtype_abbrv = {"float16": "fp16", "float32": "fp32", "int32": 
"int32"}[out_dtype]
+
+    if out_dtype in ["float16", "float32"]:
+        in_dtype = "float16"
+        in_dtype_abbrv = "fp16"
+    else:
+        in_dtype = "int8"
+        in_dtype_abbrv = "int8"
+
+    def maybe_cast(v):
+        if out_dtype in ["float32", "int32"]:
+            return Cast(out_dtype, v)
+        return v
+
+    def maybe_swap(i, j):
+        if b_transposed:
+            return j, i
+        return i, j
+
+    @T.prim_func
+    def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(
+            a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, 
scope="warp"
+        )
+        B = T.match_buffer(
+            b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, 
scope="warp"
+        )
+        C = T.match_buffer(
+            c, (WARP_SIZE, local_size_out), out_dtype, align=128, 
offset_factor=16, scope="warp"
+        )
+
+        with T.block("root"):
+            T.reads(
+                C[0:WARP_SIZE, 0:local_size_out],
+                A[0:WARP_SIZE, 0:local_size],
+                B[0:WARP_SIZE, 0:local_size],
+            )
+            T.writes(C[0:WARP_SIZE, 0:local_size_out])
+
+            for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
+                with T.block("C"):
+                    i, j, k = T.axis.remap("SSR", [i, j, k])
+                    b_row_ind, b_col_ind = maybe_swap(k, j)
+
+                    thread_id_C, local_id_C = index_map_C(i, j)
+                    thread_id_A, local_id_A = index_map_A(i, k)
+                    thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind)
+
+                    T.reads(
+                        C[thread_id_C, local_id_C],
+                        A[thread_id_A, local_id_A],
+                        B[thread_id_B, local_id_B],
+                    )
+                    T.writes(C[thread_id_C, local_id_C])
+
+                    C[thread_id_C, local_id_C] += maybe_cast(
+                        A[thread_id_A, local_id_A]
+                    ) * maybe_cast(B[thread_id_B, local_id_B])
+
+    @T.prim_func
+    def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(
+            a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, 
scope="warp"
+        )
+        B = T.match_buffer(
+            b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, 
scope="warp"
+        )
+        C = T.match_buffer(
+            c, (WARP_SIZE, local_size_out), out_dtype, align=128, 
offset_factor=16, scope="warp"
+        )
+
+        with T.block("root"):
+            T.reads(
+                C[0:WARP_SIZE, 0:local_size_out],
+                A[0:WARP_SIZE, 0:local_size],
+                B[0:WARP_SIZE, 0:local_size],
+            )
+            T.writes(C[0:WARP_SIZE, 0:local_size_out])
+            tx = T.env_thread("threadIdx.x")
+            T.launch_thread(tx, WARP_SIZE)
+
+            T.evaluate(
+                T.ptx_mma(
+                    mma_prefix,
+                    "row",
+                    "col",
+                    in_dtype_abbrv,
+                    in_dtype_abbrv,
+                    out_dtype_abbrv,
+                    A.data,
+                    A.elem_offset + tx * lift(local_size),
+                    B.data,
+                    B.elem_offset + tx * lift(local_size),
+                    C.data,
+                    C.elem_offset + tx * lift(local_size_out),
+                    False,
+                    dtype=out_dtype,
+                )
+            )
+
+            T.evaluate(
+                T.ptx_mma(
+                    mma_prefix,
+                    "row",
+                    "col",
+                    in_dtype_abbrv,
+                    in_dtype_abbrv,
+                    out_dtype_abbrv,
+                    A.data,
+                    A.elem_offset + tx * lift(local_size),
+                    B.data,
+                    B.elem_offset + tx * lift(local_size) + lift(local_size) 
// 2,
+                    C.data,
+                    C.elem_offset + tx * lift(local_size_out) + 
lift(local_size_out) // 2,
+                    False,
+                    dtype=out_dtype,
+                )
+            )
+
+    return mma_sync_desc, mma_sync_impl
+
+
+def get_mma_fill_intrin(dtype, local_size):
+    zero = IntImm("int32", 0).astype(dtype)
+
+    # Assume M = N = 16
+    index_map = shared_16x16_to_ldmatrix_32x8_layout
+
+    @T.prim_func
+    def mma_fill_desc(a: T.handle) -> None:
+        C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, 
scope="warp")
+
+        with T.block("root"):
+            T.reads()
+            T.writes(C_warp[0:WARP_SIZE, 0:local_size])
+            for i0, i1 in T.grid(M_DIM, N_DIM):
+                with T.block("C_warp"):
+                    i, j = T.axis.remap("SS", [i0, i1])
+                    thread_id, local_id = index_map(i, j)
+                    T.reads()
+                    T.writes(C_warp[thread_id, local_id])
+                    C_warp[thread_id, local_id] = zero
+
+    @T.prim_func
+    def mma_fill_impl(a: T.handle) -> None:
+        C_warp = T.match_buffer(
+            a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", 
offset_factor=1
+        )
+
+        with T.block("root"):
+            T.reads()
+            T.writes(C_warp[0:WARP_SIZE, 0:local_size])
+            tx = T.env_thread("threadIdx.x")
+            T.launch_thread(tx, WARP_SIZE)
+
+            T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset, 
dtype=dtype))
+
+    return mma_fill_desc, mma_fill_impl
+
+
+def get_mma_store_intrin(dtype, local_size, scope="global"):
+    # Assume M = N = 16
+    index_map = shared_16x16_to_ldmatrix_32x8_layout
+
+    @T.prim_func
+    def mma_store_desc(a: T.handle, c: T.handle) -> None:
+        C_warp = T.match_buffer(a, [WARP_SIZE, local_size], dtype=dtype, 
scope="warp")
+        C = T.match_buffer(c, [M_DIM, N_DIM], dtype=dtype, scope=scope)
+
+        with T.block("root"):
+            T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+            T.writes(C[0:M_DIM, 0:N_DIM])
+            for i0, i1 in T.grid(M_DIM, N_DIM):
+                with T.block("C_warp"):
+                    v0, v1 = T.axis.remap("SS", [i0, i1])
+                    thread_id, local_id = index_map(v0, v1)
+                    T.reads(C_warp[thread_id, local_id])
+                    T.writes(C[v0, v1])
+                    C[v0, v1] = C_warp[thread_id, local_id]
+
+    @T.prim_func
+    def mma_store_impl(a: T.handle, c: T.handle) -> None:
+        s0 = T.var("int32")
+        s1 = T.var("int32")
+
+        C_warp = T.match_buffer(
+            a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", 
offset_factor=1
+        )
+        C = T.match_buffer(
+            c, [M_DIM, N_DIM], dtype=dtype, scope="global", offset_factor=1, 
strides=[s0, s1]
+        )
+
+        with T.block("root"):
+            T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+            T.writes(C[0:M_DIM, 0:N_DIM])
+            tx = T.env_thread("threadIdx.x")
+            T.launch_thread(tx, WARP_SIZE)
+
+            T.evaluate(
+                T.mma_store(
+                    M_DIM,
+                    N_DIM,
+                    C.access_ptr("w"),
+                    C_warp.data,
+                    C_warp.elem_offset,
+                    s0,
+                    dtype=dtype,
+                )
+            )
+
+    return mma_store_desc, mma_store_impl
+
+
+LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a"
+TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, 
"float16", False, False))
+
+LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b"
+TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, 
"float16", True, False))
+
+LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans"
+TensorIntrin.register(
+    LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True, 
True)
+)
+
+LDMATRIX_16x32_A_INTRIN = "mma.ldmatrix_16x32_a"
+TensorIntrin.register(LDMATRIX_16x32_A_INTRIN, *get_ldmatrix_intrin(32, 
"int8", False, False))
+
+LDMATRIX_32x16_B_INTRIN = "mma.ldmatrix_32x16_b"
+TensorIntrin.register(LDMATRIX_32x16_B_INTRIN, *get_ldmatrix_intrin(32, 
"int8", True, False))
+
+LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans"
+TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, 
"int8", True, True))
+
+MMA_f16f16f32_INTRIN = "mma_f16f16f32"
+TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", 
False))
+
+MMA_f16f16f32_TRANS_INTRIN = "mma_f16f16f32_trans"
+TensorIntrin.register(MMA_f16f16f32_TRANS_INTRIN, *get_mma_intrin(16, 
"float32", True))
+
+MMA_f16f16f16_INTRIN = "mma_f16f16f16"
+TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", 
False))
+
+MMA_f16f16f16_TRANS_INTRIN = "mma_f16f16f16_trans"
+TensorIntrin.register(MMA_f16f16f16_TRANS_INTRIN, *get_mma_intrin(16, 
"float16", True))
+
+MMA_i8i8i32_INTRIN = "mma_i8i8i32"
+TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False))
+
+MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans"
+TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32", 
True))
+
+MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32"
+TensorIntrin.register(MMA_fill_16x16_f32_INTRIN, 
*get_mma_fill_intrin("float32", 8))
+
+MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16"
+TensorIntrin.register(MMA_fill_16x16_f16_INTRIN, 
*get_mma_fill_intrin("float16", 8))
+
+MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32"
+TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 
8))
+
+MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_"
+TensorIntrin.register(
+    MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8, 
"global")
+)
+
+MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_"
+TensorIntrin.register(
+    MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8, 
"global")
+)
+
+MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_"
+TensorIntrin.register(
+    MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, 
"global")
+)
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 7459d4c250..616e75f2e7 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -25,6 +25,7 @@
 
 #include <tvm/arith/analyzer.h>
 #include <tvm/runtime/registry.h>
+#include <tvm/tir/index_map.h>
 #include <tvm/tir/stmt_functor.h>
 
 #include <cmath>
@@ -818,9 +819,78 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, 
std::ostream& os) {
     std::string local_ptr = this->PrintExpr(op->args[3]);
     std::string local_elem_offset = this->PrintExpr(op->args[4]);
     std::string smem_ptr = this->PrintExpr(op->args[5]);
-    std::string smem_elem_offset = this->PrintExpr(op->args[6]);
-    this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, 
local_elem_offset,
-                                            smem_ptr, smem_elem_offset);
+    if (trans && op->dtype.bits() == 8) {
+      // Since ldmatrix assumes that a matrix element is 16 bit, it cannot 
properly transpose an
+      // int8 matrix.
+      std::string smem_stride = this->PrintExpr(op->args[6]);
+      ICHECK(num == 4);
+      os << "for (int i = 0; i < 16; ++i) {\n";
+      os << local_ptr << "[" + local_elem_offset + " + i] = " << smem_ptr
+         << "[(i % 8) / 4 * " + smem_stride + " * 16 + (threadIdx.x % 4) * 4 * 
" + smem_stride +
+                "+ (i % 4) * " + smem_stride + " + threadIdx.x / 4 +  (i / 8) 
* 8];\n";
+      os << "}\n";
+    } else {
+      std::string smem_elem_offset = this->PrintExpr(op->args[6]);
+      this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, 
local_elem_offset,
+                                              smem_ptr, smem_elem_offset);
+    }
+  } else if (op->op.same_as(builtin::mma_store())) {
+    int m = Downcast<Integer>(op->args[0])->value;
+    int n = Downcast<Integer>(op->args[1])->value;
+    std::string dst = this->PrintExpr(op->args[2]);
+    std::string src = this->PrintExpr(op->args[3]);
+    std::string src_offset = this->PrintExpr(op->args[4]);
+    PrimExpr stride = op->args[5];
+
+    ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for 
now";
+
+    // Each thread in a warp holds a certain number of elements of an MMA 
output.
+    // For example, if we compute a 16x16 tile using MMA, each thread holds 8 
elements
+    // in its registers. So conceptually, a warp memory is organized as a 32x8 
block.
+    // A map from a 16x16 tile to a 32x8 block of memory is specified by the 
index map below.
+
+    // To store the 32x8 output back to a 16x16 tile in shared or global 
memory, we invert this map
+    // to determine the output location for each 8 element.
+
+    const auto* index_map_func =
+        
runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout");
+    ICHECK(index_map_func);
+
+    auto inverse_index_map =
+        IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, 
n)});
+    auto indices_16x16 = inverse_index_map->final_indices;
+
+    // "//" and "%" in the index map are translated to FloorDiv/Mod, but the 
plain Div/Mod are fine.
+    // FloorDiv/Mod are supposed to be lowered before they reach codegen, so 
manually replace them
+    // to the plain ones here.
+    class LowerFloorDivMod : public ExprMutator {
+     public:
+      PrimExpr VisitExpr_(const FloorDivNode* op) {
+        return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b));
+      }
+      PrimExpr VisitExpr_(const FloorModNode* op) {
+        return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b));
+      }
+    };
+
+    auto dst_ind = LowerFloorDivMod()(indices_16x16[0] * stride + 
indices_16x16[1]);
+
+    var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x";
+    var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id";
+
+    os << "for (int local_id = 0; local_id < 8; ++local_id) {\n";
+    os << dst << "[" + this->PrintExpr(dst_ind) + "]"
+       << " = " << src << "[" << src_offset << " + local_id];\n";
+    os << "}\n";
+
+  } else if (op->op.same_as(builtin::mma_fill())) {
+    std::string num_elem = this->PrintExpr(op->args[0]);
+    std::string dst = this->PrintExpr(op->args[1]);
+    std::string dst_offset = this->PrintExpr(op->args[2]);
+
+    os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
+    os << dst << "[" << dst_offset << " + i] = 0.0;";
+    os << "}\n";
   } else if (op->op.same_as(builtin::ptx_cp_async())) {
     std::string dst = this->PrintExpr(op->args[0]);
     std::string dst_offset = this->PrintExpr(op->args[1]);
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 0415d1bbec..1871a3d7bf 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -256,6 +256,12 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group)
 TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
 
+TIR_DEFINE_BUILTIN_FUNC(mma_store).set_attr<TCallEffectKind>("TCallEffectKind",
+                                                             
Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(mma_fill).set_attr<TCallEffectKind>("TCallEffectKind",
+                                                            
Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kPure));
 
diff --git a/src/tir/transforms/lower_warp_memory.cc 
b/src/tir/transforms/lower_warp_memory.cc
index 40971114d4..d8250cd098 100644
--- a/src/tir/transforms/lower_warp_memory.cc
+++ b/src/tir/transforms/lower_warp_memory.cc
@@ -101,7 +101,7 @@ namespace tir {
 
 // Visitor to find m in pattern
 // store warp_mem[m * warp_index + (width * m) * y + x]
-class WarpStoreCoeffFinder : private StmtVisitor {
+class WarpStoreCoeffFinder : private StmtExprVisitor {
  public:
   WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* 
analyzer)
       : buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {}
@@ -113,6 +113,18 @@ class WarpStoreCoeffFinder : private StmtVisitor {
 
  private:
   /// Visitor implementation
+  void VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as<VarNode>() 
== buffer_) {
+      UpdatePattern(op->args[4]);
+    } else if (op->op.same_as(builtin::mma_fill()) && 
op->args[1].as<VarNode>() == buffer_) {
+      auto* local_size = op->args[0].as<IntImmNode>();
+      ICHECK(local_size) << "Integer expected for the first argument of 
mma_fill";
+      warp_coeff_ = local_size->value;
+    }
+
+    StmtExprVisitor::VisitExpr_(op);
+  }
+
   void VisitStmt_(const StoreNode* op) final {
     LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use 
BufferStoreNode instead.";
   }
@@ -245,6 +257,37 @@ class WarpAccessRewriter : protected StmtExprMutator {
   }
 
  protected:
+  PrimExpr RewriteIndicesAt(const CallNode* op, const std::vector<int>& 
indices) {
+    Array<PrimExpr> new_args = op->args;
+    for (int i : indices) {
+      if (op->args[i].get() == buffer_) {
+        PrimExpr local_index = SplitIndexByGroup(op->args[i + 1]).first;
+        new_args.Set(i + 1, local_index);
+      }
+    }
+    return Call(op->dtype, op->op, new_args);
+  }
+
+  PrimExpr VisitExpr_(const CallNode* op) override {
+    if (op->op.same_as(builtin::ptx_mma())) {
+      return RewriteIndicesAt(op, {6, 8, 10});
+    }
+
+    if (op->op.same_as(builtin::ptx_ldmatrix())) {
+      return RewriteIndicesAt(op, {3});
+    }
+
+    if (op->op.same_as(builtin::mma_store())) {
+      return RewriteIndicesAt(op, {3});
+    }
+
+    if (op->op.same_as(builtin::mma_fill())) {
+      return RewriteIndicesAt(op, {1});
+    }
+
+    return StmtExprMutator::VisitExpr_(op);
+  }
+
   PrimExpr VisitExpr_(const VarNode* op) override {
     ICHECK(op != buffer_) << "Cannot access address of warp memory directly";
     return StmtExprMutator::VisitExpr_(op);
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py 
b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
new file mode 100644
index 0000000000..67e8ae0ad8
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
@@ -0,0 +1,422 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm
+from tvm import te
+from tvm.tir.tensor_intrin.cuda import (
+    LDMATRIX_16x16_A_INTRIN,
+    LDMATRIX_16x16_B_INTRIN,
+    LDMATRIX_16x16_B_TRANS_INTRIN,
+    LDMATRIX_16x32_A_INTRIN,
+    LDMATRIX_32x16_B_INTRIN,
+    LDMATRIX_16x32_B_TRANS_INTRIN,
+    MMA_f16f16f32_INTRIN,
+    MMA_f16f16f32_TRANS_INTRIN,
+    MMA_f16f16f16_INTRIN,
+    MMA_f16f16f16_TRANS_INTRIN,
+    MMA_i8i8i32_INTRIN,
+    MMA_i8i8i32_TRANS_INTRIN,
+    MMA_fill_16x16_f32_INTRIN,
+    MMA_fill_16x16_f16_INTRIN,
+    MMA_fill_16x16_i32_INTRIN,
+    MMA_store_16x16_f32_global_INTRIN,
+    MMA_store_16x16_f16_global_INTRIN,
+    MMA_store_16x16_i32_global_INTRIN,
+    shared_16x16_to_ldmatrix_32x8_layout,
+    shared_32x16_to_ldmatrix_32x16_layout,
+    shared_16x32_to_ldmatrix_32x16_layout,
+)
+import tvm.testing
+import numpy as np
+
+
+M = 4096
+N = 4096
+K = 4096
+measure_perf = False
+gflops = (N * M * K) * 2 / 1e9
+
+
+def matmul(m, n, k, in_dtype, out_dtype, b_transposed):
+    b_shape = (n, k) if b_transposed else (k, n)
+    a = te.placeholder((m, k), name="A", dtype=in_dtype)
+    b = te.placeholder(b_shape, name="B", dtype=in_dtype)
+    k = te.reduce_axis((0, k), name="k")
+
+    def maybe_cast(v):
+        if in_dtype != out_dtype:
+            return tvm.tir.Cast(out_dtype, v)
+        return v
+
+    def maybe_swap(i, j):
+        if b_transposed:
+            return j, i
+        return i, j
+
+    c = te.compute(
+        (m, n),
+        lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k, 
j)]), axis=[k]),
+        name="C",
+    )
+    return (a, b, c)
+
+
+def is_ampere_or_newer():
+    arch = tvm.contrib.nvcc.get_target_compute_version()
+    major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
+    return major >= 8
+
+
+def run_test(
+    k_inner,
+    in_dtype,
+    out_dtype,
+    b_transposed,
+    i_factors,
+    j_factors,
+    k_factors,
+    index_map_A,
+    index_map_B,
+    index_map_C,
+    ldmatrix_a_intrin,
+    ldmatrix_b_intrin,
+    mma_intrin,
+    mma_fill_intrin,
+    mma_store_intrin,
+):
+    workload = te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, 
b_transposed))
+    ir_module = tvm.IRModule({"main": workload})
+    sch = tvm.tir.Schedule(ir_module)
+
+    block = sch.get_block("C")
+    i, j, k = sch.get_loops(block)
+    i, i_tc = sch.split(i, factors=[None, 16])
+    j, j_tc = sch.split(j, factors=[None, 16])
+    k, k_tc = sch.split(k, factors=[None, k_inner])
+
+    sch.reorder(i, j, k, i_tc, j_tc, k_tc)
+
+    block_inner = sch.blockize(i_tc)
+    block_outer, block_inner = block_inner, block
+
+    num_ty = i_factors[2] * j_factors[2]
+
+    i0, i1, i2, i3, i4 = sch.split(i, factors=i_factors)
+    j0, j1, j2, j3, j4 = sch.split(j, factors=j_factors)
+    k0, k1, k2 = sch.split(k, k_factors)
+
+    sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3, k2, i4, j4)
+
+    block_idx = sch.fuse(i0, j0)
+    block_idy = sch.fuse(i1, j1)
+    thread_idy = sch.fuse(j2, i2)
+    sch.bind(block_idx, "blockIdx.x")
+    sch.bind(block_idy, "blockIdx.y")
+    sch.bind(thread_idy, "threadIdx.y")
+
+    def fetch_to_shared(block, idx, ndim):
+        block_read = sch.cache_read(block, idx, "shared")
+        sch.compute_at(block_read, k0)
+        vector_size = 16 if in_dtype == "int8" else 8
+        warp_size = 32
+        fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
+        _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, 
vector_size])
+        sch.bind(f_2, "threadIdx.x")
+        sch.bind(f_1, "threadIdx.y")
+        sch.vectorize(f_3)
+        offset = 8 if in_dtype == "float16" else 16
+        sch.storage_align(block_read, 0, axis=-2, factor=32, offset=offset)
+
+        return block_read
+
+    fetch_to_shared(block_outer, 0, 2)
+    fetch_to_shared(block_outer, 1, 2)
+
+    A_warp = sch.cache_read(block_outer, 0, "warp")
+    B_warp = sch.cache_read(block_outer, 1, "warp")
+
+    sch.compute_at(A_warp, k1)
+    sch.compute_at(B_warp, k1)
+
+    C_warp = sch.cache_write(block_outer, 0, "warp")
+    sch.reverse_compute_at(C_warp, thread_idy)
+
+    ii, jj = sch.get_loops(C_warp)[-2:]
+    io, ii = sch.split(ii, factors=[None, 16])
+    jo, ji = sch.split(jj, factors=[None, 16])
+    sch.reorder(io, jo, ii, ji)
+
+    sch.decompose_reduction(block_outer, sch.get_loops(block_outer)[3])
+    block_init_c = sch.get_block("C_init")
+
+    def tile_wmma_fragment(block_read, height, width):
+        i, j = sch.get_loops(block_read)[-2:]
+        i0, i1 = sch.split(i, factors=[None, height])
+        j0, j1 = sch.split(j, factors=[None, width])
+        sch.reorder(i0, j0, i1, j1)
+        return i1
+
+    loop_a = tile_wmma_fragment(A_warp, 16, k_inner)
+
+    if b_transposed:
+        loop_b = tile_wmma_fragment(B_warp, 16, k_inner)
+    else:
+        loop_b = tile_wmma_fragment(B_warp, k_inner, 16)
+
+    sch.transform_layout(A_warp, 0, "write", index_map_A)
+    sch.transform_layout(B_warp, 0, "write", index_map_B)
+    sch.transform_layout(C_warp, 0, "read", index_map_C)
+
+    sch.tensorize(loop_a, ldmatrix_a_intrin)
+    sch.tensorize(loop_b, ldmatrix_b_intrin)
+    sch.tensorize(sch.get_loops(block_inner)[-3], mma_intrin)
+    sch.tensorize(sch.get_loops(block_init_c)[-2], mma_fill_intrin)
+    sch.tensorize(sch.get_loops(C_warp)[-2], mma_store_intrin)
+
+    if not is_ampere_or_newer():
+        return None
+
+    f = tvm.build(sch.mod["main"], target="cuda", name="dense")
+
+    dev = tvm.device("cuda", 0)
+
+    if in_dtype == "float16":
+        a_np = np.random.uniform(size=(M, K)).astype("float16")
+
+        if b_transposed:
+            b_np = np.random.uniform(size=(N, K)).astype("float16")
+            c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32").transpose()).astype(
+                out_dtype
+            )
+        else:
+            b_np = np.random.uniform(size=(K, N)).astype("float16")
+            c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32")).astype(out_dtype)
+    else:
+        a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
+
+        if b_transposed:
+            b_np = np.random.randint(-128, 128, (N, K)).astype("int8")
+            c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32").transpose()).astype(
+                "int32"
+            )
+        else:
+            b_np = np.random.randint(-128, 128, (K, N)).astype("int8")
+            c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32")).astype("int32")
+
+    a = tvm.nd.array(a_np, dev)
+    b = tvm.nd.array(b_np, dev)
+    c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev)
+
+    f(a, b, c)
+
+    if out_dtype != "float16":
+        # The numpy reference is computed with fp32 precision (otherwise too 
slow).
+        # So there is non-trivial accuracy difference if TVM result is 
computed with fp16 accumulation.
+        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
+
+    return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)
+
+
[email protected]_cuda
+def test_f16f16f32_m16n16k16():
+    def index_map(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
+        )
+
+    k_inner = 16
+    in_dtype = "float16"
+    out_dtype = "float32"
+    i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 
2, 1]
+
+    timer = run_test(
+        k_inner,
+        in_dtype,
+        out_dtype,
+        False,  # b_transposed
+        i_factors,
+        j_factors,
+        k_factors,
+        index_map,
+        index_map,
+        index_map,
+        LDMATRIX_16x16_A_INTRIN,
+        LDMATRIX_16x16_B_INTRIN,
+        MMA_f16f16f32_INTRIN,
+        MMA_fill_16x16_f32_INTRIN,
+        MMA_store_16x16_f32_global_INTRIN,
+    )
+
+    if measure_perf and timer:
+        print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
+
+    timer = run_test(
+        k_inner,
+        in_dtype,
+        out_dtype,
+        True,  # b_transposed
+        i_factors,
+        j_factors,
+        k_factors,
+        index_map,
+        index_map,
+        index_map,
+        LDMATRIX_16x16_A_INTRIN,
+        LDMATRIX_16x16_B_TRANS_INTRIN,
+        MMA_f16f16f32_TRANS_INTRIN,
+        MMA_fill_16x16_f32_INTRIN,
+        MMA_store_16x16_f32_global_INTRIN,
+    )
+
+    if measure_perf and timer:
+        print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / 
(timer().mean)))
+
+
[email protected]_cuda
+def test_f16f16f16_m16n16k16():
+    def index_map(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
+        )
+
+    k_inner = 16
+    in_dtype = "float16"
+    out_dtype = "float16"
+    i_factors, j_factors, k_factors = [16, 2, 1, 4, 2], [16, 2, 2, 1, 4], 
[128, 2, 1]
+
+    timer = run_test(
+        k_inner,
+        in_dtype,
+        out_dtype,
+        False,  # b_transposed
+        i_factors,
+        j_factors,
+        k_factors,
+        index_map,
+        index_map,
+        index_map,
+        LDMATRIX_16x16_A_INTRIN,
+        LDMATRIX_16x16_B_INTRIN,
+        MMA_f16f16f16_INTRIN,
+        MMA_fill_16x16_f16_INTRIN,
+        MMA_store_16x16_f16_global_INTRIN,
+    )
+
+    if measure_perf and timer:
+        print("f16f16f16_m16n16k16: %f GFLOPS" % (gflops / (timer().mean)))
+
+    timer = run_test(
+        k_inner,
+        in_dtype,
+        out_dtype,
+        True,  # b_transposed
+        i_factors,
+        j_factors,
+        k_factors,
+        index_map,
+        index_map,
+        index_map,
+        LDMATRIX_16x16_A_INTRIN,
+        LDMATRIX_16x16_B_TRANS_INTRIN,
+        MMA_f16f16f16_TRANS_INTRIN,
+        MMA_fill_16x16_f16_INTRIN,
+        MMA_store_16x16_f16_global_INTRIN,
+    )
+
+    if measure_perf and timer:
+        print("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops / 
(timer().mean)))
+
+
[email protected]_cuda
+def test_i8i8i32_m16n16k32():
+    def index_map_A(i, j):
+        return (
+            i // 16,
+            j // 32,
+            *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32),
+        )
+
+    def index_map_B(i, j):
+        return (
+            i // 32,
+            j // 16,
+            *shared_32x16_to_ldmatrix_32x16_layout(i % 32, j % 16),
+        )
+
+    def index_map_C(i, j):
+        return (
+            i // 16,
+            j // 16,
+            *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16),
+        )
+
+    k_inner = 32
+    in_dtype = "int8"
+    out_dtype = "int32"
+    i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 
2, 2]
+
+    timer = run_test(
+        k_inner,
+        in_dtype,
+        out_dtype,
+        False,  # b_transposed
+        i_factors,
+        j_factors,
+        k_factors,
+        index_map_A,
+        index_map_B,
+        index_map_C,
+        LDMATRIX_16x32_A_INTRIN,
+        LDMATRIX_32x16_B_INTRIN,
+        MMA_i8i8i32_INTRIN,
+        MMA_fill_16x16_i32_INTRIN,
+        MMA_store_16x16_i32_global_INTRIN,
+    )
+
+    if measure_perf and timer:
+        print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean)))
+
+    timer = run_test(
+        k_inner,
+        in_dtype,
+        out_dtype,
+        True,  # b_transposed
+        i_factors,
+        j_factors,
+        k_factors,
+        index_map_A,
+        index_map_A,
+        index_map_C,
+        LDMATRIX_16x32_A_INTRIN,
+        LDMATRIX_16x32_B_TRANS_INTRIN,
+        MMA_i8i8i32_TRANS_INTRIN,
+        MMA_fill_16x16_i32_INTRIN,
+        MMA_store_16x16_i32_global_INTRIN,
+    )
+
+    if measure_perf and timer:
+        print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean)))
+
+
+if __name__ == "__main__":
+    test_f16f16f32_m16n16k16()
+    test_f16f16f16_m16n16k16()
+    test_i8i8i32_m16n16k32()

Reply via email to