Hzfengsy commented on code in PR #16932:
URL: https://github.com/apache/tvm/pull/16932#discussion_r1582533172
##########
python/tvm/dlight/gpu/gemv.py:
##########
@@ -551,6 +560,192 @@ def sch_outer_reduction( # pylint:
disable=too-many-arguments, invalid-name, un
block: tir.schedule.BlockRV,
vector_input_buffers: List[tir.Buffer],
epilogue_info: Optional[BlockInfo],
+ ):
+ """Schedule the inner reduction block."""
Review Comment:
```suggestion
"""Schedule the outer reduction block."""
```
##########
python/tvm/dlight/gpu/gemv.py:
##########
@@ -551,6 +560,192 @@ def sch_outer_reduction( # pylint:
disable=too-many-arguments, invalid-name, un
block: tir.schedule.BlockRV,
vector_input_buffers: List[tir.Buffer],
epilogue_info: Optional[BlockInfo],
+ ):
+ """Schedule the inner reduction block."""
+
+ def get_max_factor(n, factors):
+ factors = sorted(factors, reverse=True)
+ for factor in factors:
+ if n % factor == 0:
+ return factor
+ return 1
+
+ def apply(
+ sch: tir.Schedule,
+ gemv,
+ TAG_S,
+ TAG_R,
+ TS,
+ TR,
+ SCALE_PACK,
+ DEC_PACK,
+ VEC_LOAD,
+ VEC_C,
+ LOAD_V_SHARED,
+ LOAD_V_VEC,
+ UNROLL,
+ LOAD_V_TILE,
+ ):
+ # rfactor: reduce to tx * vec_c
+ batch, s, r, c = sch.get_loops(block=gemv)
+ s = sch.fuse(batch, s)
+ r = sch.fuse(r, c)
+ bx, ts = sch.split(s, factors=[None, TS], preserve_unit_iters=True)
+ r, v_tile, tr, tile_r, vec_c = sch.split(
+ r, factors=[None, LOAD_V_TILE, TR, SCALE_PACK, DEC_PACK],
preserve_unit_iters=True
+ )
+ sch.reorder(bx, ts, r, v_tile, tile_r, tr, vec_c)
+ tr_vec_c = sch.fuse(tr, vec_c)
+ rf = sch.rfactor(tr_vec_c, 0)
+
+ # rfactor: reduce to tx
+ bx, ts, tr_vec_c = sch.get_loops(block=gemv)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, None],
preserve_unit_iters=True)
+ rf2 = sch.rfactor(tr, 0)
+
+ # bind, vectorize compute
+ bx, ts, r, v_tile, tile_r, tr_vec_c = sch.get_loops(block=rf)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, DEC_PACK])
+ sch.reorder(bx, ts, tr, r, v_tile, tile_r, vec_c)
+ # sch.bind(batch, "blockIdx.z")
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(ts, "threadIdx.x")
+ sch.bind(tr, "threadIdx.y")
+ sch.vectorize(vec_c)
+
+ # decompose independent scale read to outer loop
+ block_rf_stmt = sch.get(rf)
+ if len(block_rf_stmt.reads) >= 3:
+ As_local = sch.cache_read(rf, read_buffer_index=2,
storage_scope="local")
+ sch.compute_at(As_local, v_tile, preserve_unit_loops=True)
+ # *tile_thr, vec_s = sch.get_loops(block=As_local)
+ # sch.vectorize(vec_s)
+
+ Aq_local = sch.cache_read(rf, read_buffer_index=1,
storage_scope="local")
+ sch.compute_at(Aq_local, tile_r, preserve_unit_loops=True)
+ # *tile_thr, vec_s = sch.get_loops(block=Aq_local)
+ # sch.vectorize(vec_s)
+
+ if LOAD_V_SHARED:
+ V_shared = sch.cache_read(rf, read_buffer_index=0,
storage_scope="shared")
+ sch.compute_at(V_shared, r, preserve_unit_loops=True)
+ l = sch.get_loops(block=V_shared)[-1]
+ _, v_tile, tx, ty, vec = sch.split(
+ l, factors=[None, LOAD_V_TILE, TS, TR, LOAD_V_VEC],
preserve_unit_iters=True
+ )
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec)
+
+ # reduce tile_s * tr * vec to tile_s * tr
+ sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
+ tr, vec_c, ts = sch.get_loops(block=rf2)[1:]
+ sch.reorder(ts, tr, vec_c)
+ sch.bind(ts, "threadIdx.x")
+ sch.bind(tr, "threadIdx.y")
+
+ # reduce tile_s * tr to tile_s
+ sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
+ tr, ts = sch.get_loops(block=gemv)[1:]
+ sch.reorder(ts, tr)
+ sch.bind(ts, "threadIdx.x")
+ sch.bind(tr, "threadIdx.y")
+
+ sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[2])
+ sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
+
+ sch.set_scope(rf, buffer_index=0, storage_scope="local")
+ sch.set_scope(rf2, buffer_index=0, storage_scope="local")
+
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf2)[3],
+ ann_key="pragma_auto_unroll_max_step",
+ ann_val=DEC_PACK,
+ )
+ sch.annotate(
Review Comment:
If I remember correctly, explicit unroll 1 is not needed. As it will not
unroll by default
##########
python/tvm/dlight/gpu/gemv.py:
##########
@@ -551,6 +560,192 @@ def sch_outer_reduction( # pylint:
disable=too-many-arguments, invalid-name, un
block: tir.schedule.BlockRV,
vector_input_buffers: List[tir.Buffer],
epilogue_info: Optional[BlockInfo],
+ ):
+ """Schedule the inner reduction block."""
+
+ def get_max_factor(n, factors):
+ factors = sorted(factors, reverse=True)
+ for factor in factors:
+ if n % factor == 0:
+ return factor
+ return 1
+
+ def apply(
+ sch: tir.Schedule,
+ gemv,
+ TAG_S,
+ TAG_R,
+ TS,
+ TR,
+ SCALE_PACK,
+ DEC_PACK,
+ VEC_LOAD,
+ VEC_C,
+ LOAD_V_SHARED,
+ LOAD_V_VEC,
+ UNROLL,
+ LOAD_V_TILE,
+ ):
+ # rfactor: reduce to tx * vec_c
+ batch, s, r, c = sch.get_loops(block=gemv)
+ s = sch.fuse(batch, s)
+ r = sch.fuse(r, c)
+ bx, ts = sch.split(s, factors=[None, TS], preserve_unit_iters=True)
+ r, v_tile, tr, tile_r, vec_c = sch.split(
+ r, factors=[None, LOAD_V_TILE, TR, SCALE_PACK, DEC_PACK],
preserve_unit_iters=True
+ )
+ sch.reorder(bx, ts, r, v_tile, tile_r, tr, vec_c)
+ tr_vec_c = sch.fuse(tr, vec_c)
+ rf = sch.rfactor(tr_vec_c, 0)
+
+ # rfactor: reduce to tx
+ bx, ts, tr_vec_c = sch.get_loops(block=gemv)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, None],
preserve_unit_iters=True)
+ rf2 = sch.rfactor(tr, 0)
+
+ # bind, vectorize compute
+ bx, ts, r, v_tile, tile_r, tr_vec_c = sch.get_loops(block=rf)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, DEC_PACK])
+ sch.reorder(bx, ts, tr, r, v_tile, tile_r, vec_c)
+ # sch.bind(batch, "blockIdx.z")
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(ts, "threadIdx.x")
+ sch.bind(tr, "threadIdx.y")
+ sch.vectorize(vec_c)
+
+ # decompose independent scale read to outer loop
+ block_rf_stmt = sch.get(rf)
+ if len(block_rf_stmt.reads) >= 3:
+ As_local = sch.cache_read(rf, read_buffer_index=2,
storage_scope="local")
+ sch.compute_at(As_local, v_tile, preserve_unit_loops=True)
+ # *tile_thr, vec_s = sch.get_loops(block=As_local)
+ # sch.vectorize(vec_s)
+
+ Aq_local = sch.cache_read(rf, read_buffer_index=1,
storage_scope="local")
+ sch.compute_at(Aq_local, tile_r, preserve_unit_loops=True)
+ # *tile_thr, vec_s = sch.get_loops(block=Aq_local)
+ # sch.vectorize(vec_s)
+
+ if LOAD_V_SHARED:
+ V_shared = sch.cache_read(rf, read_buffer_index=0,
storage_scope="shared")
+ sch.compute_at(V_shared, r, preserve_unit_loops=True)
+ l = sch.get_loops(block=V_shared)[-1]
+ _, v_tile, tx, ty, vec = sch.split(
+ l, factors=[None, LOAD_V_TILE, TS, TR, LOAD_V_VEC],
preserve_unit_iters=True
+ )
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec)
+
+ # reduce tile_s * tr * vec to tile_s * tr
+ sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
+ tr, vec_c, ts = sch.get_loops(block=rf2)[1:]
+ sch.reorder(ts, tr, vec_c)
+ sch.bind(ts, "threadIdx.x")
+ sch.bind(tr, "threadIdx.y")
+
+ # reduce tile_s * tr to tile_s
+ sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
+ tr, ts = sch.get_loops(block=gemv)[1:]
+ sch.reorder(ts, tr)
+ sch.bind(ts, "threadIdx.x")
+ sch.bind(tr, "threadIdx.y")
+
+ sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[2])
+ sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
+
+ sch.set_scope(rf, buffer_index=0, storage_scope="local")
+ sch.set_scope(rf2, buffer_index=0, storage_scope="local")
+
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf2)[3],
+ ann_key="pragma_auto_unroll_max_step",
+ ann_val=DEC_PACK,
+ )
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf2)[3],
ann_key="pragma_unroll_explicit", ann_val=1
+ )
+
+ # Schedule epilogue
+ if epilogue_info is not None:
+ epilogue = epilogue_info.block_rv
+ if is_broadcast_epilogue(sch, block, epilogue):
+ sch.reverse_compute_at(epilogue, bx)
+ sch.set_scope(block, 0, "shared")
+ _, _, *s = sch.get_loops(epilogue) # pylint:
disable=invalid-name
+ _, tx = sch.split(sch.fuse(*s), factors=[None, TX])
+ sch.bind(tx, "threadIdx.x")
+ else:
+ sch.reverse_compute_at(epilogue, bx,
preserve_unit_loops=True)
+ ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:])
+ ts_tile_s = sch.get_loops(epilogue)[-1]
+ ts, _ = sch.split(ts_tile_s, factors=[TS, None],
preserve_unit_iters=True)
+ sch.bind(ts, "threadIdx.x")
+ sch.set_scope(block, 0, "local")
+ return sch
+ # return sch.mod["main"].with_attr("tir.is_scheduled", 1)
Review Comment:
Please remove this line
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]