Lunderberg commented on a change in pull request #9727:
URL: https://github.com/apache/tvm/pull/9727#discussion_r770107423
##########
File path: python/tvm/te/schedule.py
##########
@@ -519,9 +523,149 @@ def rolling_buffer(self):
"""
_ffi_api.StageRollingBuffer(self)
+ def transform_layout(self, mapping_function: Callable[...,
List[tvm.tir.PrimExpr]]):
+ """Defines the layout transformation for the current stage's tensor.
+
+ The map from initial_indices to final_indices must be an
+ invertible affine transformation. This method may be called
+ more than once for a given tensor, in which case each
+ transformation is applied sequentially.
+
+ If the stage is a ComputeOp, then the iteration order of the
+ compute stage is rewritten to be a row-major traversal of the
+ tensor, and the new loop iteration variables are returned.
+ For all other stages, the loop iteration order is unmodified,
+ and the return value is None.
+
+ Parameters
+ ----------
+ mapping_function : Callable[..., List[tvm.tir.PrimExpr]]
+
+ A callable that accepts N arguments of type tvm.tir.Var,
+ and outputs a list of PrimExpr. The input arguments
+ represent the location of a value in the current stage's
+ tensor, using the pre-transformation layout. The return
+ value of the function gives the location of that value in
+ the current stage's tensor, using the post-transformation
+ layout.
+
+ Returns
+ -------
+ new_iter_vars : Optional[List[tvm.tir.IterVar]]
+
+ If the stage is a ComputeOp, then the return will be the
+ updated loop iteration variables over the data array, in
+ the same order as the output values from the
+ `mapping_function`.
+
+ Otherwise, the return value is None.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ # ``A`` is a tensor whose compute definition is in NHWC
+ # format, and should be transformed into NCHWc format.
+
+ s[A].transform_layout(
+ lambda n,h,w,c: [n, c//4, h, w, c%4]
+ )
+
+
+ .. code-block:: python
+
+ # ``A`` is a tensor whose compute definition is in an
+ # arbitrary format, and should be transformed such that
+ # the last index is split, with the slower-changing index
+ # of the split placed at the slowest changing dimension.
+
+ s[A].transform_layout(
+ lambda *indices, i: [i//4, *indices, i%4]
Review comment:
I'm cheating a little bit here to avoid delaying the production of the
IndexMap. Since the dimensionality of the stage `A` is known, the `IndexMap`
corresponding to the lambda function can be created immediately, and doesn't
require waiting until later on in the build/lower step. (e.g. if
`len(A.shape)==3`, then this call would generate an `IndexMap` with
`initial_indices = [indices0, indices1, i]`, and `final_indices = [i//4,
indices0, indices1, i%4]`). The same lambda function could produce different
`IndexMap` objects depending on if it is passed to `s[A].transform_layout` or
`s[B].transform_layout`, rather than producing a single
That said, if I'm wrong about the implications and there are some edge cases
that require delaying the `IndexMap` creation, I definitely don't want to
require tracking a partially made `IndexMap` through some of the lowering
steps. As it was, this seemed the cleanest way to expose the Python API, and
had similarities to how `te.compute`'s interface works.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]