Lunderberg commented on a change in pull request #9727:
URL: https://github.com/apache/tvm/pull/9727#discussion_r779831427



##########
File path: python/tvm/te/schedule.py
##########
@@ -519,9 +523,149 @@ def rolling_buffer(self):
         """
         _ffi_api.StageRollingBuffer(self)
 
+    def transform_layout(self, mapping_function: Callable[..., 
List[tvm.tir.PrimExpr]]):
+        """Defines the layout transformation for the current stage's tensor.
+
+        The map from initial_indices to final_indices must be an
+        invertible affine transformation.  This method may be called
+        more than once for a given tensor, in which case each
+        transformation is applied sequentially.
+
+        If the stage is a ComputeOp, then the iteration order of the
+        compute stage is rewritten to be a row-major traversal of the
+        tensor, and the new loop iteration variables are returned.
+        For all other stages, the loop iteration order is unmodified,
+        and the return value is None.
+
+        Parameters
+        ----------
+        mapping_function : Callable[..., List[tvm.tir.PrimExpr]]
+
+            A callable that accepts N arguments of type tvm.tir.Var,
+            and outputs a list of PrimExpr.  The input arguments
+            represent the location of a value in the current stage's
+            tensor, using the pre-transformation layout.  The return
+            value of the function gives the location of that value in
+            the current stage's tensor, using the post-transformation
+            layout.
+
+        Returns
+        -------
+        new_iter_vars : Optional[List[tvm.tir.IterVar]]
+
+            If the stage is a ComputeOp, then the return will be the
+            updated loop iteration variables over the data array, in
+            the same order as the output values from the
+            `mapping_function`.
+
+            Otherwise, the return value is None.
+
+        Examples
+        --------
+        .. code-block:: python
+
+            # ``A`` is a tensor whose compute definition is in NHWC
+            # format, and should be transformed into NCHWc format.
+
+            s[A].transform_layout(
+                lambda n,h,w,c: [n, c//4, h, w, c%4]
+            )
+
+
+        .. code-block:: python
+
+            # ``A`` is a tensor whose compute definition is in an
+            # arbitrary format, and should be transformed such that
+            # the last index is split, with the slower-changing index
+            # of the split placed at the slowest changing dimension.
+
+            s[A].transform_layout(
+                lambda *indices, i: [i//4, *indices, i%4]
+            )
+
+        .. code-block:: python
+
+            # ``B`` is a tensor defined by te.compute to be a copy of
+            # ``A`, and should be transformed such that ``B``'s layout
+            # is a transpose of ``A``'s layout.  The loop iteration
+            # that computes ``B`` will correspond to ``B``'s memory
+            # layout.
+
+            A = te.placeholder([n,m])
+            B = te.compute(A.shape, lambda i,j: A[i,j])
+            s = te.create_schedule(B.op)
+
+            s[B].transform_layout(lambda i,j: [j,i])
+
+        """
+
+        args = []
+        var_arg_name = None
+        kwargs = collections.OrderedDict()
+        default_index_dtype = "int32"
+
+        # Make a dummy variable for each explicitly named input index.
+        # We may have some keyword-only arguments, if the function has
+        # *args before the last argument.
+        params = inspect.signature(mapping_function).parameters
+        for name, param in params.items():
+            if param.kind in [
+                inspect.Parameter.POSITIONAL_ONLY,
+                inspect.Parameter.POSITIONAL_OR_KEYWORD,
+            ]:
+                args.append(tvm.tir.Var(name, default_index_dtype))
+
+            elif param.kind == inspect.Parameter.VAR_POSITIONAL:
+                var_arg_name = name
+
+            elif param.kind == inspect.Parameter.KEYWORD_ONLY:
+                kwargs[name] = tvm.tir.Var(name, default_index_dtype)
+
+            elif param.kind in [inspect.Parameter.VAR_KEYWORD]:
+                raise ValueError("transform_layout mapping may not have 
**kwargs")
+
+        ndim = len(self.op.output(0).shape)
+
+        # Now that all the named arguments have been collected,
+        # everything that remains should go to the *args, if
+        # specified.
+        if var_arg_name is not None:
+            num_var_args = ndim - len(args) - len(kwargs)
+            for i in range(num_var_args):
+                args.append(tvm.tir.Var(f"{var_arg_name}[{i}]", 
default_index_dtype))
+
+        initial_indices = args + list(kwargs.values())
+        if len(initial_indices) != ndim:
+            raise ValueError(
+                f"transform_layout mapping accepts {len(params)} initial 
indices, "
+                f"but {self.op.name} is {len(self.op.shape)}-dimensional"
+            )
+
+        mapping = mapping_function(*args, **kwargs)
+
+        final_indices = []
+        axis_separators = []
+        for val in mapping:
+            if isinstance(val, tvm.ir.PrimExpr):
+                final_indices.append(val)
+            elif val is AXIS_SEPARATOR:
+                axis_separators.append(len(final_indices))
+            else:
+                raise TypeError(
+                    "Expected mapping function to return list of "
+                    "either tvm.ir.PrimExpr or tvm.te.AXIS_SEPARATOR.  "
+                    "Instead received {val} of type {type(val)}."
+                )
+
+        new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices, 
final_indices)
+        _ffi_api.StageSetAxisSeparators(self, axis_separators)

Review comment:
       I think I'd prefer it as a single user-facing call, primarily for 
readability.  If the two are set in separate API calls, then there would be 
ambiguity in where the axis separators are placed, which would require the user 
to read the documentation to determine which convention is used.
   
   ```python
   # Without looking at the documentation, ambiguous whether [1] sets an
   # axis separator before axis 1 or after axis 1.
   s[A].transform_layout(lambda i,j: [i//4, j, i%4])
   s[A].set_axis_separators([1])
   
   # Without looking at the documentation, ambiguous whether [1] sets an
   # axis separator in the pre-transformation axes or the
   # post-transformation axes.
   s[A].set_axis_separators([1])
   s[A].transform_layout(lambda i,j: [j, i//4, i%4])
   
   # Setting both in the same user-facing API call implicitly tells the
   # user that the axis separator occurs between two specific axes, and
   # applies to the transformed axes.
   s[A].transform_layout(lambda i,j: [i//4, j, AXIS_SEPARATOR, i%4])
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to