This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9bbc2c0ab7 [MetaSchedule] Use `shared.dyn` for Tensor Core Schedule 
Rules (#13891)
9bbc2c0ab7 is described below

commit 9bbc2c0ab7e0e30e1814ade0963c8419dd5b7d3d
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Feb 1 23:35:48 2023 +0800

    [MetaSchedule] Use `shared.dyn` for Tensor Core Schedule Rules (#13891)
    
    This PR adds Tensor Core intrinsics with `shared.dyn` scope and changes the 
default rules to use `shared.dyn`.
    
    Here are the performance improvement of GEMM 1024x1024x1024 on my device 
(RTX-3080)
    
    |                     |      Use `shared`         |  Use `shared.dyn`.   | 
Speedup  |
    | ----------- | --------------------- | -------------------- | ---------- |
    fp 16-16-16 | 66399.8766 GFLOPs | 71778.3808 GFLOPs |      8.1%    |
    fp 16-16-32 | 44292.5893 GFLOPs | 49070.2514 GFLOPS |  10.8%    |
    
    cc @vinx13 @junrushao @masahi
---
 python/tvm/tir/tensor_intrin/cuda.py               | 166 ++++++++++++++-------
 src/meta_schedule/schedule_rule/schedule_rule.cc   |  40 ++---
 .../test_meta_schedule_schedule_rule_mlt_tc.py     |  91 ++++++-----
 .../unittest/test_meta_schedule_trace_apply.py     |  16 +-
 .../unittest/test_tir_schedule_compute_inline.py   |   4 +-
 5 files changed, 198 insertions(+), 119 deletions(-)

diff --git a/python/tvm/tir/tensor_intrin/cuda.py 
b/python/tvm/tir/tensor_intrin/cuda.py
index 0cde7f2464..0703811ea7 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -18,6 +18,8 @@
 """Intrinsics for tensorization on NVIDIA GPU."""
 from typing import Dict, Tuple
 
+from typing_extensions import Literal
+
 from tvm.script import tir as T
 from tvm.tir.function import PrimFunc
 
@@ -815,54 +817,101 @@ TensorIntrin.register(
     *get_wmma_sync_intrin(16, 16, 16, "int8", "int32", True),
 )
 
-WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a"
+WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_shared"
 TensorIntrin.register(
     WMMA_LOAD_16x16x16_F16_A_INTRIN,
     *get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, False),
 )
 
-WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b"
+WMMA_LOAD_16x16x16_F16_A_DYN_INTRIN = "wmma_load_16x16x16_f16_a_shared_dyn"
+TensorIntrin.register(
+    WMMA_LOAD_16x16x16_F16_A_DYN_INTRIN,
+    *get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", False, False),
+)
+
+WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b_shared"
 TensorIntrin.register(
     WMMA_LOAD_16x16x16_F16_B_INTRIN,
     *get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, False),
 )
 
-WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans"
+WMMA_LOAD_16x16x16_F16_B_DYN_INTRIN = "wmma_load_16x16x16_f16_b_shared_dyn"
+TensorIntrin.register(
+    WMMA_LOAD_16x16x16_F16_B_DYN_INTRIN,
+    *get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", True, False),
+)
+
+WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans_shared"
 TensorIntrin.register(
     WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN,
     *get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, True),
 )
 
-WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans"
+WMMA_LOAD_16x16x16_F16_A_TRANS_DYN_INTRIN = 
"wmma_load_16x16x16_f16_a_trans_shared_dyn"
+TensorIntrin.register(
+    WMMA_LOAD_16x16x16_F16_A_TRANS_DYN_INTRIN,
+    *get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", False, True),
+)
+
+WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans_shared"
 TensorIntrin.register(
     WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN,
     *get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, True),
 )
 
-WMMA_LOAD_16x16x16_S8_A_INTRIN = "wmma_load_16x16x16_s8_a"
+WMMA_LOAD_16x16x16_F16_B_TRANS_DYN_INTRIN = 
"wmma_load_16x16x16_f16_b_trans_shared_dyn"
+TensorIntrin.register(
+    WMMA_LOAD_16x16x16_F16_B_TRANS_DYN_INTRIN,
+    *get_wmma_load_intrin(16, 16, 16, "float16", "shared.dyn", True, True),
+)
+
+WMMA_LOAD_16x16x16_S8_A_INTRIN = "wmma_load_16x16x16_s8_a_shared"
 TensorIntrin.register(
     WMMA_LOAD_16x16x16_S8_A_INTRIN,
     *get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, False),
 )
 
-WMMA_LOAD_16x16x16_S8_B_INTRIN = "wmma_load_16x16x16_s8_b"
+WMMA_LOAD_16x16x16_S8_A_DYN_INTRIN = "wmma_load_16x16x16_s8_a_shared_dyn"
+TensorIntrin.register(
+    WMMA_LOAD_16x16x16_S8_A_DYN_INTRIN,
+    *get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", False, False),
+)
+
+WMMA_LOAD_16x16x16_S8_B_INTRIN = "wmma_load_16x16x16_s8_b_shared"
 TensorIntrin.register(
     WMMA_LOAD_16x16x16_S8_B_INTRIN,
     *get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, False),
 )
 
-WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN = "wmma_load_16x16x16_s8_a_trans"
+WMMA_LOAD_16x16x16_S8_B_DYN_INTRIN = "wmma_load_16x16x16_s8_b_shared_dyn"
+TensorIntrin.register(
+    WMMA_LOAD_16x16x16_S8_B_DYN_INTRIN,
+    *get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, False),
+)
+
+WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN = "wmma_load_16x16x16_s8_a_trans_shared"
 TensorIntrin.register(
     WMMA_LOAD_16x16x16_S8_A_TRANS_INTRIN,
     *get_wmma_load_intrin(16, 16, 16, "int8", "shared", False, True),
 )
 
-WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN = "wmma_load_16x16x16_s8_b_trans"
+WMMA_LOAD_16x16x16_S8_A_TRANS_DYN_INTRIN = 
"wmma_load_16x16x16_s8_a_trans_shared_dyn"
+TensorIntrin.register(
+    WMMA_LOAD_16x16x16_S8_A_TRANS_DYN_INTRIN,
+    *get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", False, True),
+)
+
+WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN = "wmma_load_16x16x16_s8_b_trans_shared"
 TensorIntrin.register(
     WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN,
     *get_wmma_load_intrin(16, 16, 16, "int8", "shared", True, True),
 )
 
+WMMA_LOAD_16x16x16_S8_B_TRANS_DYN_INTRIN = 
"wmma_load_16x16x16_s8_b_trans_shared_dyn"
+TensorIntrin.register(
+    WMMA_LOAD_16x16x16_S8_B_TRANS_DYN_INTRIN,
+    *get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, True),
+)
 
 WMMA_FILL_16x16x16_F32_INTRIN = "wmma_fill_16x16x16_f32"
 TensorIntrin.register(WMMA_FILL_16x16x16_F32_INTRIN, *get_wmma_fill_intrin(16, 
16, 16, "float32"))
@@ -878,16 +927,34 @@ TensorIntrin.register(
     WMMA_STORE_16x16x16_F32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, 
"float32", "shared")
 )
 
+WMMA_STORE_16x16x16_F32_SHARED_DYN_INTRIN = 
"wmma_store_16x16x16_f32_shared_dyn"
+TensorIntrin.register(
+    WMMA_STORE_16x16x16_F32_SHARED_DYN_INTRIN,
+    *get_wmma_store_intrin(16, 16, 16, "float32", "shared.dyn"),
+)
+
 WMMA_STORE_16x16x16_F16_SHARED_INTRIN = "wmma_store_16x16x16_f16_shared"
 TensorIntrin.register(
     WMMA_STORE_16x16x16_F16_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, 
"float16", "shared")
 )
 
+WMMA_STORE_16x16x16_F16_SHARED_DYN_INTRIN = 
"wmma_store_16x16x16_f16_shared_dyn"
+TensorIntrin.register(
+    WMMA_STORE_16x16x16_F16_SHARED_DYN_INTRIN,
+    *get_wmma_store_intrin(16, 16, 16, "float16", "shared.dyn"),
+)
+
 WMMA_STORE_16x16x16_S32_SHARED_INTRIN = "wmma_store_16x16x16_s32_shared"
 TensorIntrin.register(
     WMMA_STORE_16x16x16_S32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, 
"int32", "shared")
 )
 
+WMMA_STORE_16x16x16_S32_SHARED_DYN_INTRIN = 
"wmma_store_16x16x16_s32_shared_dyn"
+TensorIntrin.register(
+    WMMA_STORE_16x16x16_S32_SHARED_DYN_INTRIN,
+    *get_wmma_store_intrin(16, 16, 16, "int32", "shared.dyn"),
+)
+
 WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN = "wmma_store_16x16x16_f32_global"
 TensorIntrin.register(
     WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, 
"float32", "global")
@@ -905,14 +972,21 @@ TensorIntrin.register(
 
 
 def get_wmma_intrin_group(
-    store_scope: str, in_dtype: str, out_dtype: str, trans_b: bool
+    load_scope: Literal["shared", "shared.dyn"],
+    store_scope: Literal["global", "shared", "shared.dyn"],
+    in_dtype: str,
+    out_dtype: str,
+    trans_b: bool,
 ) -> Dict[str, str]:
     """Get a group of intrinsics for wmma tensor core with the given 
configurations
 
     Parameters
     ----------
-    store_scope : str
-        Must be one of ["global", "shared"]. The memory scope of the result 
buffer.
+    load_scope : Literal["shared", "shared.dyn"]
+        The memory scope of the input buffer.
+
+    store_scope : Literal["global", "shared", "shared.dyn"]
+        The memory scope of the result buffer.
 
     in_dtype : str
         The input data type.
@@ -928,51 +1002,35 @@ def get_wmma_intrin_group(
     ret : Dict[str, str]
         A group of tensor intrinsics.
     """
-    assert store_scope in ["global", "shared"]
+    assert load_scope in ["shared", "shared.dyn"]
+    assert store_scope in ["global", "shared", "shared.dyn"]
     assert in_dtype in ["float16", "int8"]
     assert out_dtype in ["float16", "float32", "int32"]
 
-    load_a_intrins = {
-        "float16": WMMA_LOAD_16x16x16_F16_A_INTRIN,
-        "int8": WMMA_LOAD_16x16x16_S8_A_INTRIN,
-    }
-    load_b_intrins = {
-        "float16": WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN
-        if trans_b
-        else WMMA_LOAD_16x16x16_F16_B_INTRIN,
-        "int8": WMMA_LOAD_16x16x16_S8_B_TRANS_INTRIN if trans_b else 
WMMA_LOAD_16x16x16_S8_B_INTRIN,
-    }
-    compute_intrins = {
-        "float16": WMMA_SYNC_16x16x16_f16f16f16_TRANS_INTRIN
-        if trans_b
-        else WMMA_SYNC_16x16x16_f16f16f16_INTRIN,
-        "float32": WMMA_SYNC_16x16x16_f16f16f32_TRANS_INTRIN
-        if trans_b
-        else WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
-        "int32": WMMA_SYNC_16x16x16_s8s8s32_TRANS_INTRIN
-        if trans_b
-        else WMMA_SYNC_16x16x16_s8s8s32_INTRIN,
-    }
-    init_intrins = {
-        "float16": WMMA_FILL_16x16x16_F16_INTRIN,
-        "float32": WMMA_FILL_16x16x16_F32_INTRIN,
-        "int32": WMMA_FILL_16x16x16_S32_INTRIN,
-    }
-    store_intrins = {
-        "float16": WMMA_STORE_16x16x16_F16_SHARED_INTRIN
-        if store_scope == "shared"
-        else WMMA_STORE_16x16x16_F16_GLOBAL_INTRIN,
-        "float32": WMMA_STORE_16x16x16_F32_SHARED_INTRIN
-        if store_scope == "shared"
-        else WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,
-        "int32": WMMA_STORE_16x16x16_S32_SHARED_INTRIN
-        if store_scope == "shared"
-        else WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN,
-    }
+    shape = "16x16x16"
+    in_dtype = "f16" if in_dtype == "float16" else "s8"
+    out_dtype = "f16" if out_dtype == "float16" else "f32" if out_dtype == 
"float32" else "s32"
+    # convert "shared.dyn" to "shared_dyn"
+    load_scope = load_scope.replace(".", "_")
+    store_scope = store_scope.replace(".", "_")
+    trans_a = ""
+    trans_b = "_trans" if trans_b else ""
+
+    # e.g. wmma_load_16x16x16_f16_a_shared
+    load_a_intrin = f"wmma_load_{shape}_{in_dtype}_a{trans_a}_{load_scope}"
+    # e.g. wmma_load_16x16x16_f16_b_trans_shared_dyn
+    load_b_intrin = f"wmma_load_{shape}_{in_dtype}_b{trans_b}_{load_scope}"
+    # e.g. wmma_sync_16x16x16_f16f16f32_trans
+    compute_intrin = 
f"wmma_sync_{shape}_{in_dtype}{in_dtype}{out_dtype}{trans_b}"
+    # e.g. wmma_fill_16x16x16_f16
+    init_intrin = f"wmma_fill_{shape}_{out_dtype}"
+    # e.g. wmma_store_16x16x16_f16_shared_dyn
+    store_intrin = f"wmma_store_{shape}_{out_dtype}_{store_scope}"
+
     return {
-        "init": init_intrins[out_dtype],
-        "load_a": load_a_intrins[in_dtype],
-        "load_b": load_b_intrins[in_dtype],
-        "compute": compute_intrins[out_dtype],
-        "store": store_intrins[out_dtype],
+        "init": init_intrin,
+        "load_a": load_a_intrin,
+        "load_b": load_b_intrin,
+        "compute": compute_intrin,
+        "store": store_intrin,
     }
diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc 
b/src/meta_schedule/schedule_rule/schedule_rule.cc
index 938d39377f..49a7c9911c 100644
--- a/src/meta_schedule/schedule_rule/schedule_rule.cc
+++ b/src/meta_schedule/schedule_rule/schedule_rule.cc
@@ -175,47 +175,47 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() 
{
       // Tensor Cores f32 += f16 * f16
       {
           {"init", "wmma_fill_16x16x16_f32"},
-          {"load_a", "wmma_load_16x16x16_f16_a"},
-          {"load_b", "wmma_load_16x16x16_f16_b"},
+          {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
+          {"load_b", "wmma_load_16x16x16_f16_b_shared_dyn"},
           {"compute", "wmma_sync_16x16x16_f16f16f32"},
-          {"store", "wmma_store_16x16x16_f32_shared"},
+          {"store", "wmma_store_16x16x16_f32_shared_dyn"},
       },
       {
           {"init", "wmma_fill_16x16x16_f32"},
-          {"load_a", "wmma_load_16x16x16_f16_a"},
-          {"load_b", "wmma_load_16x16x16_f16_b_trans"},
+          {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
+          {"load_b", "wmma_load_16x16x16_f16_b_trans_shared_dyn"},
           {"compute", "wmma_sync_16x16x16_f16f16f32_trans"},
-          {"store", "wmma_store_16x16x16_f32_shared"},
+          {"store", "wmma_store_16x16x16_f32_shared_dyn"},
       },
       // Tensor Cores f16 += f16 * f16
       {
           {"init", "wmma_fill_16x16x16_f16"},
-          {"load_a", "wmma_load_16x16x16_f16_a"},
-          {"load_b", "wmma_load_16x16x16_f16_b"},
+          {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
+          {"load_b", "wmma_load_16x16x16_f16_b_shared_dyn"},
           {"compute", "wmma_sync_16x16x16_f16f16f16"},
-          {"store", "wmma_store_16x16x16_f16_shared"},
+          {"store", "wmma_store_16x16x16_f16_shared_dyn"},
       },
       {
           {"init", "wmma_fill_16x16x16_f16"},
-          {"load_a", "wmma_load_16x16x16_f16_a"},
-          {"load_b", "wmma_load_16x16x16_f16_b_trans"},
+          {"load_a", "wmma_load_16x16x16_f16_a_shared_dyn"},
+          {"load_b", "wmma_load_16x16x16_f16_b_trans_shared_dyn"},
           {"compute", "wmma_sync_16x16x16_f16f16f16_trans"},
-          {"store", "wmma_store_16x16x16_f16_shared"},
+          {"store", "wmma_store_16x16x16_f16_shared_dyn"},
       },
       // Tensor Cores s32 += s8 * s8
       {
           {"init", "wmma_fill_16x16x16_s32"},
-          {"load_a", "wmma_load_16x16x16_s8_a"},
-          {"load_b", "wmma_load_16x16x16_s8_b"},
+          {"load_a", "wmma_load_16x16x16_s8_a_shared_dyn"},
+          {"load_b", "wmma_load_16x16x16_s8_b_shared_dyn"},
           {"compute", "wmma_sync_16x16x16_s8s8s32"},
-          {"store", "wmma_store_16x16x16_s32_shared"},
+          {"store", "wmma_store_16x16x16_s32_shared_dyn"},
       },
       {
           {"init", "wmma_fill_16x16x16_s32"},
-          {"load_a", "wmma_load_16x16x16_s8_a"},
-          {"load_b", "wmma_load_16x16x16_s8_b_trans"},
+          {"load_a", "wmma_load_16x16x16_s8_a_shared_dyn"},
+          {"load_b", "wmma_load_16x16x16_s8_b_trans_shared_dyn"},
           {"compute", "wmma_sync_16x16x16_s8s8s32_trans"},
-          {"store", "wmma_store_16x16x16_s32_shared"},
+          {"store", "wmma_store_16x16x16_s32_shared_dyn"},
       },
   };
   Array<ScheduleRule> results{
@@ -229,11 +229,11 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() 
{
           /*reuse_read=*/
           Map<String, ObjectRef>{{"req", String("must")},
                                  {"levels", Array<Integer>{4}},  //
-                                 {"scope", String("shared")}},
+                                 {"scope", String("shared.dyn")}},
           /*reuse_write=*/
           Map<String, ObjectRef>{{"req", String("must")},
                                  {"levels", Array<Integer>{2}},  //
-                                 {"scope", String("shared")}},
+                                 {"scope", String("shared.dyn")}},
           /*use_software_pipeline=*/false)  //
   };
   Array<ScheduleRule> append = ScheduleRule::DefaultCUDA();
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py 
b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
index 73b2c990f0..0647699159 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -15,6 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: 
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+
+import pytest
+
 import tvm
 import tvm.testing
 from tvm import meta_schedule as ms
@@ -31,13 +34,15 @@ from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group
 
 def multi_level_tiling_tensor_core(
     *,
+    read_reuse_scope="shared",
     write_reuse_scope="shared",
     in_dtype="float16",
     out_dtype="float32",
     trans_b=False,
     use_software_pipeline=False,
 ) -> ms.schedule_rule.ScheduleRule:
-    assert write_reuse_scope in ["shared", "global"]
+    assert read_reuse_scope in ["shared", "shared.dyn"]
+    assert write_reuse_scope in ["shared", "shared.dyn", "global"]
     if not isinstance(in_dtype, list):
         in_dtype = [in_dtype]
     if not isinstance(out_dtype, list):
@@ -46,7 +51,9 @@ def multi_level_tiling_tensor_core(
         trans_b = [trans_b]
     return ms.schedule_rule.MultiLevelTilingTensorCore(
         intrin_groups=[
-            get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, 
_trans_b)
+            get_wmma_intrin_group(
+                read_reuse_scope, write_reuse_scope, _in_dtype, _out_dtype, 
_trans_b
+            )
             for _in_dtype in in_dtype
             for _out_dtype in out_dtype
             for _trans_b in trans_b
@@ -58,10 +65,10 @@ def multi_level_tiling_tensor_core(
         reuse_read=ms.schedule_rule.ReuseType(
             req="must",
             levels=[4],
-            scope="shared",
+            scope=read_reuse_scope,
         ),
         reuse_write=ms.schedule_rule.ReuseType(
-            req="must" if write_reuse_scope == "shared" else "no",
+            req="must" if write_reuse_scope.startswith("shared") else "no",
             levels=[2],
             scope=write_reuse_scope,
         ),
@@ -69,15 +76,17 @@ def multi_level_tiling_tensor_core(
     )
 
 
-def test_matmul_relu():
[email protected]("shared_scope", ["shared", "shared.dyn"])
+def test_matmul_relu(shared_scope):
+    intrin_suffix = shared_scope.replace(".", "_")
     # fmt: off
     @T.prim_func
     def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 
128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
-        C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", 
scope="shared")
+        C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", 
scope=shared_scope)
         C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], 
dtype="float32", scope="wmma.accumulator")
-        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope="shared")
-        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope="shared")
+        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope=shared_scope)
+        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope=shared_scope)
         A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_a")
         B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_b")
         for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"):
@@ -107,7 +116,7 @@ def test_matmul_relu():
                                     v1_o = T.axis.spatial(8, ax2_0_1 * 2 + 
ax1_0)
                                     T.reads(A_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_a_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("A_reindex_shared_wmma.matrix_a"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -120,7 +129,7 @@ def test_matmul_relu():
                                     v1_o = T.axis.spatial(8, 
ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + 
ax0_0_2_ax1_0_2_fused + ax1_0)
                                     T.reads(B_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_b_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("B_reindex_shared_wmma.matrix_b"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -155,7 +164,7 @@ def test_matmul_relu():
                             v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 
* 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0)
                             T.reads(C_reindex_shared_wmma_accumulator[v0_o * 
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 
16, v1_o * 16 : v1_o * 16 + 16])
-                            
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
+                            T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
                             for ax0_1, ax1_1 in T.grid(16, 16):
                                 with 
T.block("C_reindex_shared_wmma.accumulator"):
                                     v0_i, v1_i = T.axis.remap("SS", [ax0_1, 
ax1_1])
@@ -196,7 +205,9 @@ def test_matmul_relu():
         target=tvm.target.Target("cuda"),
         types=None,
         sch_rules=[
-            multi_level_tiling_tensor_core(),
+            multi_level_tiling_tensor_core(
+                read_reuse_scope=shared_scope, write_reuse_scope=shared_scope
+            ),
         ]
         + get_rules(kind="cuda", types=ms.schedule_rule.AutoInline),
     )
@@ -249,7 +260,7 @@ def test_matmul_relu_with_fallback():
                                     v1_o = T.axis.spatial(8, ax2_0_0 * 4 + 
ax1_0)
                                     T.reads(A_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("A_reindex_shared_wmma.matrix_a"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -262,7 +273,7 @@ def test_matmul_relu_with_fallback():
                                     v1_o = T.axis.spatial(8, 
ax0_0_2_ax1_0_2_fused * 4 + ax1_0)
                                     T.reads(B_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("B_reindex_shared_wmma.matrix_b"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -355,16 +366,18 @@ def test_matmul_relu_with_fallback():
     )
 
 
-def test_conv2d():
[email protected]("shared_scope", ["shared", "shared.dyn"])
+def test_conv2d(shared_scope):
+    intrin_suffix = shared_scope.replace(".", "_")
     # fmt: off
     @T.prim_func
     def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: 
T.Buffer[(3, 3, 32, 32), "float16"], conv2d_nhwc: T.Buffer[(1, 16, 16, 32), 
"float32"]) -> None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
         PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16")
-        conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], 
dtype="float32", scope="shared")
+        conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], 
dtype="float32", scope=shared_scope)
         conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 
32], dtype="float32", scope="wmma.accumulator")
-        PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", 
scope="shared")
-        weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", 
scope="shared")
+        PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", 
scope=shared_scope)
+        weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", 
scope=shared_scope)
         PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288], 
dtype="float16", scope="wmma.matrix_a")
         weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32], 
dtype="float16", scope="wmma.matrix_b")
         for i0, i1, i2, i3 in T.grid(1, 18, 18, 32):
@@ -400,7 +413,7 @@ def test_conv2d():
                                     v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0)
                                     T.reads(PadInput_reindex_shared[v0_o * 16 
: v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o 
* 16 : v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_a_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("PadInput_reindex_shared_wmma.matrix_a"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -413,7 +426,7 @@ def test_conv2d():
                                     v1_o = T.axis.spatial(2, 
ax0_0_0_ax1_0_0_fused + ax1_0)
                                     T.reads(weight_reindex_shared[v0_o * 16 : 
v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 
16 : v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_b_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("weight_reindex_shared_wmma.matrix_b"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -448,7 +461,7 @@ def test_conv2d():
                             v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + 
ax1_0)
                             
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, 
v1_o * 16 : v1_o * 16 + 16])
                             T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : 
v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
-                            
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
+                            T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
                             for ax0_1, ax1_1 in T.grid(16, 16):
                                 with 
T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"):
                                     v0_i, v1_i = T.axis.remap("SS", [ax0_1, 
ax1_1])
@@ -492,7 +505,9 @@ def test_conv2d():
         target=tvm.target.Target("cuda"),
         types=None,
         sch_rules=[
-            multi_level_tiling_tensor_core(),
+            multi_level_tiling_tensor_core(
+                read_reuse_scope=shared_scope, write_reuse_scope=shared_scope
+            ),
         ],
     )
     check_sketches(
@@ -511,6 +526,8 @@ def test_conv2d():
         types=None,
         sch_rules=[
             multi_level_tiling_tensor_core(
+                read_reuse_scope=shared_scope,
+                write_reuse_scope=shared_scope,
                 in_dtype="float16",
                 out_dtype=["float16", "float32"],
             ),
@@ -524,7 +541,9 @@ def test_conv2d():
     )
 
 
-def test_matmul_relu_pipeline():
[email protected]("shared_scope", ["shared", "shared.dyn"])
+def test_matmul_relu_pipeline(shared_scope):
+    intrin_suffix = shared_scope.replace(".", "_")
     # fmt: off
     @T.prim_func
     def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: 
T.Buffer[(128, 128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> 
None:
@@ -533,10 +552,10 @@ def test_matmul_relu_pipeline():
         # body
         # with T.block("root")
         C = T.alloc_buffer([128, 128], dtype="float32")
-        C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", 
scope="shared")
+        C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", 
scope=shared_scope)
         C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], 
dtype="float32", scope="wmma.accumulator")
-        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope="shared")
-        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope="shared")
+        A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope=shared_scope)
+        B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", 
scope=shared_scope)
         A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_a")
         B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], 
dtype="float16", scope="wmma.matrix_b")
         for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"):
@@ -566,7 +585,7 @@ def test_matmul_relu_pipeline():
                                     v1_o = T.axis.spatial(8, ax2_0_0 * 2 + 
ax2_0_1 + ax1_0)
                                     T.reads(A_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_a_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("A_reindex_shared_wmma.matrix_a"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -579,7 +598,7 @@ def test_matmul_relu_pipeline():
                                     v1_o = T.axis.spatial(8, 
ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0)
                                     T.reads(B_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_load_16x16x16_f16_b_{intrin_suffix}"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("B_reindex_shared_wmma.matrix_b"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -614,7 +633,7 @@ def test_matmul_relu_pipeline():
                             v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 
* 2 + ax1_0)
                             T.reads(C_reindex_shared_wmma_accumulator[v0_o * 
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                             T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 
16, v1_o * 16 : v1_o * 16 + 16])
-                            
T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"})
+                            T.block_attr({"meta_schedule.auto_tensorize": 
f"wmma_store_16x16x16_f32_{intrin_suffix}"})
                             for ax0_1, ax1_1 in T.grid(16, 16):
                                 with 
T.block("C_reindex_shared_wmma.accumulator"):
                                     v0_i, v1_i = T.axis.remap("SS", [ax0_1, 
ax1_1])
@@ -660,6 +679,8 @@ def test_matmul_relu_pipeline():
         types=None,
         sch_rules=[
             multi_level_tiling_tensor_core(
+                read_reuse_scope=shared_scope,
+                write_reuse_scope=shared_scope,
                 use_software_pipeline=True,
             ),
         ],
@@ -713,7 +734,7 @@ def test_matmul_relu_global():
                                     v1_o = T.axis.spatial(8, ax2_0_0 * 4 + 
ax2_0_1 * 2 + ax1_0)
                                     T.reads(A_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("A_reindex_shared_wmma.matrix_a"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -726,7 +747,7 @@ def test_matmul_relu_global():
                                     v1_o = T.axis.spatial(8, 
ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0)
                                     T.reads(B_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("B_reindex_shared_wmma.matrix_b"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -868,7 +889,7 @@ def test_padded_matmul_relu():
                                     v1_o = T.axis.spatial(8, ax2_0_1 * 2 + 
ax1_0)
                                     T.reads(A_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("A_reindex_shared_wmma.matrix_a"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -881,7 +902,7 @@ def test_padded_matmul_relu():
                                     v1_o = T.axis.spatial(8, 
ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + 
ax0_0_2_ax1_0_2_fused + ax1_0)
                                     T.reads(B_reindex_shared[v0_o * 16 : v0_o 
* 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : 
v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
                                     for ax0_1, ax1_1 in T.grid(16, 16):
                                         with 
T.block("B_reindex_shared_wmma.matrix_b"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1, ax1_1])
@@ -1008,7 +1029,7 @@ def test_conv_1x1():
                                     v1_o = T.axis.spatial(4, ax1_0_1)
                                     T.reads(PadInput_reindex_shared[v0_o * 16 
: v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                     
T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o 
* 16 : v1_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"})
                                     for ax0_1_1, ax1_1_1 in T.grid(16, 16):
                                         with 
T.block("PadInput_reindex_shared_wmma.matrix_a"):
                                             v0_i, v1_i = T.axis.remap("SS", 
[ax0_1_1, ax1_1_1])
@@ -1021,7 +1042,7 @@ def test_conv_1x1():
                                     v3_o = T.axis.spatial(4, 
ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0)
                                     T.reads(weight_reindex_shared[v0, v1, v2_o 
* 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
                                     
T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 
16, v3_o * 16 : v3_o * 16 + 16])
-                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"})
+                                    
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"})
                                     for ax2_1, ax3_1 in T.grid(16, 16):
                                         with 
T.block("weight_reindex_shared_wmma.matrix_b"):
                                             v2_i, v3_i = T.axis.remap("SS", 
[ax2_1, ax3_1])
diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py 
b/tests/python/unittest/test_meta_schedule_trace_apply.py
index aadc530a9b..c242f63b98 100644
--- a/tests/python/unittest/test_meta_schedule_trace_apply.py
+++ b/tests/python/unittest/test_meta_schedule_trace_apply.py
@@ -1743,7 +1743,7 @@ class Conv2dInt8_with_predicate_scheduled:
                                         v1_o = T.axis.spatial(4, ax4_0_0 * 2 + 
ax4_0_1 + ax1_0_1)
                                         T.reads(pad_temp_reindex_shared[v0_o * 
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                         
T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o 
* 16 : v1_o * 16 + 16])
-                                        
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a"})
+                                        
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a_shared"})
                                         for ax0_1_1, ax1_1_1 in T.grid(16, 16):
                                             with 
T.block("pad_temp_reindex_shared_wmma.matrix_a"):
                                                 v0_i, v1_i = 
T.axis.remap("SS", [ax0_1_1, ax1_1_1])
@@ -1757,7 +1757,7 @@ class Conv2dInt8_with_predicate_scheduled:
                                         v3_o = T.axis.spatial(4, ax4_0_0 * 2 + 
ax4_0_1 + ax3_0)
                                         T.reads(p1_reindex_shared[v0, v1, v2_o 
* 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
                                         
T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, 
v3_o * 16 : v3_o * 16 + 16])
-                                        
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans"})
+                                        
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans_shared"})
                                         for ax2_1, ax3_1 in T.grid(16, 16):
                                             with 
T.block("p1_reindex_shared_wmma.matrix_b"):
                                                 v2_i, v3_i = 
T.axis.remap("SS", [ax2_1, ax3_1])
@@ -2312,7 +2312,7 @@ def test_conv2d_int8_tensorcore():
         sch.annotate(
             block_or_loop=b158,
             ann_key="meta_schedule.auto_tensorize",
-            ann_val="wmma_load_16x16x16_s8_a",
+            ann_val="wmma_load_16x16x16_s8_a_shared",
         )
         b159 = sch.cache_read(block=b38, read_buffer_index=1, 
storage_scope="wmma.matrix_b")
         sch.compute_at(block=b159, loop=l80, preserve_unit_loops=True, 
index=-1)
@@ -2355,7 +2355,7 @@ def test_conv2d_int8_tensorcore():
         sch.annotate(
             block_or_loop=b192,
             ann_key="meta_schedule.auto_tensorize",
-            ann_val="wmma_load_16x16x16_s8_b_trans",
+            ann_val="wmma_load_16x16x16_s8_b_trans_shared",
         )
         sch.compute_inline(block=b17)
         sch.compute_inline(block=b18)
@@ -2490,10 +2490,10 @@ def test_conv2d_int8_tensorcore():
         sch.tensorize(block_or_loop=b314, 
tensor_intrin="wmma_fill_16x16x16_s32")
         b315 = sch.get_block(name="pad_temp_reindex_shared_wmma.matrix_a_o", 
func_name="main")
         sch.unannotate(block_or_loop=b315, 
ann_key="meta_schedule.auto_tensorize")
-        sch.tensorize(block_or_loop=b315, 
tensor_intrin="wmma_load_16x16x16_s8_a")
+        sch.tensorize(block_or_loop=b315, 
tensor_intrin="wmma_load_16x16x16_s8_a_shared")
         b316 = sch.get_block(name="p1_reindex_shared_wmma.matrix_b_o", 
func_name="main")
         sch.unannotate(block_or_loop=b316, 
ann_key="meta_schedule.auto_tensorize")
-        sch.tensorize(block_or_loop=b316, 
tensor_intrin="wmma_load_16x16x16_s8_b_trans")
+        sch.tensorize(block_or_loop=b316, 
tensor_intrin="wmma_load_16x16x16_s8_b_trans_shared")
         b317 = sch.get_block(name="conv2d_nhwc_o_update", func_name="main")
         sch.unannotate(block_or_loop=b317, 
ann_key="meta_schedule.auto_tensorize")
         sch.tensorize(block_or_loop=b317, 
tensor_intrin="wmma_sync_16x16x16_s8s8s32_trans")
@@ -3281,7 +3281,7 @@ def test_inline_order():
         sch.annotate(
             block_or_loop=b152,
             ann_key="meta_schedule.auto_tensorize",
-            ann_val="wmma_load_16x16x16_s8_a",
+            ann_val="wmma_load_16x16x16_s8_a_shared",
         )
         b153 = sch.cache_read(block=b32, read_buffer_index=1, 
storage_scope="wmma.matrix_b")
         sch.compute_at(block=b153, loop=l74, preserve_unit_loops=True, 
index=-1)
@@ -3324,7 +3324,7 @@ def test_inline_order():
         sch.annotate(
             block_or_loop=b186,
             ann_key="meta_schedule.auto_tensorize",
-            ann_val="wmma_load_16x16x16_s8_b_trans",
+            ann_val="wmma_load_16x16x16_s8_b_trans_shared",
         )
         sch.compute_inline(block=b11)
         sch.compute_inline(block=b12)
diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py 
b/tests/python/unittest/test_tir_schedule_compute_inline.py
index f9c5e22e97..bd46e10efa 100644
--- a/tests/python/unittest/test_tir_schedule_compute_inline.py
+++ b/tests/python/unittest/test_tir_schedule_compute_inline.py
@@ -703,7 +703,7 @@ class Conv2dInt8_TensorCore_with_predicate:
                                         v1_o = T.axis.spatial(4, ax4_0_0 * 2 + 
ax4_0_1)
                                         T.reads(pad_temp_reindex_shared[v0_o * 
16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                         
T.writes(pad_temp_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o 
* 16 : v1_o * 16 + 16])
-                                        
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a"})
+                                        
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_a_shared"})
                                         for ax0_1_1, ax1_1_1 in T.grid(16, 16):
                                             with 
T.block("pad_temp_reindex_shared_wmma.matrix_a"):
                                                 v0_i, v1_i = 
T.axis.remap("SS", [ax0_1_1, ax1_1_1])
@@ -718,7 +718,7 @@ class Conv2dInt8_TensorCore_with_predicate:
                                         v3_o = T.axis.spatial(4, ax4_0_0 * 2 + 
ax4_0_1)
                                         T.reads(p1_reindex_shared[v0, v1, v2_o 
* 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16])
                                         
T.writes(p1_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, 
v3_o * 16 : v3_o * 16 + 16])
-                                        
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans"})
+                                        
T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_s8_b_trans_shared"})
                                         for ax2_1, ax3_1 in T.grid(16, 16):
                                             with 
T.block("p1_reindex_shared_wmma.matrix_b"):
                                                 v2_i, v3_i = 
T.axis.remap("SS", [ax2_1, ax3_1])


Reply via email to