krishnaraj36 commented on code in PR #16932:
URL: https://github.com/apache/tvm/pull/16932#discussion_r1582557756


##########
python/tvm/dlight/gpu/gemv.py:
##########
@@ -551,6 +560,192 @@ def sch_outer_reduction(  # pylint: 
disable=too-many-arguments, invalid-name, un
         block: tir.schedule.BlockRV,
         vector_input_buffers: List[tir.Buffer],
         epilogue_info: Optional[BlockInfo],
+    ):
+        """Schedule the inner reduction block."""
+
+        def get_max_factor(n, factors):
+            factors = sorted(factors, reverse=True)
+            for factor in factors:
+                if n % factor == 0:
+                    return factor
+            return 1
+
+        def apply(
+            sch: tir.Schedule,
+            gemv,
+            TAG_S,
+            TAG_R,
+            TS,
+            TR,
+            SCALE_PACK,
+            DEC_PACK,
+            VEC_LOAD,
+            VEC_C,
+            LOAD_V_SHARED,
+            LOAD_V_VEC,
+            UNROLL,
+            LOAD_V_TILE,
+        ):
+            # rfactor: reduce to tx * vec_c
+            batch, s, r, c = sch.get_loops(block=gemv)
+            s = sch.fuse(batch, s)
+            r = sch.fuse(r, c)
+            bx, ts = sch.split(s, factors=[None, TS], preserve_unit_iters=True)
+            r, v_tile, tr, tile_r, vec_c = sch.split(
+                r, factors=[None, LOAD_V_TILE, TR, SCALE_PACK, DEC_PACK], 
preserve_unit_iters=True
+            )
+            sch.reorder(bx, ts, r, v_tile, tile_r, tr, vec_c)
+            tr_vec_c = sch.fuse(tr, vec_c)
+            rf = sch.rfactor(tr_vec_c, 0)
+
+            # rfactor: reduce to tx
+            bx, ts, tr_vec_c = sch.get_loops(block=gemv)
+            tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], 
preserve_unit_iters=True)
+            rf2 = sch.rfactor(tr, 0)
+
+            # bind, vectorize compute
+            bx, ts, r, v_tile, tile_r, tr_vec_c = sch.get_loops(block=rf)
+            tr, vec_c = sch.split(tr_vec_c, factors=[TR, DEC_PACK])
+            sch.reorder(bx, ts, tr, r, v_tile, tile_r, vec_c)
+            # sch.bind(batch, "blockIdx.z")
+            sch.bind(bx, "blockIdx.x")
+            sch.bind(ts, "threadIdx.x")
+            sch.bind(tr, "threadIdx.y")
+            sch.vectorize(vec_c)
+
+            # decompose independent scale read to outer loop
+            block_rf_stmt = sch.get(rf)
+            if len(block_rf_stmt.reads) >= 3:
+                As_local = sch.cache_read(rf, read_buffer_index=2, 
storage_scope="local")
+                sch.compute_at(As_local, v_tile, preserve_unit_loops=True)
+                # *tile_thr, vec_s = sch.get_loops(block=As_local)
+                # sch.vectorize(vec_s)
+
+            Aq_local = sch.cache_read(rf, read_buffer_index=1, 
storage_scope="local")
+            sch.compute_at(Aq_local, tile_r, preserve_unit_loops=True)
+            # *tile_thr, vec_s = sch.get_loops(block=Aq_local)
+            # sch.vectorize(vec_s)
+
+            if LOAD_V_SHARED:
+                V_shared = sch.cache_read(rf, read_buffer_index=0, 
storage_scope="shared")
+                sch.compute_at(V_shared, r, preserve_unit_loops=True)
+                l = sch.get_loops(block=V_shared)[-1]
+                _, v_tile, tx, ty, vec = sch.split(
+                    l, factors=[None, LOAD_V_TILE, TS, TR, LOAD_V_VEC], 
preserve_unit_iters=True
+                )
+                sch.bind(ty, "threadIdx.y")
+                sch.bind(tx, "threadIdx.x")
+                sch.vectorize(vec)
+
+            # reduce tile_s * tr * vec to tile_s * tr
+            sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
+            tr, vec_c, ts = sch.get_loops(block=rf2)[1:]
+            sch.reorder(ts, tr, vec_c)
+            sch.bind(ts, "threadIdx.x")
+            sch.bind(tr, "threadIdx.y")
+
+            # reduce tile_s * tr to tile_s
+            sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
+            tr, ts = sch.get_loops(block=gemv)[1:]
+            sch.reorder(ts, tr)
+            sch.bind(ts, "threadIdx.x")
+            sch.bind(tr, "threadIdx.y")
+
+            sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[2])
+            sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
+
+            sch.set_scope(rf, buffer_index=0, storage_scope="local")
+            sch.set_scope(rf2, buffer_index=0, storage_scope="local")
+
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf2)[3],
+                ann_key="pragma_auto_unroll_max_step",
+                ann_val=DEC_PACK,
+            )
+            sch.annotate(
+                block_or_loop=sch.get_loops(rf2)[3], 
ann_key="pragma_unroll_explicit", ann_val=1
+            )
+
+            # Schedule epilogue
+            if epilogue_info is not None:
+                epilogue = epilogue_info.block_rv
+                if is_broadcast_epilogue(sch, block, epilogue):
+                    sch.reverse_compute_at(epilogue, bx)
+                    sch.set_scope(block, 0, "shared")
+                    _, _, *s = sch.get_loops(epilogue)  # pylint: 
disable=invalid-name
+                    _, tx = sch.split(sch.fuse(*s), factors=[None, TX])
+                    sch.bind(tx, "threadIdx.x")
+                else:
+                    sch.reverse_compute_at(epilogue, bx, 
preserve_unit_loops=True)
+                    ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:])
+                    ts_tile_s = sch.get_loops(epilogue)[-1]
+                    ts, _ = sch.split(ts_tile_s, factors=[TS, None], 
preserve_unit_iters=True)
+                    sch.bind(ts, "threadIdx.x")
+                    sch.set_scope(block, 0, "local")
+            return sch
+            # return sch.mod["main"].with_attr("tir.is_scheduled", 1)

Review Comment:
   Removed unwanted code, thanks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to