Rainy-Memory opened a new pull request, #14164: URL: https://github.com/apache/tvm/pull/14164
This PR introduces MemHammer, which performs auto data movement in MetaSchedule. **This PR is not ready yet, TODO:** * Reformat Code * Add unittests * Pass CI This PR is a migration of https://github.com/Hzfengsy/asplos-tvm. Authored-by: Wuwei Lin [[email protected]](mailto:[email protected]) Authored-by: Junru Shao [[email protected]](mailto:[email protected]) Authored-by: Siyuan Feng [[email protected]](mailto:[email protected]) Authored-by: Ruihang Lai [[email protected]](mailto:[email protected]) Authored-by: Bohan Hou [[email protected]](mailto:[email protected]) Authored-by: Hongyi Jin [[email protected]](mailto:[email protected]) Given a data movement description like this: ``` @tvm.script.ir_module class GlobalToShared: @T.prim_func def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024, 1024]) B = T.match_buffer(b, [1024, 1024]) with T.block("root"): T.block_attr({"warp_execution": True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): with T.block(): A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", scope="shared.dyn") with T.block("A_shared"): T.block_attr({"auto_copy": 1, "vector_bytes": 16}) for ax0, ax1 in T.grid(128, 128): A_shared_dyn[ax0, ax1] = A[bx * 128 + ax0, by * 128 + ax1] with T.block("B"): for ax0, ax1 in T.grid(128, 128): B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] ``` By annotating the block with `T.block_attr({"auto_copy": 1})` and other optional arguments, it will be lowered to the following code with cooperative fetch, vectorize, and other specified features: ``` @tvm.script.ir_module class TransformedGlobalToShared: @T.prim_func def main(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [1024, 1024]) B = T.match_buffer(b, [1024, 1024]) with T.block("root"): T.block_attr({"warp_execution":True}) for bx in T.thread_binding(8, thread="blockIdx.x"): for by in T.thread_binding(8, thread="blockIdx.y"): for ty in T.thread_binding(8, thread="threadIdx.y"): with T.block(): A_shared_dyn = T.alloc_buffer([128, 128], dtype="float32", strides=[128, 1], scope="shared.dyn") with T.block("A_shared"): T.block_attr({"auto_copy":1, "vector_bytes":16}) for outer in T.serial(16): for ty_1 in T.thread_binding(8, thread="threadIdx.y"): for tx in T.thread_binding(32, thread="threadIdx.x"): for vec in T.vectorized(4): A_shared_dyn[(((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] = A[bx * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) // 128 % 128, by * 128 + (((outer * 8 + ty_1) * 32 + tx) * 4 + vec) % 128] with T.block("B"): for ax0, ax1 in T.grid(128, 128): B[bx * 128 + ax0, by * 128 + ax1] = A_shared_dyn[ax0, ax1] ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
