This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7ef7ce7a61 [Unity][Dlight] Add reduction rules (#15156)
7ef7ce7a61 is described below

commit 7ef7ce7a61bdb865eb54741ad9333d978986f3b9
Author: Junru Shao <[email protected]>
AuthorDate: Mon Jun 26 10:30:57 2023 -0700

    [Unity][Dlight] Add reduction rules (#15156)
    
    This PR adds a dlight rule that effectively allows fusing all reduction
    and spatial blocks together and executed within a single GPU kernel.
    This is meant to schedule softmax, layer norm, RMS norm, etc.
---
 python/tvm/dlight/base/__init__.py                 |   2 +
 python/tvm/dlight/base/analysis.py                 |  75 ++++++
 python/tvm/dlight/base/common_schedules.py         |  60 +++++
 python/tvm/dlight/base/transform.py                |  14 +-
 python/tvm/dlight/gpu/__init__.py                  |   1 +
 python/tvm/dlight/gpu/fallback.py                  |  44 ++--
 python/tvm/dlight/gpu/reduction.py                 |  92 +++++++
 ...{test_schedule_rule.py => test_gpu_fallback.py} |   0
 tests/python/dlight/test_gpu_reduction.py          | 282 +++++++++++++++++++++
 9 files changed, 539 insertions(+), 31 deletions(-)

diff --git a/python/tvm/dlight/base/__init__.py 
b/python/tvm/dlight/base/__init__.py
index 6088add37e..d14db6c4a7 100644
--- a/python/tvm/dlight/base/__init__.py
+++ b/python/tvm/dlight/base/__init__.py
@@ -15,5 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Base infra"""
+from .analysis import BlockInfo
+from .common_schedules import try_inline
 from .schedule_rule import ScheduleRule
 from .transform import ApplyDefaultSchedule
diff --git a/python/tvm/dlight/base/analysis.py 
b/python/tvm/dlight/base/analysis.py
new file mode 100644
index 0000000000..a2508c87ba
--- /dev/null
+++ b/python/tvm/dlight/base/analysis.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Analysis on TIR blocks, loops and functions."""
+from typing import List, Union
+
+from tvm import tir
+
+
+class BlockInfo:
+    """Information about a TIR block."""
+
+    block: tir.schedule.BlockRV
+    """The TIR block the current schedule refers to"""
+    name: str
+    """The name of the block"""
+    iters: List[tir.IterVar]
+    """The iteration domains of the current block"""
+
+    def __init__(
+        self,
+        sch: tir.Schedule,
+        block: tir.schedule.BlockRV,
+    ):
+        """Construct a BlockInfo object via TIR schedule."""
+        tir_block = sch.get(block)
+        self.block = block
+        self.name = tir_block.name_hint
+        self.iters = list(tir_block.iter_vars)
+
+    def dom(self) -> List[Union[int, tir.PrimExpr]]:
+        """The iteration domain of the block."""
+
+        def _iter_dom(i: tir.IterVar) -> Union[int, tir.PrimExpr]:
+            result = i.dom.extent
+            if isinstance(result, tir.IntImm):
+                result = int(result)
+            return result
+
+        result = [_iter_dom(i) for i in self.iters]
+        return result
+
+    def dom_kind(self) -> str:
+        """The iteration domain kind of the block, for example, SSSS, SSSR."""
+
+        def _iter_kind(i: tir.IterVar) -> str:
+            return {
+                tir.IterVar.DataPar: "S",
+                tir.IterVar.CommReduce: "R",
+            }.get(i.iter_type, "O")
+
+        return "".join(_iter_kind(i) for i in self.iters)
+
+    def is_spatial(self) -> bool:
+        """Whether the block is spatial, i.e. all its iteration domains are 
spatial."""
+        return all(k == "S" for k in self.dom_kind())
+
+    def __str__(self) -> str:
+        return f'BlockInfo("{self.name}", "{self.dom_kind()}", {self.dom()})'
+
+    def __repr__(self) -> str:
+        return str(self)
diff --git a/python/tvm/dlight/base/common_schedules.py 
b/python/tvm/dlight/base/common_schedules.py
new file mode 100644
index 0000000000..6568f9e5b5
--- /dev/null
+++ b/python/tvm/dlight/base/common_schedules.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Common schedule strategies for TIR."""
+from typing import Callable, List
+
+from tvm import tir
+
+from .analysis import BlockInfo
+
+
+def try_inline(
+    sch: tir.Schedule,
+    blocks: List[BlockInfo],
+) -> List[BlockInfo]:
+    """Try to inline as many blocks as possible, and return the remaining 
blocks.
+
+    Parameters
+    ----------
+    sch : tir.Schedule
+        The TIR schedule used to inline blocks.
+    blocks : List[BlockInfo]
+        The blocks to be inlined.
+
+    Returns
+    -------
+    remaining : List[BlockInfo]
+        The remaining blocks that cannot be inlined.
+    """
+
+    def _trial(func: Callable):
+        for i, block in enumerate(blocks):
+            try:
+                func(block.block)
+            except:  # pylint: disable=bare-except
+                continue
+            return i
+        return None
+
+    while True:
+        i = _trial(sch.compute_inline)
+        if i is None:
+            i = _trial(sch.reverse_compute_inline)
+        if i is None:
+            break
+        blocks.pop(i)
+    return blocks
diff --git a/python/tvm/dlight/base/transform.py 
b/python/tvm/dlight/base/transform.py
index 9d536adfec..c11a4ae060 100644
--- a/python/tvm/dlight/base/transform.py
+++ b/python/tvm/dlight/base/transform.py
@@ -28,6 +28,16 @@ from tvm.target import Target
 from .schedule_rule import ScheduleRule
 
 
+def _is_scheduled(func: tir.PrimFunc) -> bool:
+    if not isinstance(func, tir.PrimFunc):
+        return False
+    if not func.attrs:
+        return False
+    if "tir.is_scheduled" not in func.attrs:
+        return False
+    return func.attrs["tir.is_scheduled"] == 1
+
+
 @module_pass(opt_level=0, name="ApplyDefaultSchedule")
 class ApplyDefaultSchedule:  # pylint: disable=too-few-public-methods
     """A IRModule pass that applies a list of ScheduleRules to all PrimFuncs 
in the module."""
@@ -50,9 +60,7 @@ class ApplyDefaultSchedule:  # pylint: 
disable=too-few-public-methods
         target = Target.current(allow_none=False)
         updated_functions = {}
         for g_var, func in mod.functions.items():
-            if isinstance(func, tir.PrimFunc) and (
-                not func.attrs or not func.attrs.get("tir.is_scheduled", 0)
-            ):
+            if not _is_scheduled(func):
                 sch = _apply_rules(func, target, self.rules, tunable=False)
                 if sch is not None:
                     assert len(sch) == 1
diff --git a/python/tvm/dlight/gpu/__init__.py 
b/python/tvm/dlight/gpu/__init__.py
index d5311014b0..098f71d608 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -19,3 +19,4 @@ GPU-generic schedule rules.
 For CUDA/ROCm/Vulkan/Metal-specific rules, use 
`tvm.dlight.cuda/rocm/vulkan/metal` instead
 """
 from .fallback import Fallback
+from .reduction import Reduction
diff --git a/python/tvm/dlight/gpu/fallback.py 
b/python/tvm/dlight/gpu/fallback.py
index 354361323c..caefc8d563 100644
--- a/python/tvm/dlight/gpu/fallback.py
+++ b/python/tvm/dlight/gpu/fallback.py
@@ -16,13 +16,12 @@
 # under the License.
 # pylint: disable=missing-docstring
 """A fallback schedule rule for GPU operators."""
-from typing import Callable, List
+from typing import List
 
 from tvm import tir
-from tvm._ffi import get_global_func
 from tvm.target import Target
 
-from ..base import ScheduleRule
+from ..base import BlockInfo, ScheduleRule, try_inline
 
 
 def _max_threads_per_block(target: Target) -> int:
@@ -48,39 +47,28 @@ class Fallback(ScheduleRule):
         _: bool,
     ) -> tir.Schedule:
         max_threads_per_block = _max_threads_per_block(target)
-        get_loop_iter_type = get_global_func("tir.schedule.GetLoopIterType")
 
         sch = tir.Schedule(func)
-        blocks = 
sch.get_child_blocks(sch.get_block(sch.mod["main"].body.block.name_hint))
-
-        while True:
-
-            def _try_inline(func: Callable):
-                for i, block in enumerate(blocks):
-                    try:
-                        func(block)
-                    except:  # pylint: disable=bare-except
-                        continue
-                    return i
-                return None
-
-            i = _try_inline(sch.compute_inline)
-            if i is None:
-                i = _try_inline(sch.reverse_compute_inline)
-            if i is None:
-                break
-            blocks.pop(i)
-
-        for block in blocks:
+        for block in try_inline(
+            sch,
+            [
+                BlockInfo(
+                    sch,
+                    block,
+                )
+                for block in sch.get_child_blocks(sch.get_block("root"))
+            ],
+        ):
             s_loops: List[tir.schedule.LoopRV] = []
             r_loops: List[tir.schedule.LoopRV] = []
             o_loops: List[tir.schedule.LoopRV] = []
-            for loop in sch.get_loops(block):
-                iter_type = get_loop_iter_type(sch, loop)
+            dom_kind = block.dom_kind()
+            block = block.block
+            for loop, iter_type in zip(sch.get_loops(block), dom_kind):
                 {"S": s_loops, "R": r_loops, "O": 
o_loops}[iter_type].append(loop)
 
             if not s_loops:
-                s_loops.append(sch.add_unit_loop(block))
+                s_loops.append(sch.add_unit_loop(block.block))
             sch.reorder(*s_loops, *r_loops, *o_loops)
             bx, tx = sch.split(  # pylint: disable=invalid-name
                 sch.fuse(*s_loops),
diff --git a/python/tvm/dlight/gpu/reduction.py 
b/python/tvm/dlight/gpu/reduction.py
new file mode 100644
index 0000000000..b3cc58c902
--- /dev/null
+++ b/python/tvm/dlight/gpu/reduction.py
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Reduction rule for operators including softmax, layer norm, RMS norm, etc"""
+from typing import List, Union
+
+from tvm import tir
+from tvm.target import Target
+
+from ..base import BlockInfo, ScheduleRule, try_inline
+
+
+class Reduction(ScheduleRule):
+    """Reduction rule for operators including softmax, layer norm, RMS norm, 
etc"""
+
+    def apply(  # pylint: disable=too-many-locals
+        self,
+        func: tir.PrimFunc,
+        target: Target,
+        _: bool,
+    ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
+        if target.kind.name == "cuda":
+            len_tx = 256
+            unroll_depth = 256
+        else:
+            len_tx = 64
+            unroll_depth = 64
+
+        def _inline_all_spatial():
+            blocks = []
+            spatial_blocks = []
+            for block in sch.get_child_blocks(sch.get_block("root")):
+                block = BlockInfo(sch, block)
+                if block.is_spatial():
+                    spatial_blocks.append(block)
+                elif spatial_blocks:
+                    blocks.extend(try_inline(sch, spatial_blocks))
+                    blocks.append(block)
+                    spatial_blocks = []
+                else:
+                    blocks.append(block)
+            if spatial_blocks:
+                blocks.extend(try_inline(sch, spatial_blocks))
+            return blocks
+
+        sch = tir.Schedule(func)
+        blocks = _inline_all_spatial()
+        assert len(blocks) > 0
+
+        dom_kind = blocks[0].dom_kind()
+        num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S"))
+        num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R"))
+        try:
+            for block in blocks[1:-1]:
+                assert block.dom_kind() == dom_kind
+            assert blocks[-1].is_spatial()
+            assert len(blocks[-1].dom_kind()) == len(dom_kind)
+        except AssertionError:
+            print("Mismatch")
+            return None
+
+        loops = sch.get_loops(blocks[-1].block)
+        bx = sch.fuse(*loops[:num_leading_s])  # pylint: disable=invalid-name
+        _, tx = sch.split(loops[-1], [None, len_tx])  # pylint: 
disable=invalid-name
+        sch.bind(bx, "blockIdx.x")
+        sch.bind(tx, "threadIdx.x")
+
+        for block in reversed(blocks[:-1]):
+            block = block.block
+            for i, _ in enumerate(sch.get(block).writes):
+                sch.set_scope(block, buffer_index=i, storage_scope="shared")
+            sch.compute_at(block, bx, preserve_unit_loops=True)
+            r_loop = sch.fuse(*sch.get_loops(block)[-num_trailing_r:])
+            _, tx = sch.split(r_loop, [None, len_tx])  # pylint: 
disable=invalid-name
+            sch.bind(tx, "threadIdx.x")
+
+        sch.annotate(bx, ann_key="pragma_auto_unroll_max_step", 
ann_val=unroll_depth)
+        sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1)
+        return sch
diff --git a/tests/python/dlight/test_schedule_rule.py 
b/tests/python/dlight/test_gpu_fallback.py
similarity index 100%
rename from tests/python/dlight/test_schedule_rule.py
rename to tests/python/dlight/test_gpu_fallback.py
diff --git a/tests/python/dlight/test_gpu_reduction.py 
b/tests/python/dlight/test_gpu_reduction.py
new file mode 100644
index 0000000000..99307093c8
--- /dev/null
+++ b/tests/python/dlight/test_gpu_reduction.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from tvm import dlight as dl
+from tvm.ir import IRModule, assert_structural_equal
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def _check(mod_before: IRModule, mod_after: IRModule):
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
+            dl.gpu.Reduction(),
+        )(mod_before)
+    assert_structural_equal(mod, mod_after)
+
+
+def test_softmax():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(p_lv44: T.handle, p_output0: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n, m = T.int64(), T.int64()
+            lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m))
+            var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), 
T.int64(32), n, m), "float16")
+            # with T.block("root"):
+            T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(32), n))
+            T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(32), n, m))
+            T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(32), n))
+            var_T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), 
T.int64(32), n, m))
+            for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
+                with T.block("T_softmax_maxelem"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(lv44[v_i0, v_i1, v_i2, v_k])
+                    T.writes(T_softmax_maxelem[v_i0, v_i1, v_i2])
+                    with T.init():
+                        T_softmax_maxelem[v_i0, v_i1, v_i2] = 
T.float32(-3.4028234663852886e+38)
+                    T_softmax_maxelem[v_i0, v_i1, v_i2] = 
T.max(T_softmax_maxelem[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k])
+            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
+                with T.block("T_softmax_exp"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(lv44[v_i0, v_i1, v_i2, v_i3], 
T_softmax_maxelem[v_i0, v_i1, v_i2])
+                    T.writes(T_softmax_exp[v_i0, v_i1, v_i2, v_i3])
+                    T_softmax_exp[v_i0, v_i1, v_i2, v_i3] = T.exp(lv44[v_i0, 
v_i1, v_i2, v_i3] - T_softmax_maxelem[v_i0, v_i1, v_i2])
+            for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), n, m):
+                with T.block("T_softmax_expsum"):
+                    v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, 
k])
+                    T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_k])
+                    T.writes(T_softmax_expsum[v_i0, v_i1, v_i2])
+                    with T.init():
+                        T_softmax_expsum[v_i0, v_i1, v_i2] = T.float32(0)
+                    T_softmax_expsum[v_i0, v_i1, v_i2] = 
T_softmax_expsum[v_i0, v_i1, v_i2] + T_softmax_exp[v_i0, v_i1, v_i2, v_k]
+            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
+                with T.block("T_softmax_norm"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(T_softmax_exp[v_i0, v_i1, v_i2, v_i3], 
T_softmax_expsum[v_i0, v_i1, v_i2])
+                    T.writes(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, 
v_i3])
+                    T.block_attr({"axis": 3})
+                    var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3] = 
T_softmax_exp[v_i0, v_i1, v_i2, v_i3] / T_softmax_expsum[v_i0, v_i1, v_i2]
+            for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(32), n, m):
+                with T.block("compute"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, 
v_i3])
+                    T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
+                    var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = 
T.Cast("float16", var_T_softmax_norm_intermediate[v_i0, v_i1, v_i2, v_i3])
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(p_lv44: T.handle, p_output0: T.handle):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            n, m = T.int64(), T.int64()
+            lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m))
+            var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), 
T.int64(32), n, m), "float16")
+            # with T.block("root"):
+            T_softmax_maxelem_shared = T.alloc_buffer((T.int64(1), 
T.int64(32), n), scope="shared")
+            T_softmax_expsum_shared = T.alloc_buffer((T.int64(1), T.int64(32), 
n), scope="shared")
+            for i0_i1_i2_fused in T.thread_binding(n * T.int64(32), 
thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
+                for ax0, ax1, ax2, ax3_fused_0 in T.grid(T.int64(1), 
T.int64(1), T.int64(1), (m + T.int64(255)) // T.int64(256)):
+                    for ax3_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                        with T.block("T_softmax_maxelem"):
+                            v_i0 = T.axis.spatial(T.int64(1), ax0)
+                            v_i1 = T.axis.spatial(T.int64(32), i0_i1_i2_fused 
// n + ax1)
+                            v_i2 = T.axis.spatial(n, i0_i1_i2_fused % n + ax2)
+                            v_k = T.axis.reduce(m, ax3_fused_0 * T.int64(256) 
+ ax3_fused_1)
+                            T.where(T.int64(0) <= i0_i1_i2_fused // n and 
i0_i1_i2_fused // n < T.int64(32) and T.int64(0) <= i0_i1_i2_fused % n and 
i0_i1_i2_fused % n < n and ax3_fused_0 * T.int64(256) + ax3_fused_1 < m)
+                            T.reads(lv44[v_i0, v_i1, v_i2, v_k])
+                            T.writes(T_softmax_maxelem_shared[v_i0, v_i1, 
v_i2])
+                            with T.init():
+                                T_softmax_maxelem_shared[v_i0, v_i1, v_i2] = 
T.float32(-3.4028234663852886e+38)
+                            T_softmax_maxelem_shared[v_i0, v_i1, v_i2] = 
T.max(T_softmax_maxelem_shared[v_i0, v_i1, v_i2], lv44[v_i0, v_i1, v_i2, v_k])
+                for ax0, ax1, ax2, ax3_fused_0 in T.grid(T.int64(1), 
T.int64(1), T.int64(1), (m + T.int64(255)) // T.int64(256)):
+                    for ax3_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                        with T.block("T_softmax_expsum"):
+                            v_i0 = T.axis.spatial(T.int64(1), ax0)
+                            v_i1 = T.axis.spatial(T.int64(32), i0_i1_i2_fused 
// n + ax1)
+                            v_i2 = T.axis.spatial(n, i0_i1_i2_fused % n + ax2)
+                            v_k = T.axis.reduce(m, ax3_fused_0 * T.int64(256) 
+ ax3_fused_1)
+                            T.where(T.int64(0) <= i0_i1_i2_fused // n and 
i0_i1_i2_fused // n < T.int64(32) and T.int64(0) <= i0_i1_i2_fused % n and 
i0_i1_i2_fused % n < n and ax3_fused_0 * T.int64(256) + ax3_fused_1 < m)
+                            T.reads(lv44[v_i0, v_i1, v_i2, v_k], 
T_softmax_maxelem_shared[v_i0, v_i1, v_i2])
+                            T.writes(T_softmax_expsum_shared[v_i0, v_i1, v_i2])
+                            with T.init():
+                                T_softmax_expsum_shared[v_i0, v_i1, v_i2] = 
T.float32(0)
+                            T_softmax_expsum_shared[v_i0, v_i1, v_i2] = 
T_softmax_expsum_shared[v_i0, v_i1, v_i2] + T.exp(lv44[v_i0, v_i1, v_i2, v_k] - 
T_softmax_maxelem_shared[v_i0, v_i1, v_i2])
+                for i3_0 in range((m + T.int64(255)) // T.int64(256)):
+                    for i3_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                        with T.block("compute"):
+                            v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
+                            v_i1 = T.axis.spatial(T.int64(32), i0_i1_i2_fused 
// n)
+                            v_i2 = T.axis.spatial(n, i0_i1_i2_fused % n)
+                            v_i3 = T.axis.spatial(m, i3_0 * T.int64(256) + 
i3_1)
+                            T.where(i3_0 * T.int64(256) + i3_1 < m)
+                            T.reads(lv44[v_i0, v_i1, v_i2, v_i3], 
T_softmax_maxelem_shared[v_i0, v_i1, v_i2], T_softmax_expsum_shared[v_i0, v_i1, 
v_i2])
+                            T.writes(var_compute_intermediate[v_i0, v_i1, 
v_i2, v_i3])
+                            var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = 
T.Cast("float16", T.exp(lv44[v_i0, v_i1, v_i2, v_i3] - 
T_softmax_maxelem_shared[v_i0, v_i1, v_i2]) / T_softmax_expsum_shared[v_i0, 
v_i1, v_i2])
+    # fmt: on
+    _check(Before, After)
+
+
+def test_layer_norm():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), 
"float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n = T.int64()
+            lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560)))
+            var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), 
n, T.int64(2560)), "float16")
+            # with T.block("root"):
+            A_red_temp_v0 = T.alloc_buffer((T.int64(1), n))
+            A_red_temp_v1 = T.alloc_buffer((T.int64(1), n))
+            var_T_layer_norm_intermediate = T.alloc_buffer((T.int64(1), n, 
T.int64(2560)))
+            for ax0, ax1, k2 in T.grid(T.int64(1), n, T.int64(2560)):
+                with T.block("A_red_temp"):
+                    v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
+                    T.reads(lv6[v_ax0, v_ax1, v_k2])
+                    T.writes(A_red_temp_v0[v_ax0, v_ax1], A_red_temp_v1[v_ax0, 
v_ax1])
+                    with T.init():
+                        A_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
+                        A_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
+                    v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + 
lv6[v_ax0, v_ax1, v_k2]
+                    v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + 
lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2]
+                    A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0
+                    A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1
+            for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)):
+                with T.block("T_layer_norm"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(lv6[v_ax0, v_ax1, v_ax2], A_red_temp_v0[v_ax0, 
v_ax1], A_red_temp_v1[v_ax0, v_ax1], weight1[v_ax2], bias[v_ax2])
+                    T.writes(var_T_layer_norm_intermediate[v_ax0, v_ax1, 
v_ax2])
+                    var_T_layer_norm_intermediate[v_ax0, v_ax1, v_ax2] = 
(lv6[v_ax0, v_ax1, v_ax2] - A_red_temp_v0[v_ax0, v_ax1] * 
T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1[v_ax0, v_ax1] * 
T.float32(0.00039062500000000002) - A_red_temp_v0[v_ax0, v_ax1] * 
T.float32(0.00039062500000000002) * (A_red_temp_v0[v_ax0, v_ax1] * 
T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * 
weight1[v_ax2] + bias[v_ax2]
+            for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(2560)):
+                with T.block("compute"):
+                    v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                    T.reads(var_T_layer_norm_intermediate[v_i0, v_i1, v_i2])
+                    T.writes(var_compute_intermediate[v_i0, v_i1, v_i2])
+                    var_compute_intermediate[v_i0, v_i1, v_i2] = 
T.Cast("float16", var_T_layer_norm_intermediate[v_i0, v_i1, v_i2])
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), 
"float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle):
+            T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+            n = T.int64()
+            lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560)))
+            var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), 
n, T.int64(2560)), "float16")
+            # with T.block("root"):
+            A_red_temp_v0_shared = T.alloc_buffer((T.int64(1), n), 
scope="shared")
+            A_red_temp_v1_shared = T.alloc_buffer((T.int64(1), n), 
scope="shared")
+            for i0_i1_fused in T.thread_binding(n, thread="blockIdx.x", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                for ax0, ax1, ax2_fused_0 in T.grid(T.int64(1), T.int64(1), 
T.int64(10)):
+                    for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                        with T.block("A_red_temp"):
+                            v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                            v_ax1 = T.axis.spatial(n, i0_i1_fused + ax1)
+                            v_k2 = T.axis.reduce(T.int64(2560), ax2_fused_0 * 
T.int64(256) + ax2_fused_1)
+                            T.reads(lv6[v_ax0, v_ax1, v_k2])
+                            T.writes(A_red_temp_v0_shared[v_ax0, v_ax1], 
A_red_temp_v1_shared[v_ax0, v_ax1])
+                            with T.init():
+                                A_red_temp_v0_shared[v_ax0, v_ax1] = 
T.float32(0)
+                                A_red_temp_v1_shared[v_ax0, v_ax1] = 
T.float32(0)
+                            v_A_red_temp_v0: T.float32 = 
A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2]
+                            v_A_red_temp_v1: T.float32 = 
A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, 
v_ax1, v_k2]
+                            A_red_temp_v0_shared[v_ax0, v_ax1] = 
v_A_red_temp_v0
+                            A_red_temp_v1_shared[v_ax0, v_ax1] = 
v_A_red_temp_v1
+                for i2_0 in range(T.int64(10)):
+                    for i2_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                        with T.block("compute"):
+                            v_i0 = T.axis.spatial(T.int64(1), T.int64(0))
+                            v_i1 = T.axis.spatial(n, i0_i1_fused)
+                            v_i2 = T.axis.spatial(T.int64(2560), i2_0 * 
T.int64(256) + i2_1)
+                            T.reads(lv6[v_i0, v_i1, v_i2], 
A_red_temp_v0_shared[v_i0, v_i1], A_red_temp_v1_shared[v_i0, v_i1], 
weight1[v_i2], bias[v_i2])
+                            T.writes(var_compute_intermediate[v_i0, v_i1, 
v_i2])
+                            var_compute_intermediate[v_i0, v_i1, v_i2] = 
T.Cast("float16", (lv6[v_i0, v_i1, v_i2] - A_red_temp_v0_shared[v_i0, v_i1] * 
T.float32(0.00039062500000000002)) * T.rsqrt(A_red_temp_v1_shared[v_i0, v_i1] * 
T.float32(0.00039062500000000002) - A_red_temp_v0_shared[v_i0, v_i1] * 
T.float32(0.00039062500000000002) * (A_red_temp_v0_shared[v_i0, v_i1] * 
T.float32(0.00039062500000000002)) + T.float32(1.0000000000000001e-05)) * 
weight1[v_i2] + bias[v_i2])
+    # fmt: on
+    _check(Before, After)
+
+
+def test_rms_norm():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), 
var_rms_norm: T.handle):
+            T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
+            n = T.int64()
+            A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), 
"float16")
+            rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, 
T.int64(4096)), "float16")
+            # with T.block("root"):
+            Ared_temp = T.alloc_buffer((T.int64(1), n))
+            for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
+                with T.block("Ared_temp"):
+                    v_bsz, v_i, v_k = T.axis.remap("SSR", [bsz, i, k])
+                    T.reads(A[v_bsz, v_i, v_k])
+                    T.writes(Ared_temp[v_bsz, v_i])
+                    with T.init():
+                        Ared_temp[v_bsz, v_i] = T.float32(0)
+                    Ared_temp[v_bsz, v_i] = Ared_temp[v_bsz, v_i] + 
T.Cast("float32", A[v_bsz, v_i, v_k]) * T.Cast("float32", A[v_bsz, v_i, v_k])
+            for bsz, i, k in T.grid(T.int64(1), n, T.int64(4096)):
+                with T.block("rms_norm"):
+                    v_bsz, v_i, v_k = T.axis.remap("SSS", [bsz, i, k])
+                    T.reads(B[v_k], A[v_bsz, v_i, v_k], Ared_temp[v_bsz, v_i])
+                    T.writes(rms_norm_1[v_bsz, v_i, v_k])
+                    rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", 
T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / 
T.sqrt(Ared_temp[v_bsz, v_i] * T.float32(0.000244140625) + 
T.float32(9.9999999999999995e-07))))
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), 
var_rms_norm: T.handle):
+            T.func_attr({"op_pattern": 4, "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+            n = T.int64()
+            A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), 
"float16")
+            rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, 
T.int64(4096)), "float16")
+            # with T.block("root"):
+            Ared_temp_shared = T.alloc_buffer((T.int64(1), n), scope="shared")
+            for bsz_i_fused in T.thread_binding(n, thread="blockIdx.x", 
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+                for ax0, ax1, ax2_fused_0 in T.grid(T.int64(1), T.int64(1), 
T.int64(16)):
+                    for ax2_fused_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                        with T.block("Ared_temp"):
+                            v_bsz = T.axis.spatial(T.int64(1), ax0)
+                            v_i = T.axis.spatial(n, bsz_i_fused + ax1)
+                            v_k = T.axis.reduce(T.int64(4096), ax2_fused_0 * 
T.int64(256) + ax2_fused_1)
+                            T.reads(A[v_bsz, v_i, v_k])
+                            T.writes(Ared_temp_shared[v_bsz, v_i])
+                            with T.init():
+                                Ared_temp_shared[v_bsz, v_i] = T.float32(0)
+                            Ared_temp_shared[v_bsz, v_i] = 
Ared_temp_shared[v_bsz, v_i] + T.Cast("float32", A[v_bsz, v_i, v_k]) * 
T.Cast("float32", A[v_bsz, v_i, v_k])
+                for k_0 in range(T.int64(16)):
+                    for k_1 in T.thread_binding(T.int64(256), 
thread="threadIdx.x"):
+                        with T.block("rms_norm"):
+                            v_bsz = T.axis.spatial(T.int64(1), T.int64(0))
+                            v_i = T.axis.spatial(n, bsz_i_fused)
+                            v_k = T.axis.spatial(T.int64(4096), k_0 * 
T.int64(256) + k_1)
+                            T.reads(B[v_k], A[v_bsz, v_i, v_k], 
Ared_temp_shared[v_bsz, v_i])
+                            T.writes(rms_norm_1[v_bsz, v_i, v_k])
+                            rms_norm_1[v_bsz, v_i, v_k] = T.Cast("float16", 
T.Cast("float32", B[v_k]) * (T.Cast("float32", A[v_bsz, v_i, v_k]) / 
T.sqrt(Ared_temp_shared[v_bsz, v_i] * T.float32(0.000244140625) + 
T.float32(9.9999999999999995e-07))))
+    # fmt: on
+    _check(Before, After)
+
+
+if __name__ == "__main__":
+    test_softmax()
+    test_layer_norm()
+    test_rms_norm()


Reply via email to