vinx13 commented on a change in pull request #10538:
URL: https://github.com/apache/tvm/pull/10538#discussion_r828539700
##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -2111,6 +2112,82 @@ def after_unannotate(a: T.handle, b: T.handle) -> None:
self, block_or_loop, ann_key
)
+ ########## Schedule: Layout transformation ##########
+
+ @type_checked
+ def transform_layout(
+ self,
+ block: BlockRV,
+ buffer_index: int,
+ is_write_index: bool,
+ index_map: Union[IndexMap, Callable],
+ ) -> None:
+ """Apply a transformation represented by IndexMap to buffer
+ Parameters
+ ----------
+ block_rv : BlockRV
+ The block that accesses the target buffer
+ buffer_index: int
+ The index of the buffer in block's read or write region
+ is_write_index : bool
+ Whether the buffer_index is the index of the block's write region
+ index_map : Union[IndexMap, Callable]
+ The transformation to apply
+
+ Examples
+ --------
+ Before transform_layout, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_transform_layout(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (128, 128), "float32")
+ B = T.alloc_buffer((128, 128), "float32")
+ C = T.match_buffer(c, (128, 128), "float32")
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + 1.0
+
+ Create the schedule and do transform_layout:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_storage_align)
+ sch.transform_layout(sch.get_block("B"), buffer_index=0,
is_write_index=True,
+ index_map=lambda m, n: (m // 16, n // 16, m %
16, n % 16))
+ print(sch.mod["main"].script())
+
+ After applying transform_layout, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def two_elementwise_transformed_intermediate_buffer(a: T.handle,
c: T.handle) -> None:
+ A = T.match_buffer(a, (128, 128), "float32")
+ B = T.alloc_buffer((8, 8, 16, 16), "float32")
+ C = T.match_buffer(c, (128, 128), "float32")
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] *
2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] +
1.0
+
+ """
+ if callable(index_map):
+ index_map = IndexMap.from_func(index_map)
Review comment:
Though `self.get` can be used to inspect a block, it doesn't work well
with meta schedule tuning because `self.get` can't be traced by meta schedule.
Alternatively, we can allow an optional `ndim` parameter in
`IndexMap.from_func` and asks users to explicitly pass it, would that be
helpful?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]