This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 814f5501bf [TIR][Schedule] Transform layout quality of life (#11269)
814f5501bf is described below

commit 814f5501bf7d65f759135d214572388b0ddadefc
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed May 25 19:19:33 2022 -0500

    [TIR][Schedule] Transform layout quality of life (#11269)
    
    * [TIR][Schedule] Added Schedule.transform_layout_sugared
    
    * [TE][TIR] Reduced duplication in TE/TIR layout transformations
    
    Previously, the implementations of `tir.IndexMap.from_func` and
    `te.Stage.transform_layout` had significant duplication to handle
    argument parsing.  This commit extracts the shared logic into
    `tir.IndexMap`.
    
    * Enabled *args in Schedule.transform_layout_sugared
    
    * Fix lint error
    
    * Allow Schedule.transform_layout_sugared to set axis separators
    
    * Merged transform_layout_sugared functionality into transform_layout
    
    * Fix lint errors
    
    * Fix lint error
    
    * Fixed docstring errors
    
    * Updated/tested TransformatLayoutTraits::UnpackedAsPython
    
    * Disabled exec-used check for running trace.as_python()
    
    * Updated SetAxisSeparatorTraits::UnpackedAsPython
    
    * Updated unit test that was added in merge commit
    
    * Fixed the argument name for TensorizeTraits
    
    This wasn't checked before, but was the only other issue caught by the
    updates to verify_trace_roundtrip.
    
    * Re-enable type checks of transform_layout/set_axis_separator
    
    Disabled while waiting for https://github.com/apache/tvm/pull/11289,
    which was required for the `Tuple` argument.
    
    * Updated a few additional transform_layout usages from main
---
 python/tvm/te/schedule.py                          |  70 ++-------
 python/tvm/tir/function.py                         | 103 ++++++++++++-
 python/tvm/tir/schedule/schedule.py                | 165 +++++++++++++++++----
 python/tvm/tir/schedule/testing.py                 |  30 +++-
 src/tir/schedule/primitive/blockize_tensorize.cc   |   2 +-
 .../schedule/primitive/layout_transformation.cc    |  22 +--
 .../test_tir_schedule_set_axis_separator.py        |  40 +++--
 .../test_tir_schedule_tensorize_ldmatrix_mma.py    |   6 +-
 .../unittest/test_tir_schedule_transform_layout.py |  79 ++++++++--
 9 files changed, 385 insertions(+), 132 deletions(-)

diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py
index fdd08f9208..50f9a22ec2 100644
--- a/python/tvm/te/schedule.py
+++ b/python/tvm/te/schedule.py
@@ -25,7 +25,7 @@ from tvm._ffi.base import string_types
 
 from tvm.runtime import Object, convert
 from tvm.ir import container as _container
-from tvm.tir import IterVar, Buffer, Var
+from tvm.tir import IterVar, Buffer, Var, IndexMap
 
 from . import tensor as _tensor
 from . import _ffi_api
@@ -599,65 +599,12 @@ class Stage(Object):
 
         """
 
-        args = []
-        var_arg_name = None
-        kwargs = collections.OrderedDict()
-        default_index_dtype = "int32"
-
-        # Make a dummy variable for each explicitly named input index.
-        # We may have some keyword-only arguments, if the function has
-        # *args before the last argument.
-        params = inspect.signature(mapping_function).parameters
-        for name, param in params.items():
-            if param.kind in [
-                inspect.Parameter.POSITIONAL_ONLY,
-                inspect.Parameter.POSITIONAL_OR_KEYWORD,
-            ]:
-                args.append(tvm.tir.Var(name, default_index_dtype))
-
-            elif param.kind == inspect.Parameter.VAR_POSITIONAL:
-                var_arg_name = name
-
-            elif param.kind == inspect.Parameter.KEYWORD_ONLY:
-                kwargs[name] = tvm.tir.Var(name, default_index_dtype)
-
-            elif param.kind in [inspect.Parameter.VAR_KEYWORD]:
-                raise ValueError("transform_layout mapping may not have 
**kwargs")
-
         ndim = len(self.op.output(0).shape)
+        index_map, axis_separators = 
IndexMap.from_func_with_separators(mapping_function, ndim=ndim)
 
-        # Now that all the named arguments have been collected,
-        # everything that remains should go to the *args, if
-        # specified.
-        if var_arg_name is not None:
-            num_var_args = ndim - len(args) - len(kwargs)
-            for i in range(num_var_args):
-                args.append(tvm.tir.Var(f"{var_arg_name}[{i}]", 
default_index_dtype))
-
-        initial_indices = args + list(kwargs.values())
-        if len(initial_indices) != ndim:
-            raise ValueError(
-                f"transform_layout mapping accepts {len(params)} initial 
indices, "
-                f"but {self.op.name} is {len(self.op.shape)}-dimensional"
-            )
-
-        mapping = mapping_function(*args, **kwargs)
-
-        final_indices = []
-        axis_separators = []
-        for val in mapping:
-            if isinstance(val, tvm.ir.PrimExpr):
-                final_indices.append(val)
-            elif val is AXIS_SEPARATOR:
-                axis_separators.append(len(final_indices))
-            else:
-                raise TypeError(
-                    "Expected mapping function to return list of "
-                    "either tvm.ir.PrimExpr or tvm.te.AXIS_SEPARATOR.  "
-                    "Instead received {val} of type {type(val)}."
-                )
-
-        new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices, 
final_indices)
+        new_iter_vars = _ffi_api.StageTransformLayout(
+            self, index_map.initial_indices, index_map.final_indices
+        )
         _ffi_api.StageSetAxisSeparators(self, axis_separators)
 
         return new_iter_vars or None
@@ -700,9 +647,10 @@ class SpecializedCondition(Object):
 
 
 # Sentinel value used to indicate which groups of pre-flattening axes
-# should be used to post-flattening axes axes.  See
-# Stage.transform_layout for more details.
-AXIS_SEPARATOR = "axis_separator"
+# should be used to post-flattening axes axes.  Moved from
+# te.AXIS_SEPARATOR to tir.IndexMap.AXIS_SEPARATOR for general use,
+# maintained here for backwards compatibility.
+AXIS_SEPARATOR = IndexMap.AXIS_SEPARATOR
 
 
 tvm._ffi._init_api("schedule", __name__)
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index d84513e072..a921c5b9fc 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -16,8 +16,9 @@
 # under the License.
 """Function data types."""
 
-from typing import Callable, List, Mapping, Optional, Union, Tuple
+import collections
 import inspect
+from typing import Callable, List, Mapping, Optional, Union, Tuple
 
 import tvm
 import tvm._ffi
@@ -258,6 +259,11 @@ class IndexMap(Object):
     initial_indices: List[Var]
     final_indices: List[PrimExpr]
 
+    # Sentinel value used to indicate which groups of pre-flattening axes
+    # should be used to post-flattening axes axes.  See
+    # Stage.transform_layout for more details.
+    AXIS_SEPARATOR = "axis_separator"
+
     def __init__(self, initial_indices, final_indices):
         self.__init_handle_by_constructor__(_ffi_api.IndexMap, 
initial_indices, final_indices)
 
@@ -268,34 +274,117 @@ class IndexMap(Object):
         Parameters
         ----------
         mapping_function : Callable
-            The function to map from source indices to target indices
+
+            The function to map from source indices to target indices.
+            The function should accept `tir.Var` parameters and return
+            a list. Each element of the returned list should be a
+            `tir.PrimExpr`.
+
+        ndim: Optional[int]
+
+            The dimensionality of the buffer to which this
+            transformation should be applied.  If mapping_function uses
+            variadic argument `*args`, `ndim` must be specified.  If
+            mapping_function does not use variadic arguments, ndim is
+            optional.
+
+        Returns
+        -------
+        index_map: IndexMap
+
+            Returns an IndexMap representing the `mapping_function`.
+
+        """
+        index_map, axis_separators = 
IndexMap.from_func_with_separators(mapping_function, ndim)
+        assert not axis_separators, (
+            "The mapping_function provided to IndexMap.from_func "
+            "may not return IndexMap.AXIS_SEPARATOR.  "
+            "If required, please use IndexMap.from_func_with_separators 
instead."
+        )
+        return index_map
+
+    @staticmethod
+    def from_func_with_separators(mapping_function: Callable, ndim: 
Optional[int] = None):
+        """Create an index map from a function
+
+        Parameters
+        ----------
+        mapping_function : Callable
+
+            The function to map from source indices to target indices.
+            The function should accept tir.Var parameters and return a
+            list. Each element of the returned list should be either a
+            `tir.PrimExpr` or the object `IndexMap.AXIS_SEPARATOR`.
+
+        ndim: Optional[int]
+
+            The dimensionality of the buffer to which this
+            transformation should be applied.  If mapping_function uses
+            variadic argument `*args`, ndim must be specified.  If
+            mapping_function does not use variadic arguments, ndim is
+            optional.
+
+        Returns
+        -------
+        ret: Tuple[IndexMap, List[int]]
+
+            Returns a tuple whose first element is an IndexMap
+            representing the `mapping_function`, and whose second index
+            is a list of indices at which `IndexMap.AXIS_SEPARATOR`
+            occurred.
+
         """
         params = inspect.signature(mapping_function).parameters
-        default_index_dtype = "int32"
+
         args = []
         var_arg_name = None
+        kwargs = collections.OrderedDict()
+        default_index_dtype = "int32"
+
         for name, param in params.items():
             if param.kind in [
                 inspect.Parameter.POSITIONAL_ONLY,
                 inspect.Parameter.POSITIONAL_OR_KEYWORD,
             ]:
                 args.append(tvm.tir.Var(name, default_index_dtype))
+
             elif param.kind == inspect.Parameter.VAR_POSITIONAL:
                 var_arg_name = name
+
+            elif param.kind == inspect.Parameter.KEYWORD_ONLY:
+                kwargs[name] = tvm.tir.Var(name, default_index_dtype)
+
             else:
-                raise ValueError("transform_layout mapping may not have *args 
or **kwargs")
+                raise ValueError("transform_layout mapping may not have *args")
 
         # Now that all the named arguments have been collected,
         # everything that remains should go to the *args, if
         # specified.
         if var_arg_name is not None:
             assert ndim is not None, "ndim must be specified when *args is 
used"
-            num_var_args = ndim - len(args)
+            num_var_args = ndim - len(args) - len(kwargs)
             for i in range(num_var_args):
                 args.append(tvm.tir.Var(f"{var_arg_name}_{i}", 
default_index_dtype))
 
-        final_indices = mapping_function(*args)
-        return IndexMap(args, final_indices)
+        mapping = mapping_function(*args, **kwargs)
+
+        initial_indices = args + list(kwargs.values())
+
+        final_indices = []
+        axis_separators = []
+        for val in mapping:
+            if isinstance(val, tvm.ir.PrimExpr):
+                final_indices.append(val)
+            elif val is IndexMap.AXIS_SEPARATOR:
+                axis_separators.append(len(final_indices))
+            else:
+                raise TypeError(
+                    "Expected mapping function to return list of "
+                    "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR.  "
+                    "Instead received {val} of type {type(val)}."
+                )
+
+        return IndexMap(initial_indices, final_indices), axis_separators
 
     def is_equivalent_to(self, other_map: "IndexMap") -> bool:
         """Return if the index maps are equivalent.
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 6474ba0baa..dc687b1eae 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -15,13 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 """The TensorIR schedule class"""
-from typing import Callable, Dict, List, Optional, Union
+from typing import Callable, Dict, List, Optional, Union, Tuple
 
 from tvm._ffi import register_object as _register_object
 from tvm.error import TVMError, register_error
 from tvm.ir import IRModule, PrimExpr
 from tvm.runtime import Object, String
-from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc
+from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, Buffer
 from ..function import IndexMap
 
 from . import _ffi_api
@@ -2114,25 +2114,111 @@ class Schedule(Object):
 
     ########## Schedule: Layout transformation ##########
 
+    def _normalize_block_arg(self, block: Union[BlockRV, str]) -> BlockRV:
+        if isinstance(block, str):
+            return self.get_block(block)
+
+        return block
+
+    def _normalize_buffer_arg(
+        self, block: BlockRV, buffer: Union[Tuple[str, int], str, Buffer]
+    ) -> Tuple[str, int, Buffer]:
+
+        block_name = self.get(block).name_hint
+
+        def iter_buffers():
+            block_obj = self.get(block)
+            for i, read in enumerate(block_obj.reads):
+                yield "read", i, read.buffer
+            for i, write in enumerate(block_obj.writes):
+                yield "write", i, write.buffer
+
+        if isinstance(buffer, str):
+            possible_buffers = {}
+            # String lookup requires ensuring that the name is unique
+            for buffer_index, buffer_index_type, buf in iter_buffers():
+                if buf.name == buffer:
+                    possible_buffers[buf] = (buffer_index_type, buffer_index)
+
+            assert possible_buffers, f"Could not find buffer '{buffer}' in 
block '{block_name}'"
+            assert (
+                len(possible_buffers) == 1
+            ), f"Multiple buffers named '{buffer}' in block '{block_name}'"
+            buffer_obj, (buffer_index, buffer_index_type) = 
next(iter(possible_buffers.items()))
+
+        elif isinstance(buffer, Buffer):
+            # Buffer lookup has unique id, can break out early
+            found = False
+            for buffer_index, buffer_index_type, buffer_obj in iter_buffers():
+                if buffer_obj.same_as(buffer):
+                    found = True
+                    break
+
+            assert found, "Could not find buffer '{buffer.name}' in block 
'{block_name}'"
+
+        elif isinstance(buffer, tuple):
+            buffer_index_type, buffer_index = buffer
+            assert buffer_index_type in ["read", "write",], (
+                f"Invalid buffer_index_type.  "
+                f"Expected 'read' or 'write', "
+                f"but received {buffer_index_type}"
+            )
+            buffer_list = (
+                self.get(block).reads if buffer_index_type == "read" else 
self.get(block).writes
+            )
+            assert 0 <= buffer_index < len(buffer_list), (
+                f"Invalid buffer_index {buffer_index}.  "
+                f"Block {block_name} has only "
+                f"{len(buffer_list)} {buffer_index_type} buffers."
+            )
+            buffer_obj = buffer_list[buffer_index].buffer
+
+        else:
+            raise TypeError(f"Invalid type for argument 'buffer': 
{type(buffer)}")
+
+        return (buffer_index_type, buffer_index, buffer_obj)
+
     @type_checked
     def transform_layout(
         self,
-        block: BlockRV,
-        buffer_index: int,
-        buffer_index_type: str,
+        block: Union[BlockRV, str],
+        buffer: Union[Tuple[str, int], str, Buffer],
         index_map: Union[IndexMap, Callable],
     ) -> None:
         """Apply a transformation represented by IndexMap to buffer
+
         Parameters
         ----------
-        block : BlockRV
-            The block that accesses the target buffer
-        buffer_index: int
-            The index of the buffer in block's read or write region
-        buffer_index_type : str
-            Type of the buffer index, "read" or "write"
+        block : Union[BlockRV, str]
+
+            The block that accesses the target buffer.  If a string,
+            this must uniquely identify a block.
+
+        buffer: Union[Tuple[str,int], Buffer, str]
+
+            The buffer to be transformed, or a specification of how to
+            identify the buffer to be transformed.
+
+            If `buffer` if a tuple of ``(str,int)``, the first item
+            should be either "read" or "write", and the second item is
+            an index into the block's read or write regions.
+
+            If `buffer` is a string, it is the name of the buffer,
+            which must exist within the reads/writes of the block.  In
+            addition, the reads/writes of the block may not contain
+            more than one buffer with this name.
+
+            If `buffer` is a Buffer object, it must exist within the
+            reads/writes of the block.
+
         index_map : Union[IndexMap, Callable]
-            The transformation to apply
+
+            The transformation to apply.
+
+            If `index_map` is a callable, and the returned list
+            contains IndexMap.AXIS_SEPARATOR, the SetAxisSeparators
+            primitive will be called in addition to the
+            TransformLayout primitive.
 
         Examples
         --------
@@ -2159,7 +2245,7 @@ class Schedule(Object):
         .. code-block:: python
 
             sch = tir.Schedule(before_storage_align)
-            sch.transform_layout(sch.get_block("B"), buffer_index=0, "write",
+            sch.transform_layout(sch.get_block("B"), buffer=("write",0),
                                  index_map=lambda m, n: (m // 16, n // 16, m % 
16, n % 16))
             print(sch.mod["main"].script())
 
@@ -2182,20 +2268,29 @@ class Schedule(Object):
                         C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 
1.0
 
         """
+        block = self._normalize_block_arg(block)
+        buffer_index_type, buffer_index, buffer_obj = 
self._normalize_buffer_arg(block, buffer)
+
+        ndim = len(buffer_obj.shape)
         if callable(index_map):
-            index_map = IndexMap.from_func(index_map)
-        assert buffer_index_type in ["read", "write"], "Invalid 
buffer_index_type"
+            index_map, axis_separators = 
IndexMap.from_func_with_separators(index_map, ndim=ndim)
+        else:
+            axis_separators = []
+
         buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
         _ffi_api.ScheduleTransformLayout(  # type: ignore # pylint: 
disable=no-member
             self, block, buffer_index, buffer_index_type_enum, index_map
         )
+        if axis_separators:
+            _ffi_api.ScheduleSetAxisSeparator(  # type: ignore # pylint: 
disable=no-member
+                self, block, buffer_index, buffer_index_type_enum, 
axis_separators
+            )
 
     @type_checked
     def set_axis_separator(
         self,
-        block: BlockRV,
-        buffer_index: int,
-        buffer_index_type: str,
+        block: Union[BlockRV, str],
+        buffer: Union[Tuple[str, int], str, Buffer],
         axis_separators: Optional[List[int]],
     ) -> None:
         """Set the axis separator of a buffer, where the buffer is specified 
by a block and a read
@@ -2203,13 +2298,30 @@ class Schedule(Object):
 
         Parameters
         ----------
-        block : BlockRV
-            The block that accesses the target buffer
-        buffer_index: int
-            The index of the buffer in block's read or write region
-        buffer_index_type : str
-            Type of the buffer index, "read" or "write"
+        block : Union[BlockRV, str]
+
+            The block that accesses the target buffer.  If a string,
+            this must uniquely identify a block.
+
+        buffer: Union[Tuple[str,int], Buffer, str]
+
+            The buffer to be transformed, or a specification of how to
+            identify the buffer to be transformed.
+
+            If `buffer` if a tuple of ``(str,int)``, the first item
+            should be either "read" or "write", and the second item is
+            an index into the block's read or write regions.
+
+            If `buffer` is a string, it is the name of the buffer,
+            which must exist within the reads/writes of the block.  In
+            addition, the reads/writes of the block may not contain
+            more than one buffer with this name.
+
+            If `buffer` is a Buffer object, it must exist within the
+            reads/writes of the block.
+
         axis_separators : Optional[List[int]]
+
             The axis separators.
 
         Examples
@@ -2263,7 +2375,10 @@ class Schedule(Object):
                         C[vi, vj] = B[vi, vj] + T.float32(1)
         """
         axis_separators = axis_separators or []
-        assert buffer_index_type in ["read", "write"], "Invalid 
buffer_index_type"
+
+        block = self._normalize_block_arg(block)
+        buffer_index_type, buffer_index, _ = self._normalize_buffer_arg(block, 
buffer)
+
         buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
         _ffi_api.ScheduleSetAxisSeparator(  # type: ignore # pylint: 
disable=no-member
             self, block, buffer_index, buffer_index_type_enum, axis_separators
diff --git a/python/tvm/tir/schedule/testing.py 
b/python/tvm/tir/schedule/testing.py
index 04cbffcd4d..3689f756e8 100644
--- a/python/tvm/tir/schedule/testing.py
+++ b/python/tvm/tir/schedule/testing.py
@@ -15,8 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 """Testing utilities for the TensorIR schedule API"""
-from typing import Union
+from typing import Union, Sequence
 
+import tvm
 from tvm.ir import IRModule, structural_equal
 from tvm.tir import PrimFunc
 from tvm.tir.schedule import Trace, Schedule
@@ -27,6 +28,7 @@ def verify_trace_roundtrip(
     mod: Union[PrimFunc, IRModule],
     *,
     debug_mask: Union[str, int] = "all",
+    text_format: Union[str, Sequence[str]] = ["python", "json"],
 ) -> Schedule:
     """Serialize a traced schedule to JSON, then replay the JSON trace by 
applying to
     a fresh new schedule, verifying the reproducibility of scheduling.
@@ -44,18 +46,36 @@ def verify_trace_roundtrip(
         1) "all" - Turn on all the checks
         2) "none" - Turn off all the checks
         3) An integer - Turn on checks according to the bitmasks provided in 
ScheduleDebugMask
+    text_format: Union[str, Sequence[str]]
+        The text format or formats whose round-trip behavior should be
+        validated.  If a single string, validate round-trips through
     """
-    # Step 1. Serialize the trace to JSON
+    if not isinstance(text_format, str):
+        for opt in text_format:
+            new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask, 
text_format=opt)
+        return new_sch
+
     trace = sch.trace
     assert trace is not None
-    json_obj = trace.as_json()
-    # Step 2. Apply the JSON trace to a new schedule, then check if it 
reproduces the scheduling
+
+    # Step 1. Perform a round-trip through the text-format
     new_sch = Schedule(mod=mod, debug_mask=debug_mask)
-    Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
+    if text_format == "json":
+        json_obj = trace.as_json()
+        Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
+    elif text_format == "python":
+        py_trace = "\n".join(trace.as_python())
+        exec(py_trace, tvm.tir.__dict__, {"sch": new_sch})  # pylint: 
disable=exec-used
+    else:
+        assert text_format in ("json", "python"), f"Unknown text format: 
{text_format}"
+
+    # Step 2. Verify that the round-trip produced the same scheduling
     assert structural_equal(new_sch.mod, sch.mod)
+
     # Step 3. Check the consistency of the text format between the old and new 
traces
     py_repr = "\n".join(trace.as_python())
     new_py_repr = "\n".join(new_sch.trace.as_python())
     assert py_repr == new_py_repr
+
     # Step 4. Return the new schedule in case it could be useful
     return new_sch
diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc 
b/src/tir/schedule/primitive/blockize_tensorize.cc
index 331d098347..7ed80a1c5b 100644
--- a/src/tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -699,7 +699,7 @@ struct TensorizeTraits : public 
UnpackedInstTraits<TensorizeTraits> {
   static String UnpackedAsPython(Array<String> outputs, String 
block_or_loop_rv, String intrin) {
     PythonAPICall py("tensorize");
     py.Input("block_or_loop", block_or_loop_rv);
-    py.Input("intrin", intrin);
+    py.Input("tensor_intrin", intrin);
     return py.Str();
   }
 
diff --git a/src/tir/schedule/primitive/layout_transformation.cc 
b/src/tir/schedule/primitive/layout_transformation.cc
index fb63b1b289..cf95665ee8 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -291,11 +291,12 @@ struct TransformLayoutTraits : public 
UnpackedInstTraits<TransformLayoutTraits>
                                  Integer buffer_index_type, IndexMap 
index_map) {
     PythonAPICall py("transform_layout");
     py.Input("block", block_rv);
-    py.Input("buffer_index", buffer_index);
-    py.Input("buffer_index_type", '"' +
-                                      std::string(BufferIndexType2Str(
-                                          
static_cast<BufferIndexType>(buffer_index_type->value))) +
-                                      '"');
+
+    std::ostringstream os;
+    os << "(\"" << 
BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
+       << "\", " << buffer_index << ")";
+    py.Input("buffer", os.str());
+
     py.Input("index_map", index_map->ToPythonString());
     return py.Str();
   }
@@ -343,11 +344,12 @@ struct SetAxisSeparatorTraits : public 
UnpackedInstTraits<SetAxisSeparatorTraits
                                  Integer buffer_index_type, Array<IntImm> 
axis_separators) {
     PythonAPICall py("set_axis_separator");
     py.Input("block", block_rv);
-    py.Input("buffer_index", buffer_index);
-    py.Input("buffer_index_type", '"' +
-                                      std::string(BufferIndexType2Str(
-                                          
static_cast<BufferIndexType>(buffer_index_type->value))) +
-                                      '"');
+
+    std::ostringstream os;
+    os << "(\"" << 
BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
+       << "\", " << buffer_index << ")";
+    py.Input("buffer", os.str());
+
     py.Input("axis_separators", axis_separators);
     return py.Str();
   }
diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py 
b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
index 8c3d1e6735..102b3d1cd7 100644
--- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py
+++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
@@ -20,6 +20,7 @@ import pytest
 import tvm
 import tvm.testing
 from tvm import tir
+from tvm.tir import IndexMap
 from tvm.script import tir as T
 from tvm.tir.schedule.testing import verify_trace_roundtrip
 
@@ -102,11 +103,19 @@ def element_wise_subregion_match_set_axis_separator(A: 
T.Buffer[(128, 128), "flo
 
 # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
 
+use_sugared_transform = tvm.testing.parameter(
+    by_dict={"set_axis_separators": False, "transform_layout_sugared": True}
+)
 
-def test_set_axis_separator():
+def test_set_axis_separator(use_sugared_transform):
     func = element_wise
     s = tir.Schedule(func, debug_mask='all')
-    s.set_axis_separator(s.get_block("B"), 0, "write", [1])
+
+    if use_sugared_transform:
+        s.set_axis_separator(s.get_block("B"), ("write",0), [1])
+    else:
+        s.transform_layout(block='B', buffer='B', index_map=lambda i,j: 
[i,IndexMap.AXIS_SEPARATOR,j])
+
     tvm.ir.assert_structural_equal(element_wise_set_axis_separator, 
s.mod["main"])
     verify_trace_roundtrip(sch=s, mod=func)
 
@@ -114,24 +123,35 @@ def test_set_axis_separator():
 def test_set_scope_fail_on_index_out_of_bound():
     func = element_wise
     s = tir.Schedule(func, debug_mask='all')
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.set_axis_separator(s.get_block("B"), 1, "write",[1])
-    with pytest.raises(tvm.tir.ScheduleError):
-        s.set_axis_separator(s.get_block("B"), -1, "read",[1])
+    with pytest.raises(AssertionError):
+        s.set_axis_separator(s.get_block("B"), ("write",1),[1])
+    with pytest.raises(AssertionError):
+        s.set_axis_separator(s.get_block("B"), ("read",-1),[1])
 
 
-def test_set_axis_separator_input_buffer():
+def test_set_axis_separator_input_buffer(use_sugared_transform):
     func = element_wise
     s = tir.Schedule(func, debug_mask='all')
-    s.set_axis_separator(s.get_block("B"), 0, "read", [1])
+
+    if use_sugared_transform:
+        s.transform_layout(block='B', buffer='A', index_map=lambda i,j: 
[i,IndexMap.AXIS_SEPARATOR,j])
+    else:
+        s.set_axis_separator(s.get_block("B"), ("read",0), [1])
+
+
     
tvm.ir.assert_structural_equal(element_wise_set_axis_separator_input_buffer, 
s.mod["main"])
     verify_trace_roundtrip(sch=s, mod=func)
 
 
-def test_set_axis_separator_subregion():
+def test_set_axis_separator_subregion(use_sugared_transform):
     func = element_wise_subregion_match
     s = tir.Schedule(func, debug_mask='all')
-    s.set_axis_separator(s.get_block("B"), 0, "write", [1])
+
+    if use_sugared_transform:
+        s.transform_layout(block='B', buffer='B', index_map=lambda i,j: 
[i,IndexMap.AXIS_SEPARATOR,j])
+    else:
+        s.set_axis_separator(s.get_block("B"), ("write",0), [1])
+
     
tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, 
s.mod["main"])
     verify_trace_roundtrip(sch=s, mod=func)
 
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py 
b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
index 67e8ae0ad8..e9ee990a24 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
@@ -177,9 +177,9 @@ def run_test(
     else:
         loop_b = tile_wmma_fragment(B_warp, k_inner, 16)
 
-    sch.transform_layout(A_warp, 0, "write", index_map_A)
-    sch.transform_layout(B_warp, 0, "write", index_map_B)
-    sch.transform_layout(C_warp, 0, "read", index_map_C)
+    sch.transform_layout(A_warp, ("write", 0), index_map_A)
+    sch.transform_layout(B_warp, ("write", 0), index_map_B)
+    sch.transform_layout(C_warp, ("read", 0), index_map_C)
 
     sch.tensorize(loop_a, ldmatrix_a_intrin)
     sch.tensorize(loop_b, ldmatrix_b_intrin)
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py 
b/tests/python/unittest/test_tir_schedule_transform_layout.py
index 9e7cad4d85..699eaf1236 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -94,27 +94,58 @@ def two_elementwise_transformed_output_buffer(
 # pylint: 
enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
 # fmt: on
 
+use_sugared_transform = tvm.testing.parameter(
+    by_dict={"transform_layout": False, "transform_layout_sugared": True}
+)
 
-def test_two_elementwise_transform_intermediate_buffer():
+
+def test_two_elementwise_transform_intermediate_buffer(use_sugared_transform):
     sch = tir.Schedule(two_elementwise, debug_mask="all")
-    block = sch.get_block("B")
-    sch.transform_layout(block, 0, "write", lambda m, n: (m // 16, n // 16, m 
% 16, n % 16))
+
+    if use_sugared_transform:
+        sch.transform_layout(
+            block="B",
+            buffer="B",
+            index_map=packed_index_map_func,
+        )
+    else:
+        block = sch.get_block("B")
+        sch.transform_layout(block, ("write", 0), packed_index_map_func)
+
     
tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer, 
sch.mod["main"])
     verify_trace_roundtrip(sch=sch, mod=two_elementwise)
 
 
-def test_two_elementwise_transform_input_buffer():
+def test_two_elementwise_transform_input_buffer(use_sugared_transform):
     sch = tir.Schedule(two_elementwise, debug_mask="all")
-    block = sch.get_block("B")
-    sch.transform_layout(block, 0, "read", packed_index_map_func)
+
+    if use_sugared_transform:
+        sch.transform_layout(
+            index_map=packed_index_map_func,
+            block="B",
+            buffer="A",
+        )
+    else:
+        block = sch.get_block("B")
+        sch.transform_layout(block, ("read", 0), packed_index_map_func)
+
     tvm.ir.assert_structural_equal(two_elementwise_transformed_input_buffer, 
sch.mod["main"])
     verify_trace_roundtrip(sch=sch, mod=two_elementwise)
 
 
-def test_two_elementwise_transform_output_buffer():
+def test_two_elementwise_transform_output_buffer(use_sugared_transform):
     sch = tir.Schedule(two_elementwise, debug_mask="all")
-    block = sch.get_block("C")
-    sch.transform_layout(block, 0, "write", packed_index_map_func)
+
+    if use_sugared_transform:
+        sch.transform_layout(
+            index_map=packed_index_map_func,
+            block="C",
+            buffer="C",
+        )
+    else:
+        block = sch.get_block("C")
+        sch.transform_layout(block, ("write", 0), packed_index_map_func)
+
     tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer, 
sch.mod["main"])
     verify_trace_roundtrip(sch=sch, mod=two_elementwise)
 
@@ -136,7 +167,7 @@ def test_simplify():
     block_outer = sch.blockize(i_inner)
 
     B = sch.cache_read(block_outer, 0, "global")
-    sch.transform_layout(B, 0, "write", lambda i, j: (i // 16, j // 16, i % 
16, j % 16))
+    sch.transform_layout(B, ("write", 0), lambda i, j: (i // 16, j // 16, i % 
16, j % 16))
 
     @T.prim_func
     def ref(B: T.Buffer[(8, 8, 16, 16), "float32"], C: T.Buffer[(128, 128), 
"float32"]):
@@ -159,5 +190,33 @@ def test_simplify():
     tvm.ir.assert_structural_equal(ref.body.block.body, 
sch.get(sch.get_loops(block_outer)[0]))
 
 
+def test_var_args_sugar():
+    @T.prim_func
+    def summation_3d(
+        A: T.Buffer[(1024, 1024, 32), "float32"], B: T.Buffer[(1,), "float32"]
+    ) -> None:
+        B[0] = 0
+        for i, j, k in T.grid(1024, 1024, 32):
+            with T.block("compute"):
+                vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                B[0] = B[0] + A[vi, vj, vk]
+
+    @T.prim_func
+    def summation_3d_split(
+        A: T.Buffer[(1024, 1024, 8, 4), "float32"], B: T.Buffer[(1,), 
"float32"]
+    ) -> None:
+        B[0] = 0
+        for i, j, k in T.grid(1024, 1024, 32):
+            with T.block("compute"):
+                vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+                B[0] = B[0] + A[vi, vj, vk // 4, vk % 4]
+
+    sch = tir.Schedule(summation_3d, debug_mask="all")
+    sch.transform_layout(
+        index_map=lambda *indices, k: [*indices, k // 4, k % 4], 
block="compute", buffer="A"
+    )
+    tvm.ir.assert_structural_equal(summation_3d_split, sch.mod["main"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to