Hzfengsy commented on code in PR #16932:
URL: https://github.com/apache/tvm/pull/16932#discussion_r1582533172


##########
python/tvm/dlight/gpu/gemv.py:
##########
@@ -551,6 +560,192 @@ def sch_outer_reduction(  # pylint: 
disable=too-many-arguments, invalid-name, un
         block: tir.schedule.BlockRV,
         vector_input_buffers: List[tir.Buffer],
         epilogue_info: Optional[BlockInfo],
+    ):
+        """Schedule the inner reduction block."""

Review Comment:
   ```suggestion
           """Schedule the outer reduction block."""
   ```



##########
python/tvm/dlight/gpu/gemv.py:
##########
@@ -551,6 +560,192 @@ def sch_outer_reduction(  # pylint: 
disable=too-many-arguments, invalid-name, un
         block: tir.schedule.BlockRV,
         vector_input_buffers: List[tir.Buffer],
         epilogue_info: Optional[BlockInfo],
+    ):
+        """Schedule the inner reduction block."""
+
+        def get_max_factor(n, factors):
+            factors = sorted(factors, reverse=True)
+            for factor in factors:
+                if n % factor == 0:
+                    return factor
+            return 1
+
+        def apply(
+            sch: tir.Schedule,
+            gemv,
+            TAG_S,
+            TAG_R,
+            TS,
+            TR,
+            SCALE_PACK,
+            DEC_PACK,
+            VEC_LOAD,
+            VEC_C,
+            LOAD_V_SHARED,
+            LOAD_V_VEC,
+            UNROLL,
+            LOAD_V_TILE,
+        ):
+            # rfactor: reduce to tx * vec_c
+            batch, s, r, c = sch.get_loops(block=gemv)
+            s = sch.fuse(batch, s)
+            r = sch.fuse(r, c)
+            bx, ts = sch.split(s, factors=[None, TS], preserve_unit_iters=True)
+            r, v_tile, tr, tile_r, vec_c = sch.split(
+                r, factors=[None, LOAD_V_TILE, TR, SCALE_PACK, DEC_PACK], 
preserve_unit_iters=True
+            )
+            sch.reorder(bx, ts, r, v_tile, tile_r, tr, vec_c)
+            tr_vec_c = sch.fuse(tr, vec_c)
+            rf = sch.rfactor(tr_vec_c, 0)
+
+            # rfactor: reduce to tx
+            bx, ts, tr_vec_c = sch.get_loops(block=gemv)
+            tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], 
preserve_unit_iters=True)
+            rf2 = sch.rfactor(tr, 0)
+
+            # bind, vectorize compute
+            bx, ts, r, v_tile, tile_r, tr_vec_c = sch.get_loops(block=rf)
+            tr, vec_c = sch.split(tr_vec_c, factors=[TR, DEC_PACK])
+            sch.reorder(bx, ts, tr, r, v_tile, tile_r, vec_c)
+            # sch.bind(batch, "blockIdx.z")
+            sch.bind(bx, "blockIdx.x")
+            sch.bind(ts, "threadIdx.x")
+            sch.bind(tr, "threadIdx.y")
+            sch.vectorize(vec_c)
+
+            # decompose independent scale read to outer loop
+            block_rf_stmt = sch.get(rf)
+            if len(block_rf_stmt.reads) >= 3:
+                As_local = sch.cache_read(rf, read_buffer_index=2, 
storage_scope="local")
+                sch.compute_at(As_local, v_tile, preserve_unit_loops=True)
+                # *tile_thr, vec_s = sch.get_loops(block=As_local)
+                # sch.vectorize(vec_s)
+
+            Aq_local = sch.cache_read(rf, read_buffer_index=1, 
storage_scope="local")
+            sch.compute_at(Aq_local, tile_r, preserve_unit_loops=True)
+            # *tile_thr, vec_s = sch.get_loops(block=Aq_local)
+            # sch.vectorize(vec_s)
+
+            if LOAD_V_SHARED:
+                V_shared = sch.cache_read(rf, read_buffer_index=0, 
storage_scope="shared")
+                sch.compute_at(V_shared, r, preserve_unit_loops=True)
+                l = sch.get_loops(block=V_shared)[-1]
+                _, v_tile, tx, ty, vec = sch.split(
+                    l, factors=[None, LOAD_V_TILE, TS, TR, LOAD_V_VEC], 
preserve_unit_iters=True
+                )
+                sch.bind(ty, "threadIdx.y")
+                sch.bind(tx, "threadIdx.x")
+                sch.vectorize(vec)
+
+            # reduce tile_s * tr * vec to tile_s * tr
+            sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
+            tr, vec_c, ts = sch.get_loops(block=rf2)[1:]
+            sch.reorder(ts, tr, vec_c)
+            sch.bind(ts, "threadIdx.x")
+            sch.bind(tr, "threadIdx.y")
+
+            # reduce tile_s * tr to tile_s
+            sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
+            tr, ts = sch.get_loops(block=gemv)[1:]
+            sch.reorder(ts, tr)
+            sch.bind(ts, "threadIdx.x")
+            sch.bind(tr, "threadIdx.y")
+
+            sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[2])
+            sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
+
+            sch.set_scope(rf, buffer_index=0, storage_scope="local")
+            sch.set_scope(rf2, buffer_index=0, storage_scope="local")
+
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf2)[3],
+                ann_key="pragma_auto_unroll_max_step",
+                ann_val=DEC_PACK,
+            )
+            sch.annotate(

Review Comment:
   If I remember correctly, explicit unroll 1 is not needed. As it will not 
unroll by default



##########
python/tvm/dlight/gpu/gemv.py:
##########
@@ -551,6 +560,192 @@ def sch_outer_reduction(  # pylint: 
disable=too-many-arguments, invalid-name, un
         block: tir.schedule.BlockRV,
         vector_input_buffers: List[tir.Buffer],
         epilogue_info: Optional[BlockInfo],
+    ):
+        """Schedule the inner reduction block."""
+
+        def get_max_factor(n, factors):
+            factors = sorted(factors, reverse=True)
+            for factor in factors:
+                if n % factor == 0:
+                    return factor
+            return 1
+
+        def apply(
+            sch: tir.Schedule,
+            gemv,
+            TAG_S,
+            TAG_R,
+            TS,
+            TR,
+            SCALE_PACK,
+            DEC_PACK,
+            VEC_LOAD,
+            VEC_C,
+            LOAD_V_SHARED,
+            LOAD_V_VEC,
+            UNROLL,
+            LOAD_V_TILE,
+        ):
+            # rfactor: reduce to tx * vec_c
+            batch, s, r, c = sch.get_loops(block=gemv)
+            s = sch.fuse(batch, s)
+            r = sch.fuse(r, c)
+            bx, ts = sch.split(s, factors=[None, TS], preserve_unit_iters=True)
+            r, v_tile, tr, tile_r, vec_c = sch.split(
+                r, factors=[None, LOAD_V_TILE, TR, SCALE_PACK, DEC_PACK], 
preserve_unit_iters=True
+            )
+            sch.reorder(bx, ts, r, v_tile, tile_r, tr, vec_c)
+            tr_vec_c = sch.fuse(tr, vec_c)
+            rf = sch.rfactor(tr_vec_c, 0)
+
+            # rfactor: reduce to tx
+            bx, ts, tr_vec_c = sch.get_loops(block=gemv)
+            tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], 
preserve_unit_iters=True)
+            rf2 = sch.rfactor(tr, 0)
+
+            # bind, vectorize compute
+            bx, ts, r, v_tile, tile_r, tr_vec_c = sch.get_loops(block=rf)
+            tr, vec_c = sch.split(tr_vec_c, factors=[TR, DEC_PACK])
+            sch.reorder(bx, ts, tr, r, v_tile, tile_r, vec_c)
+            # sch.bind(batch, "blockIdx.z")
+            sch.bind(bx, "blockIdx.x")
+            sch.bind(ts, "threadIdx.x")
+            sch.bind(tr, "threadIdx.y")
+            sch.vectorize(vec_c)
+
+            # decompose independent scale read to outer loop
+            block_rf_stmt = sch.get(rf)
+            if len(block_rf_stmt.reads) >= 3:
+                As_local = sch.cache_read(rf, read_buffer_index=2, 
storage_scope="local")
+                sch.compute_at(As_local, v_tile, preserve_unit_loops=True)
+                # *tile_thr, vec_s = sch.get_loops(block=As_local)
+                # sch.vectorize(vec_s)
+
+            Aq_local = sch.cache_read(rf, read_buffer_index=1, 
storage_scope="local")
+            sch.compute_at(Aq_local, tile_r, preserve_unit_loops=True)
+            # *tile_thr, vec_s = sch.get_loops(block=Aq_local)
+            # sch.vectorize(vec_s)
+
+            if LOAD_V_SHARED:
+                V_shared = sch.cache_read(rf, read_buffer_index=0, 
storage_scope="shared")
+                sch.compute_at(V_shared, r, preserve_unit_loops=True)
+                l = sch.get_loops(block=V_shared)[-1]
+                _, v_tile, tx, ty, vec = sch.split(
+                    l, factors=[None, LOAD_V_TILE, TS, TR, LOAD_V_VEC], 
preserve_unit_iters=True
+                )
+                sch.bind(ty, "threadIdx.y")
+                sch.bind(tx, "threadIdx.x")
+                sch.vectorize(vec)
+
+            # reduce tile_s * tr * vec to tile_s * tr
+            sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
+            tr, vec_c, ts = sch.get_loops(block=rf2)[1:]
+            sch.reorder(ts, tr, vec_c)
+            sch.bind(ts, "threadIdx.x")
+            sch.bind(tr, "threadIdx.y")
+
+            # reduce tile_s * tr to tile_s
+            sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
+            tr, ts = sch.get_loops(block=gemv)[1:]
+            sch.reorder(ts, tr)
+            sch.bind(ts, "threadIdx.x")
+            sch.bind(tr, "threadIdx.y")
+
+            sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[2])
+            sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
+
+            sch.set_scope(rf, buffer_index=0, storage_scope="local")
+            sch.set_scope(rf2, buffer_index=0, storage_scope="local")
+
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf2)[3],
+                ann_key="pragma_auto_unroll_max_step",
+                ann_val=DEC_PACK,
+            )
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf2)[3], 
ann_key="pragma_unroll_explicit", ann_val=1
+            )
+
+            # Schedule epilogue
+            if epilogue_info is not None:
+                epilogue = epilogue_info.block_rv
+                if is_broadcast_epilogue(sch, block, epilogue):
+                    sch.reverse_compute_at(epilogue, bx)
+                    sch.set_scope(block, 0, "shared")
+                    _, _, *s = sch.get_loops(epilogue)  # pylint: 
disable=invalid-name
+                    _, tx = sch.split(sch.fuse(*s), factors=[None, TX])
+                    sch.bind(tx, "threadIdx.x")
+                else:
+                    sch.reverse_compute_at(epilogue, bx, 
preserve_unit_loops=True)
+                    ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:])
+                    ts_tile_s = sch.get_loops(epilogue)[-1]
+                    ts, _ = sch.split(ts_tile_s, factors=[TS, None], 
preserve_unit_iters=True)
+                    sch.bind(ts, "threadIdx.x")
+                    sch.set_scope(block, 0, "local")
+            return sch
+            # return sch.mod["main"].with_attr("tir.is_scheduled", 1)

Review Comment:
   Please remove this line



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to