junrushao commented on code in PR #15490:
URL: https://github.com/apache/tvm/pull/15490#discussion_r1285375275
##########
python/tvm/dlight/gpu/gemv.py:
##########
@@ -195,106 +201,255 @@ def sch_inner_reduction( # pylint:
disable=too-many-arguments
epilogue_info: Optional[BlockInfo],
):
"""Schedule the inner reduction block."""
- # pylint: disable=invalid-name
- _, s, r, _ = sch.get_loops(block)
- # TODO: make it tunable
- vec_bytes = 16 if target.kind.name == "cuda" else 8
- unroll_number = 256 if target.kind.name == "cuda" else 64
+
+ def get_max_factor(n, factors):
+ factors = sorted(factors, reverse=True)
+ for factor in factors:
+ if n % factor == 0:
+ return factor
+
+ def apply(
+ sch: tir.Schedule,
+ gemv,
+ TAG_S,
+ TAG_R,
+ TS,
+ TR,
+ TILE_S,
+ TILE_R,
+ VEC_LOAD,
+ VEC_C,
+ LOAD_V_SHARED,
+ LOAD_V_VEC,
+ UNROLL,
+ ):
+ # rfactor: reduce to tx * vec_c
+ _, s, r, c = sch.get_loops(block=gemv)
+ s = sch.fuse(_, s)
+ r = sch.fuse(r, c)
+ bx, ts, tile_s = sch.split(s, factors=[None, TS, TILE_S],
preserve_unit_iters=True)
+ r, tr, tile_r_vec_n, vec_c = sch.split(
+ r, factors=[None, TR, TILE_R // VEC_C, VEC_C],
preserve_unit_iters=True
+ )
+ sch.reorder(r, tile_r_vec_n, tr, vec_c)
+ tr_vec_c = sch.fuse(tr, vec_c)
+ rf = sch.rfactor(tr_vec_c, 0)
+
+ # rfactor: reduce to tx
+ bx, ts, tile_s, tr_vec_c = sch.get_loops(block=gemv)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, None],
preserve_unit_iters=True)
+ rf2 = sch.rfactor(tr, 0)
+
+ # bind, vectorize compute
+ bx, ts, tile_s, r, tile_r_vec_n, tr_vec_c = sch.get_loops(block=rf)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, None],
preserve_unit_iters=True)
+ sch.reorder(bx, ts, tr, r, tile_s, tile_r_vec_n, vec_c)
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(ts, TAG_S)
+ sch.bind(tr, TAG_R)
+ sch.vectorize(vec_c)
+
+ # vectorize load A
+ # (TODO) this is now actually problematic since the number of
loops is dependent on the
+ # number of dimensions of A_q
+ Aq_local = sch.cache_read(rf, read_buffer_index=1,
storage_scope="local")
+ sch.compute_at(Aq_local, r, preserve_unit_loops=True)
+ s_local, r_local = sch.get_loops(block=Aq_local)[-2:]
+ s_local, vec_load = sch.split(
+ s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True
+ )
+ sch.reorder(s_local, r_local, vec_load) # either s_local or
r_local should be 1
+ sch.vectorize(vec_load)
+
+ # load vector into shared memory, shape should be the whole vector
+ if LOAD_V_SHARED:
+ V_shared = sch.cache_read(rf, read_buffer_index=0,
storage_scope="shared")
+ sch.compute_at(V_shared, tr, preserve_unit_loops=True)
+ l = sch.get_loops(block=V_shared)[-1]
+ loop: tir.For = sch.get(l)
+ if isinstance(loop.extent, tir.IntImm):
+ # avoid introducing predicates when vector length is too
large
+ vec_length = min(int(loop.extent) // TR // TS, LOAD_V_VEC)
+ else:
+ vec_length = LOAD_V_VEC
+ if TAG_R == "threadIdx.x":
+ _, ty, tx, vec = sch.split(
+ l, factors=[None, TS, TR, vec_length],
preserve_unit_iters=True
+ )
+ else:
+ _, ty, tx, vec = sch.split(
+ l, factors=[None, TR, TS, vec_length],
preserve_unit_iters=True
+ )
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec)
+
+ # reduce tile_s * tr * vec to tile_s * tr
+ sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
+ tr, vec_c, *ts_tile_s = sch.get_loops(block=rf2)[1:]
+ ts_tile_s = sch.fuse(*ts_tile_s)
+ ts, tile_s = sch.split(ts_tile_s, factors=[TS, None],
preserve_unit_iters=True)
+ tile_s, vec_s = sch.split(
+ tile_s,
+ factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])],
+ preserve_unit_iters=True,
+ )
+ sch.reorder(ts, tr, tile_s, vec_s, vec_c)
+ sch.bind(ts, TAG_S)
+ sch.bind(tr, TAG_R)
+ sch.vectorize(vec_s)
+
+ # reduce tile_s * tr to tile_s
+ sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
+ tr, *ts_tile_s = sch.get_loops(block=gemv)[1:]
+ ts_tile_s = sch.fuse(*ts_tile_s)
+ ts, tile_s = sch.split(ts_tile_s, factors=[TS, None],
preserve_unit_iters=True)
+ sch.reorder(tile_s, ts, tr)
+ sch.bind(ts, TAG_S)
+ sch.bind(tr, TAG_R)
+
+ sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[3])
+ sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
+
+ sch.set_scope(rf, buffer_index=0, storage_scope="local")
+ sch.set_scope(rf2, buffer_index=0, storage_scope="local")
+
+ unroll_factor = UNROLL
+
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf)[3],
+ ann_key="pragma_auto_unroll_max_step",
+ ann_val=unroll_factor,
+ )
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf)[3],
ann_key="pragma_unroll_explicit", ann_val=1
+ )
+
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf2)[3],
+ ann_key="pragma_auto_unroll_max_step",
+ ann_val=unroll_factor,
+ )
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf2)[3],
ann_key="pragma_unroll_explicit", ann_val=1
+ )
+
+ if LOAD_V_SHARED:
+ sch.annotate(
+ block_or_loop=sch.get_loops(V_shared)[-4],
+ ann_key="pragma_unroll_explicit",
+ ann_val=unroll_factor,
+ )
+ sch.annotate(
+ block_or_loop=sch.get_loops(V_shared)[-4],
ann_key="pragma_vectorize", ann_val=1
+ )
+
+ # Schedule epilogue
+ if epilogue_info is not None:
+ epilogue = epilogue_info.block_rv
+ if is_broadcast_epilogue(sch, block, epilogue):
+ sch.reverse_compute_at(epilogue, bx)
+ sch.set_scope(block, 0, "shared")
+ _, _, *s = sch.get_loops(epilogue) # pylint:
disable=invalid-name
+ _, tx = sch.split(sch.fuse(*s), factors=[None, TX])
+ sch.bind(tx, "threadIdx.x")
+ else:
+ sch.reverse_compute_at(epilogue, bx,
preserve_unit_loops=True)
+ ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:])
+ ts_tile_s = sch.get_loops(epilogue)[-1]
+ ts, tile_s = sch.split(ts_tile_s, factors=[TS, None],
preserve_unit_iters=True)
+ sch.bind(ts, TAG_S)
+ sch.set_scope(block, 0, "local")
+ # pylint: enable=invalid-name
+ return sch
def get_extent(loop_rv: tir.schedule.LoopRV):
loop: tir.For = sch.get(loop_rv)
- return loop.extent.value if isinstance(loop.extent, tir.IntImm)
else 1
+ return loop.extent.value if isinstance(loop.extent, tir.IntImm)
else loop.extent
# Specify the `len_tx` and `len_ty` according to the loop extent
- len_s, len_r = get_extent(s), get_extent(r)
- if len_r >= 4096 and len_r % 128 == 0:
- len_tx = 128
- elif 1024 < len_r <= 2048 and len_r % 64 == 0:
- len_tx = 64
- else:
- len_tx = 32
-
- if len_s >= 4096:
- len_ty = 8
- else:
- len_ty = min(len_s, 4)
-
- # Use `split_k` to prevent too large shared memory usage
- split_k: int = 4
-
- _, tx = sch.split(r, [None, len_tx], preserve_unit_iters=True)
- # Schedule the RF block
- rf = sch.rfactor(tx, 0)
- batch, bx, r, tx, _ = sch.get_loops(rf)
- sch.reorder(bx, tx, r)
- ro, ri = sch.split(r, [split_k, None], preserve_unit_iters=True)
- bx, ty = sch.split(bx, [None, len_ty], preserve_unit_iters=True)
-
- sch.bind(batch, "blockIdx.y")
- sch.bind(bx, "blockIdx.x")
- sch.bind(ty, "threadIdx.y")
- sch.bind(tx, "threadIdx.x")
- sch.annotate(ro, "pragma_auto_unroll_max_step", unroll_number)
- sch.annotate(ro, "pragma_unroll_explicit", 1)
-
+ batch, s, r, c = sch.get_loops(block=block)
+ len_batch, len_s, len_r, len_c = (
+ get_extent(batch),
+ get_extent(s),
+ get_extent(r),
+ get_extent(c),
+ )
+ len_S = len_batch * len_s
+ len_R = len_r * len_c
+
+ TAG_S, TAG_R = "threadIdx.y", "threadIdx.x"
if target.kind.name == "cuda":
Review Comment:
The dlight folder structure assumes `dlight.gpu` for generic GPU schedules,
while `dlight.cuda` for cuda-specific ones. Is it possible to refactor the
arch-specific code into their own folders accordingly, or perhaps it's not a
good design?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]