krishnaraj36 commented on code in PR #16932:
URL: https://github.com/apache/tvm/pull/16932#discussion_r1582558769
##########
python/tvm/dlight/gpu/gemv.py:
##########
@@ -551,6 +560,192 @@ def sch_outer_reduction( # pylint:
disable=too-many-arguments, invalid-name, un
block: tir.schedule.BlockRV,
vector_input_buffers: List[tir.Buffer],
epilogue_info: Optional[BlockInfo],
+ ):
+ """Schedule the inner reduction block."""
+
+ def get_max_factor(n, factors):
+ factors = sorted(factors, reverse=True)
+ for factor in factors:
+ if n % factor == 0:
+ return factor
+ return 1
+
+ def apply(
+ sch: tir.Schedule,
+ gemv,
+ TAG_S,
+ TAG_R,
+ TS,
+ TR,
+ SCALE_PACK,
+ DEC_PACK,
+ VEC_LOAD,
+ VEC_C,
+ LOAD_V_SHARED,
+ LOAD_V_VEC,
+ UNROLL,
+ LOAD_V_TILE,
+ ):
+ # rfactor: reduce to tx * vec_c
+ batch, s, r, c = sch.get_loops(block=gemv)
+ s = sch.fuse(batch, s)
+ r = sch.fuse(r, c)
+ bx, ts = sch.split(s, factors=[None, TS], preserve_unit_iters=True)
+ r, v_tile, tr, tile_r, vec_c = sch.split(
+ r, factors=[None, LOAD_V_TILE, TR, SCALE_PACK, DEC_PACK],
preserve_unit_iters=True
+ )
+ sch.reorder(bx, ts, r, v_tile, tile_r, tr, vec_c)
+ tr_vec_c = sch.fuse(tr, vec_c)
+ rf = sch.rfactor(tr_vec_c, 0)
+
+ # rfactor: reduce to tx
+ bx, ts, tr_vec_c = sch.get_loops(block=gemv)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, None],
preserve_unit_iters=True)
+ rf2 = sch.rfactor(tr, 0)
+
+ # bind, vectorize compute
+ bx, ts, r, v_tile, tile_r, tr_vec_c = sch.get_loops(block=rf)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, DEC_PACK])
+ sch.reorder(bx, ts, tr, r, v_tile, tile_r, vec_c)
+ # sch.bind(batch, "blockIdx.z")
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(ts, "threadIdx.x")
+ sch.bind(tr, "threadIdx.y")
+ sch.vectorize(vec_c)
+
+ # decompose independent scale read to outer loop
+ block_rf_stmt = sch.get(rf)
+ if len(block_rf_stmt.reads) >= 3:
+ As_local = sch.cache_read(rf, read_buffer_index=2,
storage_scope="local")
+ sch.compute_at(As_local, v_tile, preserve_unit_loops=True)
+ # *tile_thr, vec_s = sch.get_loops(block=As_local)
+ # sch.vectorize(vec_s)
+
+ Aq_local = sch.cache_read(rf, read_buffer_index=1,
storage_scope="local")
+ sch.compute_at(Aq_local, tile_r, preserve_unit_loops=True)
+ # *tile_thr, vec_s = sch.get_loops(block=Aq_local)
+ # sch.vectorize(vec_s)
+
+ if LOAD_V_SHARED:
+ V_shared = sch.cache_read(rf, read_buffer_index=0,
storage_scope="shared")
+ sch.compute_at(V_shared, r, preserve_unit_loops=True)
+ l = sch.get_loops(block=V_shared)[-1]
+ _, v_tile, tx, ty, vec = sch.split(
+ l, factors=[None, LOAD_V_TILE, TS, TR, LOAD_V_VEC],
preserve_unit_iters=True
+ )
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec)
+
+ # reduce tile_s * tr * vec to tile_s * tr
+ sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
+ tr, vec_c, ts = sch.get_loops(block=rf2)[1:]
+ sch.reorder(ts, tr, vec_c)
+ sch.bind(ts, "threadIdx.x")
+ sch.bind(tr, "threadIdx.y")
+
+ # reduce tile_s * tr to tile_s
+ sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
+ tr, ts = sch.get_loops(block=gemv)[1:]
+ sch.reorder(ts, tr)
+ sch.bind(ts, "threadIdx.x")
+ sch.bind(tr, "threadIdx.y")
+
+ sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[2])
+ sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
+
+ sch.set_scope(rf, buffer_index=0, storage_scope="local")
+ sch.set_scope(rf2, buffer_index=0, storage_scope="local")
+
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf2)[3],
+ ann_key="pragma_auto_unroll_max_step",
+ ann_val=DEC_PACK,
+ )
+ sch.annotate(
Review Comment:
@Hzfengsy
As my understanding explicit unroll will apply unroll on whole loop and
default fallback value mentioned as ann_val =1.
Similarly it is used in inner reduction schedules.
Thanks for review
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]