This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 8e8799d709 [Unity][Dlight] Enhance matmul tensorizer with Int8 support 
(#16084)
8e8799d709 is described below

commit 8e8799d709046676c267828786fa6c5fa3601fb1
Author: Ivan Sidorenko <[email protected]>
AuthorDate: Wed Nov 8 00:23:37 2023 +0300

    [Unity][Dlight] Enhance matmul tensorizer with Int8 support (#16084)
    
    Main goal of this commit was not to provide highly optimized schedule,
    but to enable tensor core tensozrization for i8i8i32 matmul. Now it
    supports fp16 matmul only.
    
    Anyway, speedup is a case compared to fp16 (fp16fp16->fp16 vs i8i8->i32):
    
    N |        M, N, K*      |   Shape type   | fp16, us | int8, us | speedup |
    --|----------------------|----------------|----------|----------|---------|
    1 |     256, 256, 256    |  static shapes |   15.03  |   12.55  |  1.198x |
    2 |     512, 512, 4096   |  static shapes |  173.25  |  129.06  |  1.342x |
    3 | N (=260), 4096, 4096 | dynamic shapes |  350.44  |  298.29  |  1.175x |
    --------------------------------------------------------------------------|
    
    where K* - is a reduction axis.
    
    Implementation of Matmul int8 tensorizer was derived from Matmul
    tensorizer with minimal changes. It is possible to join these
    implementations (fp16 and int8), but I am not sure this is a good idea.
    Since it can slow down development of fp16 part.
---
 python/tvm/dlight/gpu/matmul.py                  | 220 +++++++++++++++++-
 tests/python/dlight/test_gpu_matmul_tensorize.py | 273 +++++++++++++++++++++++
 2 files changed, 492 insertions(+), 1 deletion(-)

diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 3549b4bc44..703f9c151f 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -296,6 +296,16 @@ def get_reduction_blocks(sch, blocks) -> bool:
     return reduction_blocks
 
 
+def get_in_out_dtypes(block: tir.Block) -> Tuple[str]:
+    """
+    Detect In/Out data types for the given block based on the analysis if 
read/write buffers.
+    """
+    assert len(block.reads) > 0 and len(block.writes) > 0
+    in_dtype = block.reads[0].buffer.dtype
+    out_dtype = block.writes[0].buffer.dtype
+    return (in_dtype, out_dtype)
+
+
 def check_sm_version(arch: str) -> int:
     sm_version = arch.replace("sm_", "")
     return int(sm_version) if sm_version.isdigit() else -1
@@ -520,6 +530,209 @@ class MatmulTensorization(ScheduleRule):
         return sch if tensorize_success else None
 
 
+class MatmulInt8Tensorization(ScheduleRule):
+    """
+    The schedule rule for int8 tensor core matmul computation.
+    func with attr 'dlight.do_not_tensorize' will not be tensorized.
+    """
+
+    def apply(  # pylint: disable=too-many-locals,missing-docstring
+        self,
+        func: tir.PrimFunc,
+        target: Target,
+        _: bool,
+    ) -> Optional[tir.Schedule]:
+        from tvm.tir.tensor_intrin.cuda import (  # pylint: 
disable=import-outside-toplevel
+            get_wmma_intrin_group,
+        )
+
+        sch = tir.Schedule(func)
+        root_block = analysis.get_root_block(sch)
+        blocks = sch.get_child_blocks(root_block)
+
+        if func.attrs is not None and "dlight.do_not_tensorize" in 
func.attrs.keys():
+            return None
+
+        reduction_blocks = get_reduction_blocks(sch, blocks)
+        if reduction_blocks is None:
+            return None
+
+        main_block = reduction_blocks[0]
+        block_stmt = sch.get(main_block)
+        index_maps = get_index_map(block_stmt)
+        if index_maps is None:
+            return None
+        matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+        # Start Schedule
+        # Step 0. Get schedule config.
+        # NOTE: we can analyze the config by the hardware spec in the future
+
+        # tensor core intrinsic size
+        micro_size_x = 16
+        micro_size_y = 16
+        micro_size_k = 16
+
+        warp_size = 32
+        vector_size = 4
+
+        i_factors, j_factors, k_factors = (
+            [None, 1, 4, 2],
+            [1, None, 4, 2],
+            [None, 1],
+        )
+
+        num_ty = i_factors[2] * j_factors[2]
+        x_pad_factor = i_factors[2] * i_factors[3]
+        y_pad_factor = j_factors[2] * j_factors[3]
+        k_pad_factor = k_factors[1]
+
+        # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, 
J, K]
+        block = sch.reindex(main_block, ("read", 0))
+        sch.transform_layout(block, ("write", 0), a_index_map)
+        block = sch.reindex(main_block, ("read", 1))
+        sch.transform_layout(block, ("write", 0), b_index_map)
+        block = sch.reindex(main_block, ("write", 0))
+        sch.transform_layout(block, ("read", 0), c_index_map)
+        sch.transform_block_layout(main_block, matmul_index_map)
+
+        # Step 2. Padding for dynamic shape kernels
+        sch.pad_einsum(
+            main_block,
+            [
+                1,
+                micro_size_x * x_pad_factor,
+                micro_size_y * y_pad_factor,
+                micro_size_k * k_pad_factor,
+            ],
+        )
+
+        # Step 3. Schedule matmul to use tensor core
+        block = main_block
+
+        batch, i, j, k = sch.get_loops(block)
+
+        # inner loops for tensor core computation
+        i, i_inner = sch.split(i, factors=[None, micro_size_x])
+        j, j_inner = sch.split(j, factors=[None, micro_size_y])
+        k, k_inner = sch.split(k, factors=[None, micro_size_k])
+
+        sch.reorder(i, j, k, i_inner, j_inner, k_inner)
+
+        block_inner = block
+        block_outer = sch.blockize(i_inner)
+
+        i0, i1, i2, i3 = sch.split(i, factors=i_factors)
+        j0, j1, j2, j3 = sch.split(j, factors=j_factors)
+        k0, k1 = sch.split(k, k_factors)
+        sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6])
+        sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1])
+        sch.annotate(k1, "software_pipeline_order", [0, 1, 2])
+        sch.annotate(k1, "software_pipeline_stage", [0, 0, 1])
+
+        sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3)
+
+        block_idx = sch.fuse(i0, j0)
+        block_idy = sch.fuse(i1, j1)
+        thread_idy = sch.fuse(j2, i2)
+        sch.bind(batch, "blockIdx.z")
+        sch.bind(block_idx, "blockIdx.x")
+        sch.bind(block_idy, "blockIdx.y")
+        sch.bind(thread_idy, "threadIdx.y")
+
+        def fetch_to_shared(block, idx, ndim):
+            block_read = sch.cache_read(block, idx, "shared.dyn")
+            sch.compute_at(block_read, k0)
+            fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
+
+            _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, 
warp_size, vector_size])
+
+            sch.bind(f_2, "threadIdx.x")
+            sch.bind(f_1, "threadIdx.y")
+            sch.vectorize(f_3)
+
+            sch.storage_align(block_read, 0, axis=-2, factor=32, offset=16)
+            sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 
1)
+            sch.annotate(block_read, "double_buffer_scope", 0)
+            return block_read
+
+        a_g2s = fetch_to_shared(block_outer, 0, 2)
+        b_g2s = fetch_to_shared(block_outer, 1, 2)
+
+        auto_inline_producers(sch, a_g2s)
+        auto_inline_producers(sch, b_g2s)
+
+        # create read cache to load matrix from shared memory to wmma fragments
+        A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a")
+        B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b")
+        sch.compute_at(A_mat, k1)
+        sch.compute_at(B_mat, k1)
+
+        # create write cache to store matrix from wmma fragments to shared 
memory and global memory
+        accumulator_shared_to_global = sch.cache_write(block_outer, 0, 
"shared.dyn")
+        sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4)
+
+        store = sch.cache_write(block_outer, 0, "wmma.accumulator")
+        sch.reverse_compute_at(store, thread_idy)
+        sch.reverse_compute_at(accumulator_shared_to_global, thread_idy)
+
+        # split the store loop to match hardware intrinsic pattern
+        i, j = sch.get_loops(store)[-2:]
+        i0, i1 = sch.split(i, factors=[None, 16])
+        j0, j1 = sch.split(j, factors=[None, 16])
+        sch.reorder(i0, j0, i1, j1)
+
+        block_init_c = sch.decompose_reduction(block_outer, k0)
+        block_init_c_inner = sch.get_child_blocks(block_init_c)[0]
+
+        # Tensorization by hardware intrinsics
+        intrin_group = get_wmma_intrin_group(
+            load_scope="shared.dyn",
+            store_scope="shared.dyn",
+            in_dtype="int8",
+            out_dtype="int32",
+            trans_b=True,
+        )
+
+        try:
+            i, j = sch.get_loops(A_mat)[-2:]
+            i0, i1 = sch.split(i, factors=[None, 16])
+            j0, j1 = sch.split(j, factors=[None, 16])
+            sch.reorder(i0, j0, i1, j1)
+            sch.unroll(i0)
+            sch.unroll(j0)
+            sch.tensorize(i1, intrin_group["load_a"])
+
+            i, j = sch.get_loops(B_mat)[-2:]
+            i0, i1 = sch.split(i, factors=[None, 16])
+            j0, j1 = sch.split(j, factors=[None, 16])
+            sch.reorder(i0, j0, i1, j1)
+            sch.unroll(i0)
+            sch.unroll(j0)
+            sch.tensorize(i1, intrin_group["load_b"])
+        except:  # pylint: disable=bare-except
+            return None
+
+        def tensorize_init_store_compute():
+            sch.tensorize(sch.get_loops(block_init_c_inner)[-2], 
intrin_group["init"])
+            sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"])
+            sch.tensorize(sch.get_loops(block_inner)[-3], 
intrin_group["compute"])
+
+        try:
+            tensorize_init_store_compute()
+        except:  # pylint: disable=bare-except
+            return None
+
+        auto_inline_consumer_chain(sch, accumulator_shared_to_global)
+
+        fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:])
+        _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size])
+        sch.bind(f1, "threadIdx.x")
+        sch.vectorize(f2)
+
+        return sch
+
+
 class Matmul(ScheduleRule):
     """The schedule rule for matmul-like computation"""
 
@@ -618,7 +831,12 @@ class Matmul(ScheduleRule):
                     if extent.value <= minimal_tensorize_threshold:
                         apply_tensorization = False
             if apply_tensorization:
-                tensorize_sch = MatmulTensorization().apply(func, target, _)
+                # Analyze read/write buffers and choose correct tensorizer: 
int8 or fp16.
+                in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
+                if in_dtype == "int8" and out_dtype == "int32":
+                    tensorize_sch = MatmulInt8Tensorization().apply(func, 
target, _)
+                else:
+                    tensorize_sch = MatmulTensorization().apply(func, target, 
_)
                 if tensorize_sch is not None:
                     return tensorize_sch
 
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py 
b/tests/python/dlight/test_gpu_matmul_tensorize.py
index c682c879e0..72ffb30719 100644
--- a/tests/python/dlight/test_gpu_matmul_tensorize.py
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -425,5 +425,278 @@ class TestMatmulTensorizeEpilogue(BaseBeforeAfter):
     # fmt: on
 
 
+class TestMatmulInt8Tensorize(BaseBeforeAfter):
+    # fmt: off
+    @T.prim_func
+    def before(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), 
"int8"), compute: T.Buffer((256, 256), "int32")):
+        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        for i, j, r in T.grid(256, 256, 256):
+            with T.block("compute"):
+                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, r])
+                T.reads(X[v_i, v_k], W[v_j, v_k])
+                T.writes(compute[v_i, v_j])
+                with T.init():
+                    compute[v_i, v_j] = 0
+                compute[v_i, v_j] = compute[v_i, v_j] + T.Cast("int32", X[v_i, 
v_k]) * T.Cast("int32", W[v_j, v_k])
+
+    @T.prim_func
+    def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), 
"int8"), compute: T.Buffer((256, 256), "int32")):
+        T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8", 
scope="shared.dyn")
+        W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8", 
scope="shared.dyn")
+        X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256), 
"int8", scope="wmma.matrix_a")
+        W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256), 
"int8", scope="wmma.matrix_b")
+        compute_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int32", 
scope="shared.dyn")
+        compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 256, 
256), "int32", scope="wmma.accumulator")
+        for ax0 in T.thread_binding(1, thread="blockIdx.z"):
+            for ax1_0_0_ax2_0_0_fused in T.thread_binding(2, 
thread="blockIdx.x"):
+                for ax1_0_1_ax2_0_1_fused in T.thread_binding(2, 
thread="blockIdx.y"):
+                    for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, 
thread="threadIdx.y"):
+                        for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2):
+                            with T.block("compute_o_init"):
+                                v0_o = T.axis.spatial(1, ax0)
+                                v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init)
+                                v2_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init)
+                                T.reads()
+                                
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16])
+                                with T.block("compute_init_o"):
+                                    v1_i_init_o = T.axis.spatial(1, 0)
+                                    v2_i_init_o = T.axis.spatial(1, 0)
+                                    T.reads()
+                                    
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16])
+                                    C = 
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 
16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", 
"C_s1"), scope="wmma.accumulator", offset_factor=16)
+                                    T.tvm_fill_fragment(C.data, 16, 16, 16, 
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % 
C.strides[0] // 16, T.float32(0))
+                        for ax3_0_0 in T.serial(16, 
annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], 
"software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}):
+                            for ax0_ax1_fused_0 in range(1):
+                                for ax0_ax1_fused_1 in T.thread_binding(16, 
thread="threadIdx.y"):
+                                    for ax0_ax1_fused_2 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                        for ax0_ax1_fused_3 in T.vectorized(4):
+                                            with 
T.block("X_reindex_shared.dyn"):
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial(256, 
ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + 
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16)
+                                                v2 = T.axis.spatial(256, 
ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + 
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16)
+                                                T.reads(X[v1, v2])
+                                                
T.writes(X_reindex_shared_dyn[v0, v1, v2])
+                                                
T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, 
"tir.manifest_shared_memory_local_stage": 1})
+                                                X_reindex_shared_dyn[v0, v1, 
v2] = X[v1, v2]
+                            for ax0_ax1_fused_0 in range(1):
+                                for ax0_ax1_fused_1 in T.thread_binding(16, 
thread="threadIdx.y"):
+                                    for ax0_ax1_fused_2 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                        for ax0_ax1_fused_3 in T.vectorized(4):
+                                            with 
T.block("W_reindex_shared.dyn"):
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial(256, 
ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + 
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16)
+                                                v2 = T.axis.spatial(256, 
ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + 
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16)
+                                                T.reads(W[v1, v2])
+                                                
T.writes(W_reindex_shared_dyn[v0, v1, v2])
+                                                
T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, 
"tir.manifest_shared_memory_local_stage": 1})
+                                                W_reindex_shared_dyn[v0, v1, 
v2] = W[v1, v2]
+                            for ax3_0_1 in T.serial(1, 
annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": 
[0, 0, 1]}):
+                                for ax0_0 in T.unroll(2):
+                                    for ax1_0 in T.unroll(1):
+                                        with 
T.block("X_reindex_shared.dyn_wmma.matrix_a_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(16, ax3_0_0 
+ ax1_0)
+                                            T.reads(X_reindex_shared_dyn[v0_o, 
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            
T.writes(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                            A = 
T.match_buffer(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), 
scope="shared.dyn", offset_factor=16)
+                                            C = 
T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), 
scope="wmma.matrix_a", offset_factor=16)
+                                            T.tvm_load_matrix_sync(C.data, 16, 
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + 
C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), 
A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major")
+                                for ax0_0 in T.unroll(2):
+                                    for ax1_0 in T.unroll(1):
+                                        with 
T.block("W_reindex_shared.dyn_wmma.matrix_b_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(16, ax3_0_0 
+ ax1_0)
+                                            T.reads(W_reindex_shared_dyn[v0_o, 
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            
T.writes(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                            A = 
T.match_buffer(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), 
scope="shared.dyn", offset_factor=16)
+                                            C = 
T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), 
scope="wmma.matrix_b", offset_factor=16)
+                                            T.tvm_load_matrix_sync(C.data, 16, 
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + 
C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), 
A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major")
+                                for ax1_0_3, ax2_0_3 in T.grid(2, 2):
+                                    with T.block("compute_o_update"):
+                                        v0_o = T.axis.spatial(1, ax0)
+                                        v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3)
+                                        v2_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3)
+                                        v3_o = T.axis.reduce(16, ax3_0_0 + 
ax3_0_1)
+                                        
T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 
16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], 
W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o 
* 16 + 16])
+                                        
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16])
+                                        with T.block("compute_o"):
+                                            v1_i_o = T.axis.spatial(1, 0)
+                                            v2_i_o = T.axis.spatial(1, 0)
+                                            v3_i_o = T.axis.reduce(1, 0)
+                                            
T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 
16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], 
W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o 
* 16 + 16])
+                                            
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16])
+                                            A = 
T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, 
v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), 
scope="wmma.matrix_a", offset_factor=16)
+                                            B = 
T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, 
v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"), 
scope="wmma.matrix_b", offset_factor=16)
+                                            C = 
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 
16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", 
"C_s1"), scope="wmma.accumulator", offset_factor=16)
+                                            T.tvm_mma_sync(C.data, 
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % 
C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] 
// 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // 
B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, 
C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + 
C.elem_offset % C.strides[0] // 16)
+                        for ax0_0, ax1_0 in T.grid(2, 2):
+                            with 
T.block("compute_reindex_shared.dyn_wmma.accumulator_o"):
+                                v0_o = T.axis.spatial(1, 0)
+                                v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+                                v2_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0)
+                                
T.reads(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16])
+                                T.writes(compute_reindex_shared_dyn[v0_o, v1_o 
* 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                A = 
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o 
* 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("A_s0", 
"A_s1"), scope="wmma.accumulator", offset_factor=16)
+                                C = 
T.match_buffer(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o 
* 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), 
scope="shared.dyn", offset_factor=16)
+                                T.tvm_store_matrix_sync(A.data, 16, 16, 16, 
A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % 
A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int32"), C.data, 
C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major")
+                        for ax0_ax1_fused_0 in range(8):
+                            for ax0_ax1_fused_1 in T.thread_binding(32, 
thread="threadIdx.x"):
+                                for ax0_ax1_fused_2 in T.vectorized(4):
+                                    with T.block("compute_reindex_shared.dyn"):
+                                        v0 = T.axis.spatial(1, 0)
+                                        v1 = T.axis.spatial(256, 
ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 
* 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
+                                        v2 = T.axis.spatial(256, 
ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + 
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
+                                        T.reads(compute_reindex_shared_dyn[v0, 
v1, v2])
+                                        T.writes(compute[v1, v2])
+                                        T.block_attr({"buffer_dim_align": [[0, 
1, 16, 4]]})
+                                        compute[v1, v2] = 
compute_reindex_shared_dyn[v0, v1, v2]
+    # fmt: on
+
+
+class TestMatmulInt8Tensorize3d2dDyn(BaseBeforeAfter):
+    # fmt: off
+    @T.prim_func
+    def before(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), 
var_matmul: T.handle):
+        T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
+        m = T.int32()
+        A = T.match_buffer(var_A, (1, m, 22016), "int8")
+        matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32")
+        # with T.block("root"):
+        for i0, i1, i2, k in T.grid(1, m, 4096, 22016):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k])
+                T.writes(matmul_1[v_i0, v_i1, v_i2])
+                with T.init():
+                    matmul_1[v_i0, v_i1, v_i2] = 0
+                matmul_1[v_i0, v_i1, v_i2] = matmul_1[v_i0, v_i1, v_i2] + 
T.Cast("int32", A[v_i0, v_i1, v_k]) * T.Cast("int32", B[v_i2, v_k])
+
+    @T.prim_func
+    def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), 
var_matmul: T.handle):
+        T.func_attr({"op_pattern": 4, "tir.is_scheduled": 1, "tir.noalias": 
T.bool(True)})
+        m = T.int32()
+        A = T.match_buffer(var_A, (1, m, 22016), "int8")
+        matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32")
+        # with T.block("root"):
+        A_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 128, 
22016), "int8", scope="shared.dyn")
+        B_reindex_shared_dyn = T.alloc_buffer((1, 4096, 22016), "int8", 
scope="shared.dyn")
+        A_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (m + 127) 
// 128 * 128, 22016), "int8", scope="wmma.matrix_a")
+        B_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 4096, 22016), 
"int8", scope="wmma.matrix_b")
+        matmul_1_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 
* 128, 4096), "int32", scope="shared.dyn")
+        matmul_1_reindex_pad_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 
(m + 127) // 128 * 128, 4096), "int32", scope="wmma.accumulator")
+        for ax0 in T.thread_binding(1, thread="blockIdx.z"):
+            for ax1_0_0_ax2_0_0_fused in T.thread_binding((m + 127) // 128, 
thread="blockIdx.x"):
+                for ax1_0_1_ax2_0_1_fused in T.thread_binding(32, 
thread="blockIdx.y"):
+                    for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, 
thread="threadIdx.y"):
+                        for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2):
+                            with T.block("matmul_o_init"):
+                                v0_o = T.axis.spatial(1, ax0)
+                                v1_o = T.axis.spatial((m + 127) // 128 * 8, 
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init)
+                                v2_o = T.axis.spatial(256, 
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init)
+                                T.reads()
+                                
T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 
16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                with T.block("matmul_init_o"):
+                                    v1_i_init_o = T.axis.spatial(1, 0)
+                                    v2_i_init_o = T.axis.spatial(1, 0)
+                                    T.reads()
+                                    
T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 
16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                    C = 
T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 
16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", 
strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16)
+                                    T.tvm_fill_fragment(C.data, 16, 16, 16, 
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % 
C.strides[0] // 16, T.float32(0))
+                        for ax3_0_0 in T.serial(1376, 
annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], 
"software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}):
+                            for ax0_ax1_fused_0 in range(1):
+                                for ax0_ax1_fused_1 in T.thread_binding(16, 
thread="threadIdx.y"):
+                                    for ax0_ax1_fused_2 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                        for ax0_ax1_fused_3 in T.vectorized(4):
+                                            with 
T.block("A_reindex_pad_shared.dyn"):
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial((m + 127) 
// 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + 
ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16)
+                                                v2 = T.axis.spatial(22016, 
ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + 
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16)
+                                                T.reads(A[v0, v1, v2])
+                                                
T.writes(A_reindex_pad_shared_dyn[v0, v1, v2])
+                                                
T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, 
"tir.manifest_shared_memory_local_stage": 1})
+                                                A_reindex_pad_shared_dyn[v0, 
v1, v2] = T.if_then_else(v1 < m, A[v0, v1, v2], T.int8(0))
+                            for ax0_ax1_fused_0 in range(1):
+                                for ax0_ax1_fused_1 in T.thread_binding(16, 
thread="threadIdx.y"):
+                                    for ax0_ax1_fused_2 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                        for ax0_ax1_fused_3 in T.vectorized(4):
+                                            with 
T.block("B_reindex_shared.dyn"):
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial(4096, 
ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + 
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16)
+                                                v2 = T.axis.spatial(22016, 
ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + 
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16)
+                                                T.reads(B[v1, v2])
+                                                
T.writes(B_reindex_shared_dyn[v0, v1, v2])
+                                                
T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, 
"tir.manifest_shared_memory_local_stage": 1})
+                                                B_reindex_shared_dyn[v0, v1, 
v2] = B[v1, v2]
+                            for ax3_0_1 in T.serial(1, 
annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": 
[0, 0, 1]}):
+                                for ax0_0 in T.unroll(2):
+                                    for ax1_0 in T.unroll(1):
+                                        with 
T.block("A_reindex_pad_shared.dyn_wmma.matrix_a_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(8 * ((m + 
127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(1376, 
ax3_0_0 + ax1_0)
+                                            
T.reads(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o 
* 16 + 16])
+                                            
T.writes(A_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                            A_1 = 
T.match_buffer(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), 
scope="shared.dyn", offset_factor=16)
+                                            C = 
T.match_buffer(A_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 
+ 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), 
scope="wmma.matrix_a", offset_factor=16)
+                                            T.tvm_load_matrix_sync(C.data, 16, 
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + 
C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), 
A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "row_major")
+                                for ax0_0 in T.unroll(2):
+                                    for ax1_0 in T.unroll(1):
+                                        with 
T.block("B_reindex_shared.dyn_wmma.matrix_b_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(256, 
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(1376, 
ax3_0_0 + ax1_0)
+                                            T.reads(B_reindex_shared_dyn[v0_o, 
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            
T.writes(B_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                            A_1 = 
T.match_buffer(B_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), 
scope="shared.dyn", offset_factor=16)
+                                            C = 
T.match_buffer(B_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), 
scope="wmma.matrix_b", offset_factor=16)
+                                            T.tvm_load_matrix_sync(C.data, 16, 
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + 
C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), 
A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "col_major")
+                                for ax1_0_3, ax2_0_3 in T.grid(2, 2):
+                                    with T.block("matmul_o_update"):
+                                        v0_o = T.axis.spatial(1, ax0)
+                                        v1_o = T.axis.spatial((m + 127) // 128 
* 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3)
+                                        v2_o = T.axis.spatial(256, 
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3)
+                                        v3_o = T.axis.reduce(1376, ax3_0_0 + 
ax3_0_1)
+                                        
T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 
+ 16, v2_o * 16:v2_o * 16 + 16], A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o 
* 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], 
B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o 
* 16 + 16])
+                                        
T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 
16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                        with T.block("matmul_o"):
+                                            v1_i_o = T.axis.spatial(1, 0)
+                                            v2_i_o = T.axis.spatial(1, 0)
+                                            v3_i_o = T.axis.reduce(1, 0)
+                                            
T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 
+ 16, v2_o * 16:v2_o * 16 + 16], A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o 
* 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], 
B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o 
* 16 + 16])
+                                            
T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 
16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            A_1 = 
T.match_buffer(A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 
16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), 
scope="wmma.matrix_a", offset_factor=16)
+                                            B_1 = 
T.match_buffer(B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, 
v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"), 
scope="wmma.matrix_b", offset_factor=16)
+                                            C = 
T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 
16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", 
strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16)
+                                            T.tvm_mma_sync(C.data, 
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % 
C.strides[0] // 16, A_1.data, A_1.elem_offset // A_1.strides[0] // 16 * 
(A_1.strides[0] // 16) + A_1.elem_offset % A_1.strides[0] // 16, B_1.data, 
B_1.elem_offset // B_1.strides[0] // 16 * (B_1.strides[0] // 16) + 
B_1.elem_offset % B_1.strides[0] // 16, C.data, C.elem_offset // C.strides[0] 
// 16 * (C.strides[0] // 16) + C.elem_offset % C.strides [...]
+                        for ax0_0, ax1_0 in T.grid(2, 2):
+                            with 
T.block("matmul_1_reindex_pad_shared.dyn_wmma.accumulator_o"):
+                                v0_o = T.axis.spatial(1, 0)
+                                v1_o = T.axis.spatial(8 * ((m + 127) // 128), 
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+                                v2_o = T.axis.spatial(256, 
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0)
+                                
T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 
16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                T.writes(matmul_1_reindex_pad_shared_dyn[v0_o, 
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                A_1 = 
T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 
16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", 
strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16)
+                                C = 
T.match_buffer(matmul_1_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), 
scope="shared.dyn", offset_factor=16)
+                                T.tvm_store_matrix_sync(A_1.data, 16, 16, 16, 
A_1.elem_offset // A_1.strides[0] // 16 * (A_1.strides[0] // 16) + 
A_1.elem_offset % A_1.strides[0] // 16, 
T.tvm_access_ptr(T.type_annotation("int32"), C.data, C.elem_offset, 
C.strides[0] * 16, 2), C.strides[0], "row_major")
+                        for ax0_ax1_fused_0 in range(8):
+                            for ax0_ax1_fused_1 in T.thread_binding(32, 
thread="threadIdx.x"):
+                                for ax0_ax1_fused_2 in T.vectorized(4):
+                                    with 
T.block("matmul_1_reindex_pad_shared.dyn"):
+                                        v0 = T.axis.spatial(1, 0)
+                                        v1 = T.axis.spatial((m + 127) // 128 * 
128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + 
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
+                                        v2 = T.axis.spatial(4096, 
ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + 
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
+                                        
T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2])
+                                        T.writes(matmul_1[0, v1, v2])
+                                        T.block_attr({"buffer_dim_align": [[0, 
1, 16, 4]]})
+                                        if v1 < m:
+                                            matmul_1[0, v1, v2] = 
matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
+    # fmt: on
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to