This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new fd5a41f6f0 [Unity][Dlight] Add schedule rule for decode transpose
(#15304)
fd5a41f6f0 is described below
commit fd5a41f6f081b4c28e55da02e08e59997dbf6be5
Author: Wuwei Lin <[email protected]>
AuthorDate: Fri Jul 14 07:05:03 2023 -0700
[Unity][Dlight] Add schedule rule for decode transpose (#15304)
---
python/tvm/dlight/base/__init__.py | 8 +-
python/tvm/dlight/base/analysis.py | 46 +++++++-
python/tvm/dlight/gpu/__init__.py | 1 +
python/tvm/dlight/gpu/decode_gemv.py | 52 +-------
python/tvm/dlight/gpu/transpose.py | 129 ++++++++++++++++++++
tests/python/dlight/test_gpu_transpose.py | 189 ++++++++++++++++++++++++++++++
6 files changed, 377 insertions(+), 48 deletions(-)
diff --git a/python/tvm/dlight/base/__init__.py
b/python/tvm/dlight/base/__init__.py
index b69c82fca0..d3a0598980 100644
--- a/python/tvm/dlight/base/__init__.py
+++ b/python/tvm/dlight/base/__init__.py
@@ -15,7 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""Base infra"""
-from .analysis import BlockInfo, IterInfo, normalize_prim_func
+from .analysis import (
+ BlockInfo,
+ IterInfo,
+ normalize_prim_func,
+ detect_dominant_read,
+ is_broadcast_epilogue,
+)
from .common_schedules import try_inline, try_inline_contiguous_spatial
from .schedule_rule import ScheduleRule
from .transform import ApplyDefaultSchedule
diff --git a/python/tvm/dlight/base/analysis.py
b/python/tvm/dlight/base/analysis.py
index 6e16239910..1ef257c530 100644
--- a/python/tvm/dlight/base/analysis.py
+++ b/python/tvm/dlight/base/analysis.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Analysis on TIR blocks, loops and functions."""
-from typing import List, Optional, Union
+from typing import List, Optional, Union, Set
from typing_extensions import Literal
@@ -206,3 +206,47 @@ def get_root_block(sch: Schedule, func_name: str = "main")
-> BlockRV:
f"{sch.mod[func_name].body}"
)
return sch.get_block(block.name_hint)
+
+
+def _collect_vars_used_in_access_region(region: List[ir.Range]) ->
Set[tir.Var]:
+ tir_vars: Set[tir.Var] = set()
+
+ def _collect_tir_var(expr):
+ if isinstance(expr, tir.Var):
+ tir_vars.add(expr)
+
+ for expr in region:
+ assert expr.extent == 1
+ tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var)
+ return tir_vars
+
+
+def detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
+ """Detect the dominant read indices in the block."""
+ dominant_read = None
+ num_read_iters = -1
+ for buffer_region in block.reads:
+ tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
+ if num_read_iters < len(tir_vars):
+ num_read_iters = len(tir_vars)
+ dominant_read = buffer_region
+ assert dominant_read is not None
+ (result,) = dominant_read.buffer.offset_of([e.min for e in
dominant_read.region])
+ return result
+
+
+def is_broadcast_epilogue(
+ sch: tir.Schedule,
+ block: tir.schedule.BlockRV,
+ epilogue: tir.schedule.BlockRV,
+) -> bool:
+ """Check if the epilogue block is a broadcast pattern"""
+ write_buffers = {r.buffer for r in sch.get(block).writes}
+ epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom
!= 1}
+ for buffer_region in sch.get(epilogue).reads:
+ if buffer_region.buffer not in write_buffers:
+ continue
+ tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
+ if len(tir_vars) < len(epilogue_iters):
+ return True
+ return False
diff --git a/python/tvm/dlight/gpu/__init__.py
b/python/tvm/dlight/gpu/__init__.py
index 934928ffaf..ca1cc8d5f7 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -22,3 +22,4 @@ from .decode_gemv import DecodeGEMV
from .fallback import Fallback
from .matmul import Matmul
from .reduction import Reduction
+from .transpose import Transpose
diff --git a/python/tvm/dlight/gpu/decode_gemv.py
b/python/tvm/dlight/gpu/decode_gemv.py
index 6c7e31181b..afcfdb3020 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""A rule for DecodeGEMV."""
-from typing import List, Optional, Set, Tuple, Union
+from typing import List, Optional, Tuple, Union
from tvm import arith, ir, tir
from tvm.target import Target
@@ -25,6 +25,8 @@ from ..base import (
ScheduleRule,
normalize_prim_func,
try_inline_contiguous_spatial,
+ detect_dominant_read,
+ is_broadcast_epilogue,
)
from . import utils
@@ -45,48 +47,6 @@ def _get_reduction_expr(block: tir.Block) ->
Optional[tir.PrimExpr]:
return buffer_store.value.b
-def _collect_vars_used_in_access_region(region: List[ir.Range]) ->
Set[tir.Var]:
- tir_vars: Set[tir.Var] = set()
-
- def _collect_tir_var(expr):
- if isinstance(expr, tir.Var):
- tir_vars.add(expr)
-
- for expr in region:
- assert expr.extent == 1
- tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var)
- return tir_vars
-
-
-def _detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
- dominant_read = None
- num_read_iters = -1
- for buffer_region in block.reads:
- tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
- if num_read_iters < len(tir_vars):
- num_read_iters = len(tir_vars)
- dominant_read = buffer_region
- assert dominant_read is not None
- (result,) = dominant_read.buffer.offset_of([e.min for e in
dominant_read.region])
- return result
-
-
-def _is_broadcast_epilogue(
- sch: tir.Schedule,
- block: tir.schedule.BlockRV,
- epilogue: tir.schedule.BlockRV,
-) -> bool:
- write_buffers = {r.buffer for r in sch.get(block).writes}
- epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom
!= 1}
- for buffer_region in sch.get(epilogue).reads:
- if buffer_region.buffer not in write_buffers:
- continue
- tir_vars = _collect_vars_used_in_access_region(buffer_region.region)
- if len(tir_vars) < len(epilogue_iters):
- return True
- return False
-
-
class DecodeGEMV(ScheduleRule):
"""A rule for DecodeGEMV."""
@@ -128,7 +88,7 @@ class DecodeGEMV(ScheduleRule):
sch,
block_info,
arith.normalize_to_iter_sum(
- _detect_dominant_read(block_stmt),
+ detect_dominant_read(block_stmt),
input_iters={i.var: i.dom for i in block_stmt.iter_vars},
),
)
@@ -223,7 +183,7 @@ class DecodeGEMV(ScheduleRule):
if epilogue_info is not None:
epilogue = epilogue_info.block_rv
sch.reverse_compute_at(epilogue, bx)
- if _is_broadcast_epilogue(sch, block, epilogue):
+ if is_broadcast_epilogue(sch, block, epilogue):
sch.set_scope(block, 0, "shared")
_, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name
_, tx = sch.split(sch.fuse(*s), factors=[None, len_tx])
@@ -268,7 +228,7 @@ class DecodeGEMV(ScheduleRule):
if epilogue_info is not None:
epilogue = epilogue_info.block_rv
sch.reverse_compute_at(epilogue, bx)
- if _is_broadcast_epilogue(sch, block, epilogue):
+ if is_broadcast_epilogue(sch, block, epilogue):
sch.set_scope(block, 0, "shared")
_, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name
_, tx, ty = sch.split(sch.fuse(*s), factors=[None, len_tx,
len_ty])
diff --git a/python/tvm/dlight/gpu/transpose.py
b/python/tvm/dlight/gpu/transpose.py
new file mode 100644
index 0000000000..0a5cebc89e
--- /dev/null
+++ b/python/tvm/dlight/gpu/transpose.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Reduction rule for operators including softmax, layer norm, RMS norm, etc"""
+from typing import List, Union
+
+from tvm import tir, arith
+from tvm.target import Target
+from tvm.tir import Schedule
+from tvm.tir.schedule import BlockRV
+
+
+from ..base import (
+ ScheduleRule,
+ normalize_prim_func,
+ try_inline_contiguous_spatial,
+ detect_dominant_read,
+)
+
+
+class Transpose(ScheduleRule):
+ """Schedule rule for transpose"""
+
+ def is_transpose(self, sch: Schedule, block_rv: BlockRV):
+ block = sch.get(block_rv)
+ if isinstance(block.body, tir.BufferStore):
+ rhs = block.body.value
+ if isinstance(rhs, tir.BufferLoad):
+ lhs_indices = block.body.indices
+ rhs_indices = rhs.indices
+ if list(lhs_indices) != list(rhs_indices) and set(lhs_indices)
== set(rhs_indices):
+ return True
+ return False
+
+ def apply( # pylint: disable=too-many-locals
+ self,
+ func: tir.PrimFunc,
+ target: Target,
+ _: bool,
+ ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
+ # pylint: disable=invalid-name
+ if target.kind.name == "cuda":
+ len_tx = 16
+ len_ty = 8
+ unroll_depth = 256
+ else:
+ len_tx = 8
+ len_ty = 4
+ unroll_depth = 64
+ len_vec = 4
+
+ sch = tir.Schedule(func)
+ blocks = normalize_prim_func(sch)
+ transpose_block_idx = -1
+ for idx, block in reversed(list(enumerate(blocks))):
+ if self.is_transpose(sch, block.block_rv):
+ transpose_block_idx = idx
+ break
+ if not block.is_injective():
+ return None
+ if transpose_block_idx == -1:
+ return None
+ transpose_block = blocks[transpose_block_idx].block_rv
+
+ prologue = None # the optional decoding block
+ if transpose_block_idx > 0:
+ spatials = try_inline_contiguous_spatial(sch, blocks[:
transpose_block_idx - 1])
+ assert len(spatials) == 0
+ prologue = blocks[transpose_block_idx - 1].block_rv
+
+ loops = sch.get_loops(transpose_block)
+ if len(loops) != 2:
+ # transpose with more than 2 axes is not supported
+ return None
+
+ c_factor = 1
+ if prologue is not None:
+ block_stmt = sch.get(prologue)
+ print(detect_dominant_read(block_stmt))
+ result = arith.normalize_to_iter_sum(
+ detect_dominant_read(block_stmt),
+ input_iters={i.var: i.dom for i in block_stmt.iter_vars},
+ )
+ if len(result.args) > 0:
+ c_factor = int(result.args[0].lower_factor)
+
+ i, j = loops
+ i, vi = sch.split(i, factors=[None, c_factor],
preserve_unit_iters=True)
+ bi, ti = sch.split(i, factors=[None, len_ty], preserve_unit_iters=True)
+ bj, tj = sch.split(j, factors=[None, len_tx], preserve_unit_iters=True)
+ sch.reorder(bi, bj, ti, tj, vi)
+ sch.bind(bi, "blockIdx.y")
+ sch.bind(bj, "blockIdx.x")
+ sch.bind(ti, "threadIdx.y")
+ sch.bind(tj, "threadIdx.x")
+ len_vec = min(len_vec, c_factor)
+ _, vi = sch.split(vi, factors=[None, len_vec])
+ if len_vec > 1:
+ sch.vectorize(vi)
+
+ cache_read = sch.cache_read(transpose_block, read_buffer_index=0,
storage_scope="shared")
+ sch.compute_at(cache_read, bj)
+ loops = sch.get_loops(cache_read)[2:]
+ fused = sch.fuse(*loops)
+ _, ty, tx, v = sch.split(fused, factors=[None, len_ty, len_tx,
c_factor])
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+ sch.unroll(v)
+ sch.storage_align(block=cache_read, buffer_index=0, axis=0, factor=32,
offset=1)
+
+ sch.annotate(bi, ann_key="pragma_auto_unroll_max_step",
ann_val=unroll_depth)
+ sch.annotate(bi, ann_key="pragma_unroll_explicit", ann_val=1)
+
+ if prologue is not None:
+ sch.compute_inline(prologue)
+ return sch
diff --git a/tests/python/dlight/test_gpu_transpose.py
b/tests/python/dlight/test_gpu_transpose.py
new file mode 100644
index 0000000000..c4313fe6a2
--- /dev/null
+++ b/tests/python/dlight/test_gpu_transpose.py
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm
+from tvm import dlight as dl
+from tvm.ir import IRModule, assert_structural_equal
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def _check(mod_before: IRModule, mod_after: IRModule):
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ with target:
+ mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
+ dl.gpu.Transpose(),
+ )(mod_before)
+ assert_structural_equal(mod, mod_after)
+
+
+def test_transpose():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)),
"float32"), T_transpose: T.Buffer((T.int64(4096), T.int64(512)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for ax0, ax1 in T.grid(T.int64(4096), T.int64(512)):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)),
"float32"), T_transpose: T.Buffer((T.int64(4096), T.int64(512)), "float32")):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ rxplaceholder_shared = T.alloc_buffer((T.int64(512),
T.int64(4096)), scope="shared")
+ for ax0_0_0 in T.thread_binding(T.int64(512), thread="blockIdx.y",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax1_0 in T.thread_binding(T.int64(32),
thread="blockIdx.x"):
+ for ax0_ax1_fused_0 in range(T.int64(1)):
+ for ax0_ax1_fused_1 in T.thread_binding(T.int64(8),
thread="threadIdx.y"):
+ for ax0_ax1_fused_2 in
T.thread_binding(T.int64(16), thread="threadIdx.x"):
+ for ax0_ax1_fused_3 in T.unroll(T.int64(1)):
+ with T.block("rxplaceholder_shared"):
+ v0 = T.axis.spatial(T.int64(512),
ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 *
T.int64(16) + ax0_ax1_fused_2 + ax0_ax1_fused_3) // T.int64(8))
+ v1 = T.axis.spatial(T.int64(4096),
ax0_0_0 * T.int64(8) + (ax0_ax1_fused_0 * T.int64(128) + ax0_ax1_fused_1 *
T.int64(16) + ax0_ax1_fused_2 + ax0_ax1_fused_3) % T.int64(8))
+ T.reads(rxplaceholder[v0, v1])
+ T.writes(rxplaceholder_shared[v0, v1])
+ T.block_attr({"buffer_dim_align": [[0,
0, 32, 1]]})
+ rxplaceholder_shared[v0, v1] =
rxplaceholder[v0, v1]
+ for ax0_0_1 in T.thread_binding(T.int64(8),
thread="threadIdx.y"):
+ for ax1_1 in T.thread_binding(T.int64(16),
thread="threadIdx.x"):
+ for ax0_1_0 in range(T.int64(1)):
+ for ax0_1_1 in range(T.int64(1)):
+ with T.block("T_transpose"):
+ v0 = T.axis.spatial(T.int64(4096),
ax0_0_0 * T.int64(8) + ax0_0_1 + ax0_1_0 + ax0_1_1)
+ v1 = T.axis.spatial(T.int64(512),
ax1_0 * T.int64(16) + ax1_1)
+ T.reads(rxplaceholder_shared[v1, v0])
+ T.writes(T_transpose[v0, v1])
+ T_transpose[v0, v1] =
rxplaceholder_shared[v1, v0]
+ # fmt: on
+ _check(Before, After)
+
+
+def test_decode_transpose():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)),
"uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "uint32"),
T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ decode = T.alloc_buffer((T.int64(4096), T.int64(4096)))
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(rxplaceholder[v_i // T.int64(8), v_j],
rxplaceholder_1[v_i // T.int64(32), v_j])
+ T.writes(decode[v_i, v_j])
+ decode[v_i, v_j] = T.Cast("float32",
T.bitwise_and(T.shift_right(rxplaceholder[v_i // T.int64(8), v_j],
T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) *
T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v_i //
T.int64(32), v_j], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32",
T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v_i // T.int64(32),
v_j], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
+ for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(decode[v_ax1, v_ax0])
+ T.writes(T_transpose[v_ax0, v_ax1])
+ T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)),
"uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "uint32"),
T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float32")):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ decode_shared = T.alloc_buffer((T.int64(4096), T.int64(4096)),
scope="shared")
+ for ax0_0_0 in T.thread_binding(T.int64(64), thread="blockIdx.y",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax1_0 in T.thread_binding(T.int64(256),
thread="blockIdx.x"):
+ for ax0_ax1_fused_0 in range(T.int64(1)):
+ for ax0_ax1_fused_1 in T.thread_binding(T.int64(8),
thread="threadIdx.y"):
+ for ax0_ax1_fused_2 in
T.thread_binding(T.int64(16), thread="threadIdx.x"):
+ for ax0_ax1_fused_3 in T.unroll(T.int64(8)):
+ with T.block("decode_shared"):
+ v0 = T.axis.spatial(T.int64(4096),
ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 *
T.int64(128) + ax0_ax1_fused_2 * T.int64(8) + ax0_ax1_fused_3) // T.int64(64))
+ v1 = T.axis.spatial(T.int64(4096),
ax0_0_0 * T.int64(64) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 *
T.int64(128) + ax0_ax1_fused_2 * T.int64(8) + ax0_ax1_fused_3) % T.int64(64))
+ T.reads(rxplaceholder[v0 //
T.int64(8), v1], rxplaceholder_1[v0 // T.int64(32), v1])
+ T.writes(decode_shared[v0, v1])
+ T.block_attr({"buffer_dim_align": [[0,
0, 32, 1]]})
+ decode_shared[v0, v1] =
T.Cast("float32", T.bitwise_and(T.shift_right(rxplaceholder[v0 // T.int64(8),
v1], T.Cast("uint32", v0 % T.int64(8) * T.int64(4))), T.uint32(15))) *
T.reinterpret("float32", T.shift_left(T.bitwise_and(rxplaceholder_1[v0 //
T.int64(32), v1], T.uint32(65535)), T.uint32(16))) + T.reinterpret("float32",
T.shift_left(T.bitwise_and(T.shift_right(rxplaceholder_1[v0 // T.int64(32),
v1], T.uint32(16)), T.uint32(65535)), T.uint32(16)))
+ for ax0_0_1 in T.thread_binding(T.int64(8),
thread="threadIdx.y"):
+ for ax1_1 in T.thread_binding(T.int64(16),
thread="threadIdx.x"):
+ for ax0_1_0 in range(T.int64(2)):
+ for ax0_1_1 in T.vectorized(T.int64(4)):
+ with T.block("T_transpose"):
+ v0 = T.axis.spatial(T.int64(4096),
ax0_0_0 * T.int64(64) + ax0_0_1 * T.int64(8) + ax0_1_0 * T.int64(4) + ax0_1_1)
+ v1 = T.axis.spatial(T.int64(4096),
ax1_0 * T.int64(16) + ax1_1)
+ T.reads(decode_shared[v1, v0])
+ T.writes(T_transpose[v0, v1])
+ T_transpose[v0, v1] =
decode_shared[v1, v0]
+ # fmt: on
+ _check(Before, After)
+
+
+def test_decode_int3_transpose():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B:
T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose:
T.Buffer((T.int64(4096), T.int64(4096)), "float16")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ decode_1 = T.alloc_buffer((T.int64(4096), T.int64(4096)),
"float16")
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(A[v_i // T.int64(10), v_j], B[v_i // T.int64(40),
v_j])
+ T.writes(decode_1[v_i, v_j])
+ decode_1[v_i, v_j] = (T.Cast("float16",
T.bitwise_and(T.shift_right(A[v_i // T.int64(10), v_j], T.Cast("uint32", v_i %
T.int64(10)) * T.uint32(3)), T.uint32(7))) - T.float16(3)) * B[v_i //
T.int64(40), v_j]
+ for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(decode_1[v_ax1, v_ax0])
+ T.writes(T_transpose[v_ax0, v_ax1])
+ T_transpose[v_ax0, v_ax1] = decode_1[v_ax1, v_ax0]
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B:
T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose:
T.Buffer((T.int64(4096), T.int64(4096)), "float16")):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ decode_1_shared = T.alloc_buffer((T.int64(4096), T.int64(4096)),
"float16", scope="shared")
+ for ax0_0_0 in T.thread_binding(T.int64(52), thread="blockIdx.y",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax1_0 in T.thread_binding(T.int64(256),
thread="blockIdx.x"):
+ for ax0_ax1_fused_0 in range(T.int64(2)):
+ for ax0_ax1_fused_1 in T.thread_binding(T.int64(8),
thread="threadIdx.y"):
+ for ax0_ax1_fused_2 in
T.thread_binding(T.int64(16), thread="threadIdx.x"):
+ for ax0_ax1_fused_3 in T.unroll(T.int64(10)):
+ with T.block("decode_1_shared"):
+ v0 = T.axis.spatial(T.int64(4096),
ax1_0 * T.int64(16) + (ax0_ax1_fused_0 * T.int64(1280) + ax0_ax1_fused_1 *
T.int64(160) + ax0_ax1_fused_2 * T.int64(10) + ax0_ax1_fused_3) // T.int64(82))
+ v1 = T.axis.spatial(T.int64(4096),
ax0_0_0 * T.int64(80) + (ax0_ax1_fused_0 * T.int64(1280) + ax0_ax1_fused_1 *
T.int64(160) + ax0_ax1_fused_2 * T.int64(10) + ax0_ax1_fused_3) % T.int64(82))
+ T.where(ax0_0_0 * T.int64(80) +
(((ax0_ax1_fused_0 * T.int64(8) + ax0_ax1_fused_1) * T.int64(16) +
ax0_ax1_fused_2) * T.int64(10) + ax0_ax1_fused_3) % T.int64(82) < T.int64(4096)
and ((ax0_ax1_fused_0 * T.int64(8) + ax0_ax1_fused_1) * T.int64(16) +
ax0_ax1_fused_2) * T.int64(10) + ax0_ax1_fused_3 < T.int64(1312))
+ T.reads(A[v0 // T.int64(10), v1], B[v0
// T.int64(40), v1])
+ T.writes(decode_1_shared[v0, v1])
+ T.block_attr({"buffer_dim_align": [[0,
0, 32, 1]]})
+ decode_1_shared[v0, v1] =
(T.Cast("float16", T.bitwise_and(T.shift_right(A[v0 // T.int64(10), v1],
T.Cast("uint32", v0 % T.int64(10)) * T.uint32(3)), T.uint32(7))) -
T.float16(3)) * B[v0 // T.int64(40), v1]
+ for ax0_0_1 in T.thread_binding(T.int64(8),
thread="threadIdx.y"):
+ for ax1_1 in T.thread_binding(T.int64(16),
thread="threadIdx.x"):
+ for ax0_1_0 in range(T.int64(3)):
+ for ax0_1_1 in T.vectorized(T.int64(4)):
+ with T.block("T_transpose"):
+ v0 = T.axis.spatial(T.int64(4096),
(ax0_0_0 * T.int64(8) + ax0_0_1) * T.int64(10) + (ax0_1_0 * T.int64(4) +
ax0_1_1))
+ v1 = T.axis.spatial(T.int64(4096),
ax1_0 * T.int64(16) + ax1_1)
+ T.where((ax0_0_0 * T.int64(8) +
ax0_0_1) * T.int64(10) + (ax0_1_0 * T.int64(4) + ax0_1_1) < T.int64(4096) and
ax0_0_0 * T.int64(8) + ax0_0_1 < T.int64(410) and ax0_1_0 * T.int64(4) +
ax0_1_1 < T.int64(10))
+ T.reads(decode_1_shared[v1, v0])
+ T.writes(T_transpose[v0, v1])
+ T_transpose[v0, v1] =
decode_1_shared[v1, v0]
+ # fmt: on
+ _check(Before, After)