Lunderberg commented on a change in pull request #10538:
URL: https://github.com/apache/tvm/pull/10538#discussion_r832492587
##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -2111,6 +2112,82 @@ def after_unannotate(a: T.handle, b: T.handle) -> None:
self, block_or_loop, ann_key
)
+ ########## Schedule: Layout transformation ##########
+
+ @type_checked
+ def transform_layout(
+ self,
+ block: BlockRV,
+ buffer_index: int,
+ is_write_index: bool,
+ index_map: Union[IndexMap, Callable],
+ ) -> None:
+ """Apply a transformation represented by IndexMap to buffer
+ Parameters
+ ----------
+ block_rv : BlockRV
+ The block that accesses the target buffer
+ buffer_index: int
+ The index of the buffer in block's read or write region
+ is_write_index : bool
+ Whether the buffer_index is the index of the block's write region
+ index_map : Union[IndexMap, Callable]
+ The transformation to apply
+
+ Examples
+ --------
+ Before transform_layout, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_transform_layout(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (128, 128), "float32")
+ B = T.alloc_buffer((128, 128), "float32")
+ C = T.match_buffer(c, (128, 128), "float32")
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + 1.0
+
+ Create the schedule and do transform_layout:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_storage_align)
+ sch.transform_layout(sch.get_block("B"), buffer_index=0,
is_write_index=True,
+ index_map=lambda m, n: (m // 16, n // 16, m %
16, n % 16))
+ print(sch.mod["main"].script())
+
+ After applying transform_layout, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def two_elementwise_transformed_intermediate_buffer(a: T.handle,
c: T.handle) -> None:
+ A = T.match_buffer(a, (128, 128), "float32")
+ B = T.alloc_buffer((8, 8, 16, 16), "float32")
+ C = T.match_buffer(c, (128, 128), "float32")
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] *
2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] +
1.0
+
+ """
+ if callable(index_map):
+ index_map = IndexMap.from_func(index_map)
Review comment:
That makes sense, and so I'd lean toward avoiding varargs for now.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]