Lunderberg commented on a change in pull request #10538:
URL: https://github.com/apache/tvm/pull/10538#discussion_r831192552
##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -2111,6 +2112,82 @@ def after_unannotate(a: T.handle, b: T.handle) -> None:
self, block_or_loop, ann_key
)
+ ########## Schedule: Layout transformation ##########
+
+ @type_checked
+ def transform_layout(
+ self,
+ block: BlockRV,
+ buffer_index: int,
+ is_write_index: bool,
+ index_map: Union[IndexMap, Callable],
+ ) -> None:
+ """Apply a transformation represented by IndexMap to buffer
+ Parameters
+ ----------
+ block_rv : BlockRV
+ The block that accesses the target buffer
+ buffer_index: int
+ The index of the buffer in block's read or write region
+ is_write_index : bool
+ Whether the buffer_index is the index of the block's write region
+ index_map : Union[IndexMap, Callable]
+ The transformation to apply
+
+ Examples
+ --------
+ Before transform_layout, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_transform_layout(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (128, 128), "float32")
+ B = T.alloc_buffer((128, 128), "float32")
+ C = T.match_buffer(c, (128, 128), "float32")
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + 1.0
+
+ Create the schedule and do transform_layout:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_storage_align)
+ sch.transform_layout(sch.get_block("B"), buffer_index=0,
is_write_index=True,
+ index_map=lambda m, n: (m // 16, n // 16, m %
16, n % 16))
+ print(sch.mod["main"].script())
+
+ After applying transform_layout, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def two_elementwise_transformed_intermediate_buffer(a: T.handle,
c: T.handle) -> None:
+ A = T.match_buffer(a, (128, 128), "float32")
+ B = T.alloc_buffer((8, 8, 16, 16), "float32")
+ C = T.match_buffer(c, (128, 128), "float32")
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] *
2.0
+ for i, j in T.grid(128, 128):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] +
1.0
+
+ """
+ if callable(index_map):
+ index_map = IndexMap.from_func(index_map)
Review comment:
I think I understand. So, the issue is that a user who passes `lambda
*indices: A[n, *indices]`, or some other transformation that uses varargs,
would expect it to apply to whatever buffer is present at that spot. However,
since `self.get` isn't recognized by the meta schedule, it would only apply to
buffers that have the same dimension as the untuned buffer. By having the user
specify the number of dimensions explicitly, it prevents the user from making
the false assumption. Is that correct?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]