This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1637b1436f [Unity][Dlight] Avoid TransformBlockLayout in GEMV Rule 
(#15248)
1637b1436f is described below

commit 1637b1436f18d6026b68a9e2e8dd57d45756161a
Author: Junru Shao <[email protected]>
AuthorDate: Wed Jul 5 18:27:34 2023 -0700

    [Unity][Dlight] Avoid TransformBlockLayout in GEMV Rule (#15248)
    
    This PR made two changes:
    - Replace the use of `sch.transform_block_layout` with classic
      split/fuse/reorder in DecodeGEMV rule. The primitive itself is
      designed to be an overkill, but in fact doesn't support symbolic
      bounds really well. Meanwhile, generally it wouldn't be necessary to
      use this advanced primitive as it might hurt subsequent analysis,
      although in our case there isn't any.
    - Always make sure there is at least one spatial loop during the initial
      PrimFunc normalization. This could help cases of `n -> 1` reduction
      to schedule more conveniently.
    
    NOTE: `sch.transform_block_layout` is still used in the PrimFunc
    normalization to remove unit loops and ensure at least one spatial loop.
    There is no reorder/fusion/split happening in this specific use and
    could help with all subsequent analysis to make sure all iters are
    non-trivial.
---
 python/tvm/dlight/gpu/decode_gemv.py        |  68 ++++++++--------
 src/tir/schedule/transform.cc               |   8 +-
 tests/python/dlight/test_gpu_decode_gemv.py | 116 +++++++++++++++++++++-------
 3 files changed, 128 insertions(+), 64 deletions(-)

diff --git a/python/tvm/dlight/gpu/decode_gemv.py 
b/python/tvm/dlight/gpu/decode_gemv.py
index b9e8b44ef2..69da1b0b0a 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -80,7 +80,8 @@ class DecodeGEMV(ScheduleRule):
         if not isinstance(func, tir.PrimFunc):
             return None
         sch = tir.Schedule(func)
-        block_infos = try_inline_contiguous_spatial(sch, 
normalize_prim_func(sch))
+        block_infos = normalize_prim_func(sch)
+        block_infos = try_inline_contiguous_spatial(sch, block_infos)
         if block_infos is None or len(block_infos) > 2:
             return None
 
@@ -117,53 +118,48 @@ class DecodeGEMV(ScheduleRule):
             sch.reverse_compute_at(block_infos[1].block_rv, 
sch.get_loops(block)[0])
         return sch
 
-    def _normalize(
+    def _normalize(  # pylint: disable=too-many-branches
         self,
         sch: tir.Schedule,
         block_info: BlockInfo,
-        iter_sum: arith.IterSumExpr,
+        access: arith.IterSumExpr,
     ) -> Tuple[Optional[bool], Optional[int]]:
-        if iter_sum.base != 0:
+        if access.base != 0:
             return None, None
         iter_to_info = {i.var: i for i in block_info.iters}
-        s_dom, r_dom, c_dom, c_factor = None, None, None, None
-        for split in iter_sum.args:
-            var = split.source.source
-            info = iter_to_info[var]
-            dom = info.dom
+        s_loops, r_loops, c_loops, c_factor = [], [], [], None
+        for split_expr in access.args:
+            var = split_expr.source.source
+            info = iter_to_info.pop(var)
+            loop = info.loop_rv
             is_inner_reduction = info.kind == "R"
-            if split.lower_factor > 1:
-                if c_dom is not None:
+            if split_expr.lower_factor > 1:
+                if c_loops:
                     return None, None
-                c_dom = tir.floormod(var, split.lower_factor)
-                var = tir.floordiv(var, split.lower_factor)
-                dom = tir.floordiv(dom, split.lower_factor)
+                loop, c_loop = sch.split(loop, factors=[None, 
split_expr.lower_factor])
+                c_loops.append(c_loop)
                 if not is_inner_reduction:
-                    c_factor = split.lower_factor
+                    c_factor = split_expr.lower_factor
             if is_inner_reduction:
-                if r_dom is None:
-                    r_dom = var
-                else:
-                    r_dom = r_dom * dom + var
+                r_loops.append(loop)
             else:
-                if s_dom is None:
-                    s_dom = var
+                s_loops.append(loop)
+
+        if iter_to_info:
+            for var, info in iter_to_info.items():
+                if info.kind == "S" and info.dom == 1:
+                    s_loops.append(info.loop_rv)
                 else:
-                    s_dom = s_dom * dom + var
-
-        assert r_dom is not None
-        if s_dom is None:
-            s_dom = tir.const(1, r_dom.dtype)
-        if c_dom is None:
-            c_dom = tir.const(1, r_dom.dtype)
-        sch.transform_block_layout(
-            block_info.block_rv,
-            tir.IndexMap(
-                [i.var for i in block_info.iters],
-                [s_dom, r_dom, c_dom],
-                None,
-            ),
-        )
+                    return None, None
+        assert s_loops
+        assert r_loops
+        if len(s_loops) != len([i for i in block_info.iters if i.kind == "S"]):
+            return None, None
+        if not c_loops:
+            c_loops = [sch.add_unit_loop(block_info.block_rv)]
+        sch.reorder(*s_loops, *r_loops, *c_loops)
+        sch.fuse(*s_loops)
+        sch.fuse(*r_loops)
         return is_inner_reduction, c_factor
 
     def _sch_inner_reduction(
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 35bf7b7669..8c46af4a30 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -456,6 +456,7 @@ Optional<ObjectRef> NormalizePrimFunc(Schedule sch) {
   Array<IntImm> block_is_reduction;
   for (const BlockRV& block : blocks) {
     Array<IterVar> iters = sch->Get(block)->iter_vars;
+    bool has_spatial_iter = false;
     Array<Var> index_map_inputs;
     Array<PrimExpr> index_map_outputs;
     for (const IterVar& iter : sch->Get(block)->iter_vars) {
@@ -463,10 +464,13 @@ Optional<ObjectRef> NormalizePrimFunc(Schedule sch) {
       index_map_inputs.push_back(var);
       if (!is_one(iter->dom->extent)) {
         index_map_outputs.push_back(var);
+        if (iter->iter_type == IterVarType::kDataPar) {
+          has_spatial_iter = true;
+        }
       }
     }
-    if (index_map_outputs.empty()) {
-      index_map_outputs.push_back(make_zero(DataType::Int(64)));
+    if (index_map_outputs.empty() || !has_spatial_iter) {
+      index_map_outputs.insert(index_map_outputs.begin(), 
tir::make_const(DataType::Int(64), 0));
     }
     sch->TransformBlockLayout(block, IndexMap(index_map_inputs, 
index_map_outputs));
     block_loops.push_back(sch->GetLoops(block));
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py 
b/tests/python/dlight/test_gpu_decode_gemv.py
index 303b16809e..ba923a7d21 100644
--- a/tests/python/dlight/test_gpu_decode_gemv.py
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -30,6 +30,7 @@ def test_decode_gemv_1():
         @T.prim_func
         def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
             T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+            # with T.block("root"):
             B = T.alloc_buffer((4096, 4096), "float16")
             for i, j in T.grid(4096, 4096):
                 with T.block("decode"):
@@ -64,7 +65,7 @@ def test_decode_gemv_1():
                         with T.block("matmul_rf_update"):
                             vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
                             v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR", 
[i2_i0_i1_fused, k_0_fused_0, k_1])
-                            C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 256 + 
vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i2, vk_0_fused_0 * 256 + vk_0_fused_1], 
T.Cast("uint32", ((vk_0_fused_0 * 256 + vk_0_fused_1) * 8 + vk_1) % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 256 + 
vk_0_fused_1) // 4])
+                            C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 + 
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + 
vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) % 
8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 
2048 + vk_0_fused_1 * 8 + vk_1) // 32])
                 for ax1_ax2_ax3_fused in range(1): # pylint: 
disable=unused-variable
                     for ax0_fused in T.thread_binding(256, 
thread="threadIdx.x"):
                         with T.block("matmul"):
@@ -126,7 +127,7 @@ def test_decode_gemv_2():
                                 vk_0_fused_1 = T.axis.spatial(16, k_0_fused_1)
                                 v_i2 = T.axis.spatial(4096, i2_i0_i1_fused_0 * 
16 + i2_i0_i1_fused_1)
                                 vk_0_fused_0, vk_1 = T.axis.remap("RR", 
[k_0_fused_0, k_1])
-                                C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 16 + 
vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[vk_0_fused_0 * 16 + vk_0_fused_1, v_i2], 
T.Cast("uint32", ((vk_0_fused_0 * 16 + vk_0_fused_1) * 8 + vk_1) % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 16 + 
vk_0_fused_1) // 4, v_i2])
+                                C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 128 + 
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[(vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) // 
8, v_i2], T.Cast("uint32", (vk_0_fused_0 * 128 + vk_0_fused_1 * 8 + vk_1) % 8) 
* T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[(vk_0_fused_0 * 128 + 
vk_0_fused_1 * 8 + vk_1) // 32, v_i2])
                 for ax1_ax2_ax3_fused in T.thread_binding(16, 
thread="threadIdx.x"):
                     for ax0_fused in T.thread_binding(16, 
thread="threadIdx.y"):
                         with T.block("matmul"):
@@ -182,26 +183,23 @@ def test_decode_gemv_3():
                     for i2_1_init in range(8):
                         with T.block("matmul_rf_init"):
                             vk_fused_1 = T.axis.spatial(256, k_fused_1)
-                            v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused)
-                            v_i2 = T.axis.spatial(8, i2_1_init)
-                            C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] = 
T.float16(0)
+                            v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 + 
i2_1_init)
+                            C_rf_local[vk_fused_1, 0, 0, v_i2] = T.float16(0)
                     for k_fused_0, i2_1 in T.grid(16, 8):
                         with T.block("matmul_rf_update"):
                             vk_fused_1 = T.axis.spatial(256, k_fused_1)
-                            v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused)
+                            v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 8 + 
i2_1)
                             vk_fused_0 = T.axis.reduce(16, k_fused_0)
-                            v_i2 = T.axis.spatial(8, i2_1)
-                            C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] = 
C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2] + V[0, 0, vk_fused_0 * 256 + 
vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i1, 
vk_fused_0 * 256 + vk_fused_1], T.Cast("uint32", (v_i1 * 8 + v_i2) % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i1 // 4, vk_fused_0 * 256 + 
vk_fused_1])
+                            C_rf_local[vk_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 256 + vk_fused_1] * 
((T.Cast("float16", T.bitwise_and(T.shift_right(W[v_i2 // 8, vk_fused_0 * 256 + 
vk_fused_1], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * S[v_i2 // 32, vk_fused_0 * 256 + vk_fused_1])
                 for ax1_ax2_ax3_fused_0 in range(1):
                     for ax0_fused in T.thread_binding(256, 
thread="threadIdx.x"):
                         for ax1_ax2_ax3_fused_1 in range(8):
                             with T.block("matmul"):
                                 vk_fused_1 = T.axis.reduce(256, ax0_fused)
-                                v_i1 = T.axis.spatial(512, i2_0_i0_i1_fused)
-                                v_i2 = T.axis.spatial(8, ax1_ax2_ax3_fused_0 * 
8 + ax1_ax2_ax3_fused_1)
+                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused * 
8 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)
                                 with T.init():
-                                    C[0, 0, v_i1 * 8 + v_i2] = T.float16(0)
-                                C[0, 0, v_i1 * 8 + v_i2] = C[0, 0, v_i1 * 8 + 
v_i2] + C_rf_local[vk_fused_1, 0, 0, v_i1 * 8 + v_i2]
+                                    C[0, 0, v_i2] = T.float16(0)
+                                C[0, 0, v_i2] = C[0, 0, v_i2] + 
C_rf_local[vk_fused_1, 0, 0, v_i2]
 
     # fmt: on
 
@@ -242,6 +240,7 @@ def test_decode_gemv_4():
         @T.prim_func
         def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), 
"float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), 
"float16")):
             T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+            # with T.block("root"):
             C_rf_local = T.alloc_buffer((16, 1, 1, 4096), "float16", 
scope="local")
             for i2_0_i0_i1_fused_0 in T.thread_binding(32, 
thread="blockIdx.x"):
                 for i2_0_i0_i1_fused_1 in T.thread_binding(16, 
thread="threadIdx.x"):
@@ -249,26 +248,23 @@ def test_decode_gemv_4():
                         for i2_1_init in range(8):
                             with T.block("matmul_rf_init"):
                                 vk_fused_1 = T.axis.spatial(16, k_fused_1)
-                                v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 * 
16 + i2_0_i0_i1_fused_1)
-                                v2 = T.axis.spatial(8, i2_1_init)
-                                C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] = 
T.float16(0)
+                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 
* 128 + i2_0_i0_i1_fused_1 * 8 + i2_1_init)
+                                C_rf_local[vk_fused_1, 0, 0, v_i2] = 
T.float16(0)
                         for k_fused_0, i2_1 in T.grid(256, 8):
                             with T.block("matmul_rf_update"):
                                 vk_fused_1 = T.axis.spatial(16, k_fused_1)
-                                v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 * 
16 + i2_0_i0_i1_fused_1)
+                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 
* 128 + i2_0_i0_i1_fused_1 * 8 + i2_1)
                                 vk_fused_0 = T.axis.reduce(256, k_fused_0)
-                                v2 = T.axis.spatial(8, i2_1)
-                                C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] = 
C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2] + V[0, 0, vk_fused_0 * 16 + 
vk_fused_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 
+ vk_fused_1, v1], T.Cast("uint32", (v1 * 8 + v2) % 8) * T.uint32(4)), 
T.uint32(15))) - T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v1 // 4])
+                                C_rf_local[vk_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_fused_1, 0, 0, v_i2] + V[0, 0, vk_fused_0 * 16 + vk_fused_1] * 
((T.Cast("float16", T.bitwise_and(T.shift_right(W[vk_fused_0 * 16 + vk_fused_1, 
v_i2 // 8], T.Cast("uint32", v_i2 % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * S[vk_fused_0 * 16 + vk_fused_1, v_i2 // 32])
                 for ax1_ax2_ax3_fused_0 in T.thread_binding(16, 
thread="threadIdx.x"):
                     for ax0_fused in T.thread_binding(16, 
thread="threadIdx.y"):
                         for ax1_ax2_ax3_fused_1 in range(8):
                             with T.block("matmul"):
                                 vk_fused_1 = T.axis.reduce(16, ax0_fused)
-                                v1 = T.axis.spatial(512, i2_0_i0_i1_fused_0 * 
16 + (ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1) // 8)
-                                v2 = T.axis.spatial(8, (ax1_ax2_ax3_fused_0 * 
8 + ax1_ax2_ax3_fused_1) % 8)
+                                v_i2 = T.axis.spatial(4096, i2_0_i0_i1_fused_0 
* 128 + ax1_ax2_ax3_fused_0 * 8 + ax1_ax2_ax3_fused_1)
                                 with T.init():
-                                    C[0, 0, v1 * 8 + v2] = T.float16(0)
-                                C[0, 0, v1 * 8 + v2] = C[0, 0, v1 * 8 + v2] + 
C_rf_local[vk_fused_1, 0, 0, v1 * 8 + v2]
+                                    C[0, 0, v_i2] = T.float16(0)
+                                C[0, 0, v_i2] = C[0, 0, v_i2] + 
C_rf_local[vk_fused_1, 0, 0, v_i2]
 
     # fmt: on
 
@@ -328,8 +324,8 @@ def test_decode_gemv_sigmoid():
                         with T.block("matmul_rf_update"):
                             vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
                             v_i2, vk_0_fused_0, vk_1 = T.axis.remap("SRR", 
[i2_i0_i1_fused, k_0_fused_0, k_1])
-                            C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, (vk_0_fused_0 * 256 + 
vk_0_fused_1) * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i2, vk_0_fused_0 * 256 + vk_0_fused_1], 
T.Cast("uint32", ((vk_0_fused_0 * 256 + vk_0_fused_1) * 8 + vk_1) % 8) * 
T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 256 + 
vk_0_fused_1) // 4])
-                for ax1_ax2_ax3_fused in range(1): # pylint: 
disable=unused-variable
+                            C_rf_local[vk_0_fused_1, 0, 0, v_i2] = 
C_rf_local[vk_0_fused_1, 0, 0, v_i2] + V[0, 0, vk_0_fused_0 * 2048 + 
vk_0_fused_1 * 8 + vk_1] * ((T.Cast("float16", 
T.bitwise_and(T.shift_right(W[v_i2, (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + 
vk_1) // 8], T.Cast("uint32", (vk_0_fused_0 * 2048 + vk_0_fused_1 * 8 + vk_1) % 
8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v_i2, (vk_0_fused_0 * 
2048 + vk_0_fused_1 * 8 + vk_1) // 32])
+                for ax1_ax2_ax3_fused in range(1):  # pylint: 
disable=unused-variable
                     for ax0_fused in T.thread_binding(256, 
thread="threadIdx.x"):
                         with T.block("matmul"):
                             vk_0_fused_1 = T.axis.reduce(256, ax0_fused)
@@ -400,8 +396,10 @@ def test_decode_gemv_1_fp32():
                     for ax1_0_fused_0, ax1_1 in T.grid(2, 8):
                         with T.block("matmul_rf_update"):
                             vax1_0_fused_1, v0, vax1_0_fused_0, vax1_1 = 
T.axis.remap("SSRR", [ax1_0_fused_1, ax0_fused, ax1_0_fused_0, ax1_1])
-                            C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = 
C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0, 
(vax1_0_fused_0 * 256 + vax1_0_fused_1) * 8 + vax1_1]) * T.Cast("float32", 
(T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, vax1_0_fused_0 * 256 + 
vax1_0_fused_1], T.Cast("uint32", ((vax1_0_fused_0 * 256 + vax1_0_fused_1) * 8 
+ vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * S[v0, 
(vax1_0_fused_0 * 256 + vax1_0_fused_1) // 4])
-                for ax1_fused in range(1):
+                            T.reads(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0], 
V[0, 0, vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1], W[v0, 
(vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 8], S[v0, 
(vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1) // 32])
+                            T.writes(C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0])
+                            C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] = 
C_fp32_rf_local[vax1_0_fused_1, 0, 0, v0] + T.Cast("float32", V[0, 0, 
vax1_0_fused_0 * 2048 + vax1_0_fused_1 * 8 + vax1_1]) * T.Cast("float32", 
(T.Cast("float16", T.bitwise_and(T.shift_right(W[v0, (vax1_0_fused_0 * 2048 + 
vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 2048 + 
vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * S[v0, (vax1_0_fused_0 * 2048 + vax1_0 [...]
+                for ax1_fused in range(1):  # pylint: disable=unused-variable
                     for ax0_fused_1 in T.thread_binding(256, 
thread="threadIdx.x"):
                         with T.block("matmul"):
                             vax1_0_fused_1, v0 = T.axis.remap("RS", 
[ax0_fused_1, ax0_fused])
@@ -424,6 +422,71 @@ def test_decode_gemv_1_fp32():
     assert_structural_equal(mod, After)
 
 
+def test_reduction_no_spatial():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), 
"float16"), rms_norm: T.Buffer((1, 4096), "float16")):
+            Ared_temp = T.alloc_buffer((1, 1))
+            for ax0 in range(4096):
+                with T.block("Ared_temp"):
+                    v0 = T.axis.reduce(4096, ax0)
+                    with T.init():
+                        Ared_temp[0, 0] = T.float32(0)
+                    Ared_temp[0, 0] = Ared_temp[0, 0] + T.Cast("float32", A[0, 
0, v0]) * T.Cast("float32", A[0, 0, v0])
+            for ax0 in range(4096):
+                with T.block("rms_norm"):
+                    v0 = T.axis.spatial(4096, ax0)
+                    rms_norm[0, v0] = T.Cast("float16", T.Cast("float32", 
B[v0]) * (T.Cast("float32", A[0, 0, v0]) / T.sqrt(Ared_temp[0, 0] * 
T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), 
"float16"), rms_norm: T.Buffer((1, 4096), "float16")):
+            T.func_attr({"tir.is_scheduled": 1})
+            # with T.block("root"):
+            Ared_temp_local = T.alloc_buffer((1, 1), scope="local")
+            Ared_temp_rf_local = T.alloc_buffer((256, 1, 1), scope="local")
+            for ax0_fused in T.thread_binding(T.int64(1), 
thread="blockIdx.x"): # pylint: disable=unused-variable
+                for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+                    with T.block("Ared_temp_rf_init"):
+                        vax1_fused_1 = T.axis.spatial(256, ax1_fused_1)
+                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                        T.reads()
+                        T.writes(Ared_temp_rf_local[vax1_fused_1, 0, 0])
+                        Ared_temp_rf_local[vax1_fused_1, 0, 0] = T.float32(0)
+                    for ax1_fused_0, u in T.grid(16, 1): # pylint: 
disable=unused-variable
+                        with T.block("Ared_temp_rf_update"):
+                            vax1_fused_1 = T.axis.spatial(256, ax1_fused_1)
+                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                            vax1_fused_0 = T.axis.reduce(16, ax1_fused_0)
+                            T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0], 
A[0, 0, vax1_fused_0 * 256 + vax1_fused_1])
+                            T.writes(Ared_temp_rf_local[vax1_fused_1, 0, 0])
+                            Ared_temp_rf_local[vax1_fused_1, 0, 0] = 
Ared_temp_rf_local[vax1_fused_1, 0, 0] + T.Cast("float32", A[0, 0, vax1_fused_0 
* 256 + vax1_fused_1]) * T.Cast("float32", A[0, 0, vax1_fused_0 * 256 + 
vax1_fused_1])
+                for ax1_fused in range(T.int64(1)): # pylint: 
disable=unused-variable
+                    for ax0 in T.thread_binding(256, thread="threadIdx.x"):
+                        with T.block("Ared_temp"):
+                            vax1_fused_1 = T.axis.reduce(256, ax0)
+                            v0 = T.axis.spatial(T.int64(1), T.int64(0))
+                            T.reads(Ared_temp_rf_local[vax1_fused_1, 0, 0])
+                            T.writes(Ared_temp_local[0, 0])
+                            with T.init():
+                                Ared_temp_local[0, 0] = T.float32(0)
+                            Ared_temp_local[0, 0] = Ared_temp_local[0, 0] + 
Ared_temp_rf_local[vax1_fused_1, 0, 0]
+                for ax0 in range(4096):
+                    with T.block("rms_norm"):
+                        v0 = T.axis.spatial(4096, ax0)
+                        T.reads(B[v0], A[0, 0, v0], Ared_temp_local[0, 0])
+                        T.writes(rms_norm[0, v0])
+                        rms_norm[0, v0] = T.Cast("float16", T.Cast("float32", 
B[v0]) * (T.Cast("float32", A[0, 0, v0]) / T.sqrt(Ared_temp_local[0, 0] * 
T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))))
+    # fmt: on
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.DecodeGEMV())(Before)  # pylint: 
disable=not-callable
+    assert_structural_equal(mod, After)
+
+
 if __name__ == "__main__":
     test_decode_gemv_1()
     test_decode_gemv_2()
@@ -431,3 +494,4 @@ if __name__ == "__main__":
     test_decode_gemv_4()
     test_decode_gemv_sigmoid()
     test_decode_gemv_1_fp32()
+    test_reduction_no_spatial()

Reply via email to