This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new fd5a41f6f0 [Unity][Dlight] Add schedule rule for decode transpose 
(#15304)
fd5a41f6f0 is described below

commit fd5a41f6f081b4c28e55da02e08e59997dbf6be5
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Jul 14 07:05:03 2023 -0700

    [Unity][Dlight] Add schedule rule for decode transpose (#15304)
---
 python/tvm/dlight/base/__init__.py        |   8 +-
 python/tvm/dlight/base/analysis.py        |  46 +++++++-
 python/tvm/dlight/gpu/__init__.py         |   1 +
 python/tvm/dlight/gpu/decode_gemv.py      |  52 +-------
 python/tvm/dlight/gpu/transpose.py        | 129 ++++++++++++++++++++
 tests/python/dlight/test_gpu_transpose.py | 189 ++++++++++++++++++++++++++++++
 6 files changed, 377 insertions(+), 48 deletions(-)

diff --git a/python/tvm/dlight/base/__init__.py 
b/python/tvm/dlight/base/__init__.py
index b69c82fca0..d3a0598980 100644
--- a/python/tvm/dlight/base/__init__.py
+++ b/python/tvm/dlight/base/__init__.py
@@ -15,7 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 """Base infra"""
-from .analysis import BlockInfo, IterInfo, normalize_prim_func
+from .analysis import (
+    BlockInfo,
+    IterInfo,
+    normalize_prim_func,
+    detect_dominant_read,
+    is_broadcast_epilogue,
+)
 from .common_schedules import try_inline, try_inline_contiguous_spatial
 from .schedule_rule import ScheduleRule
 from .transform import ApplyDefaultSchedule
diff --git a/python/tvm/dlight/base/analysis.py 
b/python/tvm/dlight/base/analysis.py
index 6e16239910..1ef257c530 100644
--- a/python/tvm/dlight/base/analysis.py
+++ b/python/tvm/dlight/base/analysis.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Analysis on TIR blocks, loops and functions."""
-from typing import List, Optional, Union
+from typing import List, Optional, Union, Set
 
 from typing_extensions import Literal
 
@@ -206,3 +206,47 @@ def get_root_block(sch: Schedule, func_name: str = "main") 
-> BlockRV:
             f"{sch.mod[func_name].body}"
         )
     return sch.get_block(block.name_hint)
+
+
+def _collect_vars_used_in_access_region(region: List[ir.Range]) -> 
Set[tir.Var]:
+    tir_vars: Set[tir.Var] = set()
+
+    def _collect_tir_var(expr):
+        if isinstance(expr, tir.Var):
+            tir_vars.add(expr)
+
+    for expr in region:
+        assert expr.extent == 1
+        tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var)
+    return tir_vars
+
+
+def detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
+    """Detect the dominant read indices in the block."""
+    dominant_read = None
+    num_read_iters = -1
+    for buffer_region in block.reads:
+        tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
+        if num_read_iters < len(tir_vars):
+            num_read_iters = len(tir_vars)
+            dominant_read = buffer_region
+    assert dominant_read is not None
+    (result,) = dominant_read.buffer.offset_of([e.min for e in 
dominant_read.region])
+    return result
+
+
+def is_broadcast_epilogue(
+    sch: tir.Schedule,
+    block: tir.schedule.BlockRV,
+    epilogue: tir.schedule.BlockRV,
+) -> bool:
+    """Check if the epilogue block is a broadcast pattern"""
+    write_buffers = {r.buffer for r in sch.get(block).writes}
+    epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom 
!= 1}
+    for buffer_region in sch.get(epilogue).reads:
+        if buffer_region.buffer not in write_buffers:
+            continue
+        tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
+        if len(tir_vars) < len(epilogue_iters):
+            return True
+    return False
diff --git a/python/tvm/dlight/gpu/__init__.py 
b/python/tvm/dlight/gpu/__init__.py
index 934928ffaf..ca1cc8d5f7 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -22,3 +22,4 @@ from .decode_gemv import DecodeGEMV
 from .fallback import Fallback
 from .matmul import Matmul
 from .reduction import Reduction
+from .transpose import Transpose
diff --git a/python/tvm/dlight/gpu/decode_gemv.py 
b/python/tvm/dlight/gpu/decode_gemv.py
index 6c7e31181b..afcfdb3020 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """A rule for DecodeGEMV."""
-from typing import List, Optional, Set, Tuple, Union
+from typing import List, Optional, Tuple, Union
 
 from tvm import arith, ir, tir
 from tvm.target import Target
@@ -25,6 +25,8 @@ from ..base import (
     ScheduleRule,
     normalize_prim_func,
     try_inline_contiguous_spatial,
+    detect_dominant_read,
+    is_broadcast_epilogue,
 )
 from . import utils
 
@@ -45,48 +47,6 @@ def _get_reduction_expr(block: tir.Block) -> 
Optional[tir.PrimExpr]:
     return buffer_store.value.b
 
 
-def _collect_vars_used_in_access_region(region: List[ir.Range]) -> 
Set[tir.Var]:
-    tir_vars: Set[tir.Var] = set()
-
-    def _collect_tir_var(expr):
-        if isinstance(expr, tir.Var):
-            tir_vars.add(expr)
-
-    for expr in region:
-        assert expr.extent == 1
-        tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var)
-    return tir_vars
-
-
-def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
-    dominant_read = None
-    num_read_iters = -1
-    for buffer_region in block.reads:
-        tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
-        if num_read_iters < len(tir_vars):
-            num_read_iters = len(tir_vars)
-            dominant_read = buffer_region
-    assert dominant_read is not None
-    (result,) = dominant_read.buffer.offset_of([e.min for e in 
dominant_read.region])
-    return result
-
-
-def _is_broadcast_epilogue(
-    sch: tir.Schedule,
-    block: tir.schedule.BlockRV,
-    epilogue: tir.schedule.BlockRV,
-) -> bool:
-    write_buffers = {r.buffer for r in sch.get(block).writes}
-    epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom 
!= 1}
-    for buffer_region in sch.get(epilogue).reads:
-        if buffer_region.buffer not in write_buffers:
-            continue
-        tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
-        if len(tir_vars) < len(epilogue_iters):
-            return True
-    return False
-
-
 class DecodeGEMV(ScheduleRule):
     """A rule for DecodeGEMV."""
 
@@ -128,7 +88,7 @@ class DecodeGEMV(ScheduleRule):
             sch,
             block_info,
             arith.normalize_to_iter_sum(
-                _detect_dominant_read(block_stmt),
+                detect_dominant_read(block_stmt),
                 input_iters={i.var: i.dom for i in block_stmt.iter_vars},
             ),
         )
@@ -223,7 +183,7 @@ class DecodeGEMV(ScheduleRule):
         if epilogue_info is not None:
             epilogue = epilogue_info.block_rv
             sch.reverse_compute_at(epilogue, bx)
-            if _is_broadcast_epilogue(sch, block, epilogue):
+            if is_broadcast_epilogue(sch, block, epilogue):
                 sch.set_scope(block, 0, "shared")
                 _, *s = sch.get_loops(epilogue)  # pylint: disable=invalid-name
                 _, tx = sch.split(sch.fuse(*s), factors=[None, len_tx])
@@ -268,7 +228,7 @@ class DecodeGEMV(ScheduleRule):
         if epilogue_info is not None:
             epilogue = epilogue_info.block_rv
             sch.reverse_compute_at(epilogue, bx)
-            if _is_broadcast_epilogue(sch, block, epilogue):
+            if is_broadcast_epilogue(sch, block, epilogue):
                 sch.set_scope(block, 0, "shared")
                 _, *s = sch.get_loops(epilogue)  # pylint: disable=invalid-name
                 _, tx, ty = sch.split(sch.fuse(*s), factors=[None, len_tx, 
len_ty])
diff --git a/python/tvm/dlight/gpu/transpose.py 
b/python/tvm/dlight/gpu/transpose.py
new file mode 100644
index 0000000000..0a5cebc89e
--- /dev/null
+++ b/python/tvm/dlight/gpu/transpose.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Reduction rule for operators including softmax, layer norm, RMS norm, etc"""
+from typing import List, Union
+
+from tvm import tir, arith
+from tvm.target import Target
+from tvm.tir import Schedule
+from tvm.tir.schedule import BlockRV
+
+
+from ..base import (
+    ScheduleRule,
+    normalize_prim_func,
+    try_inline_contiguous_spatial,
+    detect_dominant_read,
+)
+
+
+class Transpose(ScheduleRule):
+    """Schedule rule for transpose"""
+
+    def is_transpose(self, sch: Schedule, block_rv: BlockRV):
+        block = sch.get(block_rv)
+        if isinstance(block.body, tir.BufferStore):
+            rhs = block.body.value
+            if isinstance(rhs, tir.BufferLoad):
+                lhs_indices = block.body.indices
+                rhs_indices = rhs.indices
+                if list(lhs_indices) != list(rhs_indices) and set(lhs_indices) 
== set(rhs_indices):
+                    return True
+        return False
+
+    def apply(  # pylint: disable=too-many-locals
+        self,
+        func: tir.PrimFunc,
+        target: Target,
+        _: bool,
+    ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
+        # pylint: disable=invalid-name
+        if target.kind.name == "cuda":
+            len_tx = 16
+            len_ty = 8
+            unroll_depth = 256
+        else:
+            len_tx = 8
+            len_ty = 4
+            unroll_depth = 64
+        len_vec = 4
+
+        sch = tir.Schedule(func)
+        blocks = normalize_prim_func(sch)
+        transpose_block_idx = -1
+        for idx, block in reversed(list(enumerate(blocks))):
+            if self.is_transpose(sch, block.block_rv):
+                transpose_block_idx = idx
+                break
+            if not block.is_injective():
+                return None
+        if transpose_block_idx == -1:
+            return None
+        transpose_block = blocks[transpose_block_idx].block_rv
+
+        prologue = None  # the optional decoding block
+        if transpose_block_idx > 0:
+            spatials = try_inline_contiguous_spatial(sch, blocks[: 
transpose_block_idx - 1])
+            assert len(spatials) == 0
+            prologue = blocks[transpose_block_idx - 1].block_rv
+
+        loops = sch.get_loops(transpose_block)
+        if len(loops) != 2:
+            # transpose with more than 2 axes is not supported
+            return None
+
+        c_factor = 1
+        if prologue is not None:
+            block_stmt = sch.get(prologue)
+            print(detect_dominant_read(block_stmt))
+            result = arith.normalize_to_iter_sum(
+                detect_dominant_read(block_stmt),
+                input_iters={i.var: i.dom for i in block_stmt.iter_vars},
+            )
+            if len(result.args) > 0:
+                c_factor = int(result.args[0].lower_factor)
+
+        i, j = loops
+        i, vi = sch.split(i, factors=[None, c_factor], 
preserve_unit_iters=True)
+        bi, ti = sch.split(i, factors=[None, len_ty], preserve_unit_iters=True)
+        bj, tj = sch.split(j, factors=[None, len_tx], preserve_unit_iters=True)
+        sch.reorder(bi, bj, ti, tj, vi)
+        sch.bind(bi, "blockIdx.y")
+        sch.bind(bj, "blockIdx.x")
+        sch.bind(ti, "threadIdx.y")
+        sch.bind(tj, "threadIdx.x")
+        len_vec = min(len_vec, c_factor)
+        _, vi = sch.split(vi, factors=[None, len_vec])
+        if len_vec > 1:
+            sch.vectorize(vi)
+
+        cache_read = sch.cache_read(transpose_block, read_buffer_index=0, 
storage_scope="shared")
+        sch.compute_at(cache_read, bj)
+        loops = sch.get_loops(cache_read)[2:]
+        fused = sch.fuse(*loops)
+        _, ty, tx, v = sch.split(fused, factors=[None, len_ty, len_tx, 
c_factor])
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tx, "threadIdx.x")
+        sch.unroll(v)
+        sch.storage_align(block=cache_read, buffer_index=0, axis=0, factor=32, 
offset=1)
+
+        sch.annotate(bi, ann_key="pragma_auto_unroll_max_step", 
ann_val=unroll_depth)
+        sch.annotate(bi, ann_key="pragma_unroll_explicit", ann_val=1)
+
+        if prologue is not None:
+            sch.compute_inline(prologue)
+        return sch
diff --git a/tests/python/dlight/test_gpu_transpose.py 
b/tests/python/dlight/test_gpu_transpose.py
new file mode 100644
index 0000000000..c4313fe6a2
--- /dev/null
+++ b/tests/python/dlight/test_gpu_transpose.py
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm
+from tvm import dlight as dl
+from tvm.ir import IRModule, assert_structural_equal
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def _check(mod_before: IRModule, mod_after: IRModule):
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
+            dl.gpu.Transpose(),
+        )(mod_before)
+    assert_structural_equal(mod, mod_after)
+
+
+def test_transpose():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), 
"float32"), T_transpose: T.Buffer((T.int64(4096), T.int64(512)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(4096), T.int64(512)):
+                with T.block("T_transpose"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), 
"float32"), T_transpose: T.Buffer((T.int64(4096), T.int64(512)), "float32")):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            rxplaceholder_shared = T.alloc_buffer((T.int64(512), 
T.int64(4096)), scope="shared")
+            for ax0_0_0 in T.thread_binding(T.int64(512), thread="blockIdx.y", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                for ax1_0 in T.thread_binding(T.int64(32), 
thread="blockIdx.x"):
+                    for ax0_ax1_fused_0 in range(T.int64(1)):
+                        for ax0_ax1_fused_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.y"):
+                            for ax0_ax1_fused_2 in 
T.thread_binding(T.int64(16), thread="threadIdx.x"):
+                                for ax0_ax1_fused_3 in T.unroll(T.int64(1)):
+                                    with T.block("rxplaceholder_shared"):
+                                        v0 = T.axis.spatial(T.int64(512), 
ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * 
T.int64(16) + ax0_ax1_fused_2 + ax0_ax1_fused_3) // T.int64(8))
+                                        v1 = T.axis.spatial(T.int64(4096), 
ax0_0_0 * T.int64(8) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 * 
T.int64(16) + ax0_ax1_fused_2 + ax0_ax1_fused_3) % T.int64(8))
+                                        T.reads(rxplaceholder[v0, v1])
+                                        T.writes(rxplaceholder_shared[v0, v1])
+                                        T.block_attr({"buffer_dim_align": [[0, 
0, 32, 1]]})
+                                        rxplaceholder_shared[v0, v1] = 
rxplaceholder[v0, v1]
+                    for ax0_0_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.y"):
+                        for ax1_1 in T.thread_binding(T.int64(16), 
thread="threadIdx.x"):
+                            for ax0_1_0 in range(T.int64(1)):
+                                for ax0_1_1 in range(T.int64(1)):
+                                    with T.block("T_transpose"):
+                                        v0 = T.axis.spatial(T.int64(4096), 
ax0_0_0 * T.int64(8) + ax0_0_1 + ax0_1_0 + ax0_1_1)
+                                        v1 = T.axis.spatial(T.int64(512), 
ax1_0 * T.int64(16) + ax1_1)
+                                        T.reads(rxplaceholder_shared[v1, v0])
+                                        T.writes(T_transpose[v0, v1])
+                                        T_transpose[v0, v1] = 
rxplaceholder_shared[v1, v0]
+    # fmt: on
+    _check(Before, After)
+
+
+def test_decode_transpose():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), 
"uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), 
T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            decode = T.alloc_buffer((T.int64(4096), T.int64(4096)))
+            for i, j in T.grid(T.int64(4096), T.int64(4096)):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(rxplaceholder[v_i // T.int64(8), v_j], 
rxplaceholder_1[v_i // T.int64(32), v_j])
+                    T.writes(decode[v_i, v_j])
+                    decode[v_i, v_j] = T.Cast("float32", 
T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j], 
T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) * 
T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v_i // 
T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", 
T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v_i // T.int64(32), 
v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
+            for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
+                with T.block("T_transpose"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(decode[v_ax1, v_ax0])
+                    T.writes(T_transpose[v_ax0, v_ax1])
+                    T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), 
"uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), 
T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float32")):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            decode_shared = T.alloc_buffer((T.int64(4096), T.int64(4096)), 
scope="shared")
+            for ax0_0_0 in T.thread_binding(T.int64(64), thread="blockIdx.y", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                for ax1_0 in T.thread_binding(T.int64(256), 
thread="blockIdx.x"):
+                    for ax0_ax1_fused_0 in range(T.int64(1)):
+                        for ax0_ax1_fused_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.y"):
+                            for ax0_ax1_fused_2 in 
T.thread_binding(T.int64(16), thread="threadIdx.x"):
+                                for ax0_ax1_fused_3 in T.unroll(T.int64(8)):
+                                    with T.block("decode_shared"):
+                                        v0 = T.axis.spatial(T.int64(4096), 
ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * 
T.int64(128) + ax0_ax1_fused_2 * T.int64(8) + ax0_ax1_fused_3) // T.int64(64))
+                                        v1 = T.axis.spatial(T.int64(4096), 
ax0_0_0 * T.int64(64) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * 
T.int64(128) + ax0_ax1_fused_2 * T.int64(8) + ax0_ax1_fused_3) % T.int64(64))
+                                        T.reads(rxplaceholder[v0 // 
T.int64(8), v1], rxplaceholder_1[v0 // T.int64(32), v1])
+                                        T.writes(decode_shared[v0, v1])
+                                        T.block_attr({"buffer_dim_align": [[0, 
0, 32, 1]]})
+                                        decode_shared[v0, v1] = 
T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v0 // T.int64(8), 
v1], T.Cast("uint32", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) * 
T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v0 // 
T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32", 
T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v0 // T.int64(32), 
v1], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
+                    for ax0_0_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.y"):
+                        for ax1_1 in T.thread_binding(T.int64(16), 
thread="threadIdx.x"):
+                            for ax0_1_0 in range(T.int64(2)):
+                                for ax0_1_1 in T.vectorized(T.int64(4)):
+                                    with T.block("T_transpose"):
+                                        v0 = T.axis.spatial(T.int64(4096), 
ax0_0_0 * T.int64(64) + ax0_0_1 * T.int64(8) + ax0_1_0 * T.int64(4) + ax0_1_1)
+                                        v1 = T.axis.spatial(T.int64(4096), 
ax1_0 * T.int64(16) + ax1_1)
+                                        T.reads(decode_shared[v1, v0])
+                                        T.writes(T_transpose[v0, v1])
+                                        T_transpose[v0, v1] = 
decode_shared[v1, v0]
+    # fmt: on
+    _check(Before, After)
+
+
+def test_decode_int3_transpose():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: 
T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: 
T.Buffer((T.int64(4096), T.int64(4096)), "float16")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            decode_1 = T.alloc_buffer((T.int64(4096), T.int64(4096)), 
"float16")
+            for i, j in T.grid(T.int64(4096), T.int64(4096)):
+                with T.block("decode"):
+                    v_i, v_j = T.axis.remap("SS", [i, j])
+                    T.reads(A[v_i // T.int64(10), v_j], B[v_i // T.int64(40), 
v_j])
+                    T.writes(decode_1[v_i, v_j])
+                    decode_1[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(A[v_i // T.int64(10), v_j], T.Cast("uint32", v_i % 
T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i // 
T.int64(40), v_j]
+            for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
+                with T.block("T_transpose"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(decode_1[v_ax1, v_ax0])
+                    T.writes(T_transpose[v_ax0, v_ax1])
+                    T_transpose[v_ax0, v_ax1] = decode_1[v_ax1, v_ax0]
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: 
T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: 
T.Buffer((T.int64(4096), T.int64(4096)), "float16")):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            decode_1_shared = T.alloc_buffer((T.int64(4096), T.int64(4096)), 
"float16", scope="shared")
+            for ax0_0_0 in T.thread_binding(T.int64(52), thread="blockIdx.y", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                for ax1_0 in T.thread_binding(T.int64(256), 
thread="blockIdx.x"):
+                    for ax0_ax1_fused_0 in range(T.int64(2)):
+                        for ax0_ax1_fused_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.y"):
+                            for ax0_ax1_fused_2 in 
T.thread_binding(T.int64(16), thread="threadIdx.x"):
+                                for ax0_ax1_fused_3 in T.unroll(T.int64(10)):
+                                    with T.block("decode_1_shared"):
+                                        v0 = T.axis.spatial(T.int64(4096), 
ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(1280) + ax0_ax1_fused_1 * 
T.int64(160) + ax0_ax1_fused_2 * T.int64(10) + ax0_ax1_fused_3) // T.int64(82))
+                                        v1 = T.axis.spatial(T.int64(4096), 
ax0_0_0 * T.int64(80) + (ax0_ax1_fused_0 * T.int64(1280) + ax0_ax1_fused_1 * 
T.int64(160) + ax0_ax1_fused_2 * T.int64(10) + ax0_ax1_fused_3) % T.int64(82))
+                                        T.where(ax0_0_0 * T.int64(80) + 
(((ax0_ax1_fused_0 * T.int64(8) + ax0_ax1_fused_1) * T.int64(16) + 
ax0_ax1_fused_2) * T.int64(10) + ax0_ax1_fused_3) % T.int64(82) < T.int64(4096) 
and ((ax0_ax1_fused_0 * T.int64(8) + ax0_ax1_fused_1) * T.int64(16) + 
ax0_ax1_fused_2) * T.int64(10) + ax0_ax1_fused_3 < T.int64(1312))
+                                        T.reads(A[v0 // T.int64(10), v1], B[v0 
// T.int64(40), v1])
+                                        T.writes(decode_1_shared[v0, v1])
+                                        T.block_attr({"buffer_dim_align": [[0, 
0, 32, 1]]})
+                                        decode_1_shared[v0, v1] = 
(T.Cast("float16", T.bitwise_and(T.shift_right(A[v0 // T.int64(10), v1], 
T.Cast("uint32", v0 % T.int64(10)) * T.uint32(3)), T.uint32(7))) - 
T.float16(3)) * B[v0 // T.int64(40), v1]
+                    for ax0_0_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.y"):
+                        for ax1_1 in T.thread_binding(T.int64(16), 
thread="threadIdx.x"):
+                            for ax0_1_0 in range(T.int64(3)):
+                                for ax0_1_1 in T.vectorized(T.int64(4)):
+                                    with T.block("T_transpose"):
+                                        v0 = T.axis.spatial(T.int64(4096), 
(ax0_0_0 * T.int64(8) + ax0_0_1) * T.int64(10) + (ax0_1_0 * T.int64(4) + 
ax0_1_1))
+                                        v1 = T.axis.spatial(T.int64(4096), 
ax1_0 * T.int64(16) + ax1_1)
+                                        T.where((ax0_0_0 * T.int64(8) + 
ax0_0_1) * T.int64(10) + (ax0_1_0 * T.int64(4) + ax0_1_1) < T.int64(4096) and 
ax0_0_0 * T.int64(8) + ax0_0_1 < T.int64(410) and ax0_1_0 * T.int64(4) + 
ax0_1_1 < T.int64(10))
+                                        T.reads(decode_1_shared[v1, v0])
+                                        T.writes(T_transpose[v0, v1])
+                                        T_transpose[v0, v1] = 
decode_1_shared[v1, v0]
+    # fmt: on
+    _check(Before, After)

Reply via email to