This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 814f5501bf [TIR][Schedule] Transform layout quality of life (#11269)
814f5501bf is described below
commit 814f5501bf7d65f759135d214572388b0ddadefc
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed May 25 19:19:33 2022 -0500
[TIR][Schedule] Transform layout quality of life (#11269)
* [TIR][Schedule] Added Schedule.transform_layout_sugared
* [TE][TIR] Reduced duplication in TE/TIR layout transformations
Previously, the implementations of `tir.IndexMap.from_func` and
`te.Stage.transform_layout` had significant duplication to handle
argument parsing. This commit extracts the shared logic into
`tir.IndexMap`.
* Enabled *args in Schedule.transform_layout_sugared
* Fix lint error
* Allow Schedule.transform_layout_sugared to set axis separators
* Merged transform_layout_sugared functionality into transform_layout
* Fix lint errors
* Fix lint error
* Fixed docstring errors
* Updated/tested TransformatLayoutTraits::UnpackedAsPython
* Disabled exec-used check for running trace.as_python()
* Updated SetAxisSeparatorTraits::UnpackedAsPython
* Updated unit test that was added in merge commit
* Fixed the argument name for TensorizeTraits
This wasn't checked before, but was the only other issue caught by the
updates to verify_trace_roundtrip.
* Re-enable type checks of transform_layout/set_axis_separator
Disabled while waiting for https://github.com/apache/tvm/pull/11289,
which was required for the `Tuple` argument.
* Updated a few additional transform_layout usages from main
---
python/tvm/te/schedule.py | 70 ++-------
python/tvm/tir/function.py | 103 ++++++++++++-
python/tvm/tir/schedule/schedule.py | 165 +++++++++++++++++----
python/tvm/tir/schedule/testing.py | 30 +++-
src/tir/schedule/primitive/blockize_tensorize.cc | 2 +-
.../schedule/primitive/layout_transformation.cc | 22 +--
.../test_tir_schedule_set_axis_separator.py | 40 +++--
.../test_tir_schedule_tensorize_ldmatrix_mma.py | 6 +-
.../unittest/test_tir_schedule_transform_layout.py | 79 ++++++++--
9 files changed, 385 insertions(+), 132 deletions(-)
diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py
index fdd08f9208..50f9a22ec2 100644
--- a/python/tvm/te/schedule.py
+++ b/python/tvm/te/schedule.py
@@ -25,7 +25,7 @@ from tvm._ffi.base import string_types
from tvm.runtime import Object, convert
from tvm.ir import container as _container
-from tvm.tir import IterVar, Buffer, Var
+from tvm.tir import IterVar, Buffer, Var, IndexMap
from . import tensor as _tensor
from . import _ffi_api
@@ -599,65 +599,12 @@ class Stage(Object):
"""
- args = []
- var_arg_name = None
- kwargs = collections.OrderedDict()
- default_index_dtype = "int32"
-
- # Make a dummy variable for each explicitly named input index.
- # We may have some keyword-only arguments, if the function has
- # *args before the last argument.
- params = inspect.signature(mapping_function).parameters
- for name, param in params.items():
- if param.kind in [
- inspect.Parameter.POSITIONAL_ONLY,
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- ]:
- args.append(tvm.tir.Var(name, default_index_dtype))
-
- elif param.kind == inspect.Parameter.VAR_POSITIONAL:
- var_arg_name = name
-
- elif param.kind == inspect.Parameter.KEYWORD_ONLY:
- kwargs[name] = tvm.tir.Var(name, default_index_dtype)
-
- elif param.kind in [inspect.Parameter.VAR_KEYWORD]:
- raise ValueError("transform_layout mapping may not have
**kwargs")
-
ndim = len(self.op.output(0).shape)
+ index_map, axis_separators =
IndexMap.from_func_with_separators(mapping_function, ndim=ndim)
- # Now that all the named arguments have been collected,
- # everything that remains should go to the *args, if
- # specified.
- if var_arg_name is not None:
- num_var_args = ndim - len(args) - len(kwargs)
- for i in range(num_var_args):
- args.append(tvm.tir.Var(f"{var_arg_name}[{i}]",
default_index_dtype))
-
- initial_indices = args + list(kwargs.values())
- if len(initial_indices) != ndim:
- raise ValueError(
- f"transform_layout mapping accepts {len(params)} initial
indices, "
- f"but {self.op.name} is {len(self.op.shape)}-dimensional"
- )
-
- mapping = mapping_function(*args, **kwargs)
-
- final_indices = []
- axis_separators = []
- for val in mapping:
- if isinstance(val, tvm.ir.PrimExpr):
- final_indices.append(val)
- elif val is AXIS_SEPARATOR:
- axis_separators.append(len(final_indices))
- else:
- raise TypeError(
- "Expected mapping function to return list of "
- "either tvm.ir.PrimExpr or tvm.te.AXIS_SEPARATOR. "
- "Instead received {val} of type {type(val)}."
- )
-
- new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices,
final_indices)
+ new_iter_vars = _ffi_api.StageTransformLayout(
+ self, index_map.initial_indices, index_map.final_indices
+ )
_ffi_api.StageSetAxisSeparators(self, axis_separators)
return new_iter_vars or None
@@ -700,9 +647,10 @@ class SpecializedCondition(Object):
# Sentinel value used to indicate which groups of pre-flattening axes
-# should be used to post-flattening axes axes. See
-# Stage.transform_layout for more details.
-AXIS_SEPARATOR = "axis_separator"
+# should be used to post-flattening axes axes. Moved from
+# te.AXIS_SEPARATOR to tir.IndexMap.AXIS_SEPARATOR for general use,
+# maintained here for backwards compatibility.
+AXIS_SEPARATOR = IndexMap.AXIS_SEPARATOR
tvm._ffi._init_api("schedule", __name__)
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index d84513e072..a921c5b9fc 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -16,8 +16,9 @@
# under the License.
"""Function data types."""
-from typing import Callable, List, Mapping, Optional, Union, Tuple
+import collections
import inspect
+from typing import Callable, List, Mapping, Optional, Union, Tuple
import tvm
import tvm._ffi
@@ -258,6 +259,11 @@ class IndexMap(Object):
initial_indices: List[Var]
final_indices: List[PrimExpr]
+ # Sentinel value used to indicate which groups of pre-flattening axes
+ # should be used to post-flattening axes axes. See
+ # Stage.transform_layout for more details.
+ AXIS_SEPARATOR = "axis_separator"
+
def __init__(self, initial_indices, final_indices):
self.__init_handle_by_constructor__(_ffi_api.IndexMap,
initial_indices, final_indices)
@@ -268,34 +274,117 @@ class IndexMap(Object):
Parameters
----------
mapping_function : Callable
- The function to map from source indices to target indices
+
+ The function to map from source indices to target indices.
+ The function should accept `tir.Var` parameters and return
+ a list. Each element of the returned list should be a
+ `tir.PrimExpr`.
+
+ ndim: Optional[int]
+
+ The dimensionality of the buffer to which this
+ transformation should be applied. If mapping_function uses
+ variadic argument `*args`, `ndim` must be specified. If
+ mapping_function does not use variadic arguments, ndim is
+ optional.
+
+ Returns
+ -------
+ index_map: IndexMap
+
+ Returns an IndexMap representing the `mapping_function`.
+
+ """
+ index_map, axis_separators =
IndexMap.from_func_with_separators(mapping_function, ndim)
+ assert not axis_separators, (
+ "The mapping_function provided to IndexMap.from_func "
+ "may not return IndexMap.AXIS_SEPARATOR. "
+ "If required, please use IndexMap.from_func_with_separators
instead."
+ )
+ return index_map
+
+ @staticmethod
+ def from_func_with_separators(mapping_function: Callable, ndim:
Optional[int] = None):
+ """Create an index map from a function
+
+ Parameters
+ ----------
+ mapping_function : Callable
+
+ The function to map from source indices to target indices.
+ The function should accept tir.Var parameters and return a
+ list. Each element of the returned list should be either a
+ `tir.PrimExpr` or the object `IndexMap.AXIS_SEPARATOR`.
+
+ ndim: Optional[int]
+
+ The dimensionality of the buffer to which this
+ transformation should be applied. If mapping_function uses
+ variadic argument `*args`, ndim must be specified. If
+ mapping_function does not use variadic arguments, ndim is
+ optional.
+
+ Returns
+ -------
+ ret: Tuple[IndexMap, List[int]]
+
+ Returns a tuple whose first element is an IndexMap
+ representing the `mapping_function`, and whose second index
+ is a list of indices at which `IndexMap.AXIS_SEPARATOR`
+ occurred.
+
"""
params = inspect.signature(mapping_function).parameters
- default_index_dtype = "int32"
+
args = []
var_arg_name = None
+ kwargs = collections.OrderedDict()
+ default_index_dtype = "int32"
+
for name, param in params.items():
if param.kind in [
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
]:
args.append(tvm.tir.Var(name, default_index_dtype))
+
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
var_arg_name = name
+
+ elif param.kind == inspect.Parameter.KEYWORD_ONLY:
+ kwargs[name] = tvm.tir.Var(name, default_index_dtype)
+
else:
- raise ValueError("transform_layout mapping may not have *args
or **kwargs")
+ raise ValueError("transform_layout mapping may not have *args")
# Now that all the named arguments have been collected,
# everything that remains should go to the *args, if
# specified.
if var_arg_name is not None:
assert ndim is not None, "ndim must be specified when *args is
used"
- num_var_args = ndim - len(args)
+ num_var_args = ndim - len(args) - len(kwargs)
for i in range(num_var_args):
args.append(tvm.tir.Var(f"{var_arg_name}_{i}",
default_index_dtype))
- final_indices = mapping_function(*args)
- return IndexMap(args, final_indices)
+ mapping = mapping_function(*args, **kwargs)
+
+ initial_indices = args + list(kwargs.values())
+
+ final_indices = []
+ axis_separators = []
+ for val in mapping:
+ if isinstance(val, tvm.ir.PrimExpr):
+ final_indices.append(val)
+ elif val is IndexMap.AXIS_SEPARATOR:
+ axis_separators.append(len(final_indices))
+ else:
+ raise TypeError(
+ "Expected mapping function to return list of "
+ "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. "
+ "Instead received {val} of type {type(val)}."
+ )
+
+ return IndexMap(initial_indices, final_indices), axis_separators
def is_equivalent_to(self, other_map: "IndexMap") -> bool:
"""Return if the index maps are equivalent.
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 6474ba0baa..dc687b1eae 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -15,13 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""The TensorIR schedule class"""
-from typing import Callable, Dict, List, Optional, Union
+from typing import Callable, Dict, List, Optional, Union, Tuple
from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
-from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc
+from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, Buffer
from ..function import IndexMap
from . import _ffi_api
@@ -2114,25 +2114,111 @@ class Schedule(Object):
########## Schedule: Layout transformation ##########
+ def _normalize_block_arg(self, block: Union[BlockRV, str]) -> BlockRV:
+ if isinstance(block, str):
+ return self.get_block(block)
+
+ return block
+
+ def _normalize_buffer_arg(
+ self, block: BlockRV, buffer: Union[Tuple[str, int], str, Buffer]
+ ) -> Tuple[str, int, Buffer]:
+
+ block_name = self.get(block).name_hint
+
+ def iter_buffers():
+ block_obj = self.get(block)
+ for i, read in enumerate(block_obj.reads):
+ yield "read", i, read.buffer
+ for i, write in enumerate(block_obj.writes):
+ yield "write", i, write.buffer
+
+ if isinstance(buffer, str):
+ possible_buffers = {}
+ # String lookup requires ensuring that the name is unique
+ for buffer_index, buffer_index_type, buf in iter_buffers():
+ if buf.name == buffer:
+ possible_buffers[buf] = (buffer_index_type, buffer_index)
+
+ assert possible_buffers, f"Could not find buffer '{buffer}' in
block '{block_name}'"
+ assert (
+ len(possible_buffers) == 1
+ ), f"Multiple buffers named '{buffer}' in block '{block_name}'"
+ buffer_obj, (buffer_index, buffer_index_type) =
next(iter(possible_buffers.items()))
+
+ elif isinstance(buffer, Buffer):
+ # Buffer lookup has unique id, can break out early
+ found = False
+ for buffer_index, buffer_index_type, buffer_obj in iter_buffers():
+ if buffer_obj.same_as(buffer):
+ found = True
+ break
+
+ assert found, "Could not find buffer '{buffer.name}' in block
'{block_name}'"
+
+ elif isinstance(buffer, tuple):
+ buffer_index_type, buffer_index = buffer
+ assert buffer_index_type in ["read", "write",], (
+ f"Invalid buffer_index_type. "
+ f"Expected 'read' or 'write', "
+ f"but received {buffer_index_type}"
+ )
+ buffer_list = (
+ self.get(block).reads if buffer_index_type == "read" else
self.get(block).writes
+ )
+ assert 0 <= buffer_index < len(buffer_list), (
+ f"Invalid buffer_index {buffer_index}. "
+ f"Block {block_name} has only "
+ f"{len(buffer_list)} {buffer_index_type} buffers."
+ )
+ buffer_obj = buffer_list[buffer_index].buffer
+
+ else:
+ raise TypeError(f"Invalid type for argument 'buffer':
{type(buffer)}")
+
+ return (buffer_index_type, buffer_index, buffer_obj)
+
@type_checked
def transform_layout(
self,
- block: BlockRV,
- buffer_index: int,
- buffer_index_type: str,
+ block: Union[BlockRV, str],
+ buffer: Union[Tuple[str, int], str, Buffer],
index_map: Union[IndexMap, Callable],
) -> None:
"""Apply a transformation represented by IndexMap to buffer
+
Parameters
----------
- block : BlockRV
- The block that accesses the target buffer
- buffer_index: int
- The index of the buffer in block's read or write region
- buffer_index_type : str
- Type of the buffer index, "read" or "write"
+ block : Union[BlockRV, str]
+
+ The block that accesses the target buffer. If a string,
+ this must uniquely identify a block.
+
+ buffer: Union[Tuple[str,int], Buffer, str]
+
+ The buffer to be transformed, or a specification of how to
+ identify the buffer to be transformed.
+
+ If `buffer` if a tuple of ``(str,int)``, the first item
+ should be either "read" or "write", and the second item is
+ an index into the block's read or write regions.
+
+ If `buffer` is a string, it is the name of the buffer,
+ which must exist within the reads/writes of the block. In
+ addition, the reads/writes of the block may not contain
+ more than one buffer with this name.
+
+ If `buffer` is a Buffer object, it must exist within the
+ reads/writes of the block.
+
index_map : Union[IndexMap, Callable]
- The transformation to apply
+
+ The transformation to apply.
+
+ If `index_map` is a callable, and the returned list
+ contains IndexMap.AXIS_SEPARATOR, the SetAxisSeparators
+ primitive will be called in addition to the
+ TransformLayout primitive.
Examples
--------
@@ -2159,7 +2245,7 @@ class Schedule(Object):
.. code-block:: python
sch = tir.Schedule(before_storage_align)
- sch.transform_layout(sch.get_block("B"), buffer_index=0, "write",
+ sch.transform_layout(sch.get_block("B"), buffer=("write",0),
index_map=lambda m, n: (m // 16, n // 16, m %
16, n % 16))
print(sch.mod["main"].script())
@@ -2182,20 +2268,29 @@ class Schedule(Object):
C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] +
1.0
"""
+ block = self._normalize_block_arg(block)
+ buffer_index_type, buffer_index, buffer_obj =
self._normalize_buffer_arg(block, buffer)
+
+ ndim = len(buffer_obj.shape)
if callable(index_map):
- index_map = IndexMap.from_func(index_map)
- assert buffer_index_type in ["read", "write"], "Invalid
buffer_index_type"
+ index_map, axis_separators =
IndexMap.from_func_with_separators(index_map, ndim=ndim)
+ else:
+ axis_separators = []
+
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint:
disable=no-member
self, block, buffer_index, buffer_index_type_enum, index_map
)
+ if axis_separators:
+ _ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint:
disable=no-member
+ self, block, buffer_index, buffer_index_type_enum,
axis_separators
+ )
@type_checked
def set_axis_separator(
self,
- block: BlockRV,
- buffer_index: int,
- buffer_index_type: str,
+ block: Union[BlockRV, str],
+ buffer: Union[Tuple[str, int], str, Buffer],
axis_separators: Optional[List[int]],
) -> None:
"""Set the axis separator of a buffer, where the buffer is specified
by a block and a read
@@ -2203,13 +2298,30 @@ class Schedule(Object):
Parameters
----------
- block : BlockRV
- The block that accesses the target buffer
- buffer_index: int
- The index of the buffer in block's read or write region
- buffer_index_type : str
- Type of the buffer index, "read" or "write"
+ block : Union[BlockRV, str]
+
+ The block that accesses the target buffer. If a string,
+ this must uniquely identify a block.
+
+ buffer: Union[Tuple[str,int], Buffer, str]
+
+ The buffer to be transformed, or a specification of how to
+ identify the buffer to be transformed.
+
+ If `buffer` if a tuple of ``(str,int)``, the first item
+ should be either "read" or "write", and the second item is
+ an index into the block's read or write regions.
+
+ If `buffer` is a string, it is the name of the buffer,
+ which must exist within the reads/writes of the block. In
+ addition, the reads/writes of the block may not contain
+ more than one buffer with this name.
+
+ If `buffer` is a Buffer object, it must exist within the
+ reads/writes of the block.
+
axis_separators : Optional[List[int]]
+
The axis separators.
Examples
@@ -2263,7 +2375,10 @@ class Schedule(Object):
C[vi, vj] = B[vi, vj] + T.float32(1)
"""
axis_separators = axis_separators or []
- assert buffer_index_type in ["read", "write"], "Invalid
buffer_index_type"
+
+ block = self._normalize_block_arg(block)
+ buffer_index_type, buffer_index, _ = self._normalize_buffer_arg(block,
buffer)
+
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
_ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint:
disable=no-member
self, block, buffer_index, buffer_index_type_enum, axis_separators
diff --git a/python/tvm/tir/schedule/testing.py
b/python/tvm/tir/schedule/testing.py
index 04cbffcd4d..3689f756e8 100644
--- a/python/tvm/tir/schedule/testing.py
+++ b/python/tvm/tir/schedule/testing.py
@@ -15,8 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities for the TensorIR schedule API"""
-from typing import Union
+from typing import Union, Sequence
+import tvm
from tvm.ir import IRModule, structural_equal
from tvm.tir import PrimFunc
from tvm.tir.schedule import Trace, Schedule
@@ -27,6 +28,7 @@ def verify_trace_roundtrip(
mod: Union[PrimFunc, IRModule],
*,
debug_mask: Union[str, int] = "all",
+ text_format: Union[str, Sequence[str]] = ["python", "json"],
) -> Schedule:
"""Serialize a traced schedule to JSON, then replay the JSON trace by
applying to
a fresh new schedule, verifying the reproducibility of scheduling.
@@ -44,18 +46,36 @@ def verify_trace_roundtrip(
1) "all" - Turn on all the checks
2) "none" - Turn off all the checks
3) An integer - Turn on checks according to the bitmasks provided in
ScheduleDebugMask
+ text_format: Union[str, Sequence[str]]
+ The text format or formats whose round-trip behavior should be
+ validated. If a single string, validate round-trips through
"""
- # Step 1. Serialize the trace to JSON
+ if not isinstance(text_format, str):
+ for opt in text_format:
+ new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask,
text_format=opt)
+ return new_sch
+
trace = sch.trace
assert trace is not None
- json_obj = trace.as_json()
- # Step 2. Apply the JSON trace to a new schedule, then check if it
reproduces the scheduling
+
+ # Step 1. Perform a round-trip through the text-format
new_sch = Schedule(mod=mod, debug_mask=debug_mask)
- Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
+ if text_format == "json":
+ json_obj = trace.as_json()
+ Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
+ elif text_format == "python":
+ py_trace = "\n".join(trace.as_python())
+ exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) # pylint:
disable=exec-used
+ else:
+ assert text_format in ("json", "python"), f"Unknown text format:
{text_format}"
+
+ # Step 2. Verify that the round-trip produced the same scheduling
assert structural_equal(new_sch.mod, sch.mod)
+
# Step 3. Check the consistency of the text format between the old and new
traces
py_repr = "\n".join(trace.as_python())
new_py_repr = "\n".join(new_sch.trace.as_python())
assert py_repr == new_py_repr
+
# Step 4. Return the new schedule in case it could be useful
return new_sch
diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc
b/src/tir/schedule/primitive/blockize_tensorize.cc
index 331d098347..7ed80a1c5b 100644
--- a/src/tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -699,7 +699,7 @@ struct TensorizeTraits : public
UnpackedInstTraits<TensorizeTraits> {
static String UnpackedAsPython(Array<String> outputs, String
block_or_loop_rv, String intrin) {
PythonAPICall py("tensorize");
py.Input("block_or_loop", block_or_loop_rv);
- py.Input("intrin", intrin);
+ py.Input("tensor_intrin", intrin);
return py.Str();
}
diff --git a/src/tir/schedule/primitive/layout_transformation.cc
b/src/tir/schedule/primitive/layout_transformation.cc
index fb63b1b289..cf95665ee8 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -291,11 +291,12 @@ struct TransformLayoutTraits : public
UnpackedInstTraits<TransformLayoutTraits>
Integer buffer_index_type, IndexMap
index_map) {
PythonAPICall py("transform_layout");
py.Input("block", block_rv);
- py.Input("buffer_index", buffer_index);
- py.Input("buffer_index_type", '"' +
- std::string(BufferIndexType2Str(
-
static_cast<BufferIndexType>(buffer_index_type->value))) +
- '"');
+
+ std::ostringstream os;
+ os << "(\"" <<
BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
+ << "\", " << buffer_index << ")";
+ py.Input("buffer", os.str());
+
py.Input("index_map", index_map->ToPythonString());
return py.Str();
}
@@ -343,11 +344,12 @@ struct SetAxisSeparatorTraits : public
UnpackedInstTraits<SetAxisSeparatorTraits
Integer buffer_index_type, Array<IntImm>
axis_separators) {
PythonAPICall py("set_axis_separator");
py.Input("block", block_rv);
- py.Input("buffer_index", buffer_index);
- py.Input("buffer_index_type", '"' +
- std::string(BufferIndexType2Str(
-
static_cast<BufferIndexType>(buffer_index_type->value))) +
- '"');
+
+ std::ostringstream os;
+ os << "(\"" <<
BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
+ << "\", " << buffer_index << ")";
+ py.Input("buffer", os.str());
+
py.Input("axis_separators", axis_separators);
return py.Str();
}
diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py
b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
index 8c3d1e6735..102b3d1cd7 100644
--- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py
+++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py
@@ -20,6 +20,7 @@ import pytest
import tvm
import tvm.testing
from tvm import tir
+from tvm.tir import IndexMap
from tvm.script import tir as T
from tvm.tir.schedule.testing import verify_trace_roundtrip
@@ -102,11 +103,19 @@ def element_wise_subregion_match_set_axis_separator(A:
T.Buffer[(128, 128), "flo
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
+use_sugared_transform = tvm.testing.parameter(
+ by_dict={"set_axis_separators": False, "transform_layout_sugared": True}
+)
-def test_set_axis_separator():
+def test_set_axis_separator(use_sugared_transform):
func = element_wise
s = tir.Schedule(func, debug_mask='all')
- s.set_axis_separator(s.get_block("B"), 0, "write", [1])
+
+ if use_sugared_transform:
+ s.set_axis_separator(s.get_block("B"), ("write",0), [1])
+ else:
+ s.transform_layout(block='B', buffer='B', index_map=lambda i,j:
[i,IndexMap.AXIS_SEPARATOR,j])
+
tvm.ir.assert_structural_equal(element_wise_set_axis_separator,
s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)
@@ -114,24 +123,35 @@ def test_set_axis_separator():
def test_set_scope_fail_on_index_out_of_bound():
func = element_wise
s = tir.Schedule(func, debug_mask='all')
- with pytest.raises(tvm.tir.ScheduleError):
- s.set_axis_separator(s.get_block("B"), 1, "write",[1])
- with pytest.raises(tvm.tir.ScheduleError):
- s.set_axis_separator(s.get_block("B"), -1, "read",[1])
+ with pytest.raises(AssertionError):
+ s.set_axis_separator(s.get_block("B"), ("write",1),[1])
+ with pytest.raises(AssertionError):
+ s.set_axis_separator(s.get_block("B"), ("read",-1),[1])
-def test_set_axis_separator_input_buffer():
+def test_set_axis_separator_input_buffer(use_sugared_transform):
func = element_wise
s = tir.Schedule(func, debug_mask='all')
- s.set_axis_separator(s.get_block("B"), 0, "read", [1])
+
+ if use_sugared_transform:
+ s.transform_layout(block='B', buffer='A', index_map=lambda i,j:
[i,IndexMap.AXIS_SEPARATOR,j])
+ else:
+ s.set_axis_separator(s.get_block("B"), ("read",0), [1])
+
+
tvm.ir.assert_structural_equal(element_wise_set_axis_separator_input_buffer,
s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)
-def test_set_axis_separator_subregion():
+def test_set_axis_separator_subregion(use_sugared_transform):
func = element_wise_subregion_match
s = tir.Schedule(func, debug_mask='all')
- s.set_axis_separator(s.get_block("B"), 0, "write", [1])
+
+ if use_sugared_transform:
+ s.transform_layout(block='B', buffer='B', index_map=lambda i,j:
[i,IndexMap.AXIS_SEPARATOR,j])
+ else:
+ s.set_axis_separator(s.get_block("B"), ("write",0), [1])
+
tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator,
s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
index 67e8ae0ad8..e9ee990a24 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
@@ -177,9 +177,9 @@ def run_test(
else:
loop_b = tile_wmma_fragment(B_warp, k_inner, 16)
- sch.transform_layout(A_warp, 0, "write", index_map_A)
- sch.transform_layout(B_warp, 0, "write", index_map_B)
- sch.transform_layout(C_warp, 0, "read", index_map_C)
+ sch.transform_layout(A_warp, ("write", 0), index_map_A)
+ sch.transform_layout(B_warp, ("write", 0), index_map_B)
+ sch.transform_layout(C_warp, ("read", 0), index_map_C)
sch.tensorize(loop_a, ldmatrix_a_intrin)
sch.tensorize(loop_b, ldmatrix_b_intrin)
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py
b/tests/python/unittest/test_tir_schedule_transform_layout.py
index 9e7cad4d85..699eaf1236 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -94,27 +94,58 @@ def two_elementwise_transformed_output_buffer(
# pylint:
enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
# fmt: on
+use_sugared_transform = tvm.testing.parameter(
+ by_dict={"transform_layout": False, "transform_layout_sugared": True}
+)
-def test_two_elementwise_transform_intermediate_buffer():
+
+def test_two_elementwise_transform_intermediate_buffer(use_sugared_transform):
sch = tir.Schedule(two_elementwise, debug_mask="all")
- block = sch.get_block("B")
- sch.transform_layout(block, 0, "write", lambda m, n: (m // 16, n // 16, m
% 16, n % 16))
+
+ if use_sugared_transform:
+ sch.transform_layout(
+ block="B",
+ buffer="B",
+ index_map=packed_index_map_func,
+ )
+ else:
+ block = sch.get_block("B")
+ sch.transform_layout(block, ("write", 0), packed_index_map_func)
+
tvm.ir.assert_structural_equal(two_elementwise_transformed_intermediate_buffer,
sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=two_elementwise)
-def test_two_elementwise_transform_input_buffer():
+def test_two_elementwise_transform_input_buffer(use_sugared_transform):
sch = tir.Schedule(two_elementwise, debug_mask="all")
- block = sch.get_block("B")
- sch.transform_layout(block, 0, "read", packed_index_map_func)
+
+ if use_sugared_transform:
+ sch.transform_layout(
+ index_map=packed_index_map_func,
+ block="B",
+ buffer="A",
+ )
+ else:
+ block = sch.get_block("B")
+ sch.transform_layout(block, ("read", 0), packed_index_map_func)
+
tvm.ir.assert_structural_equal(two_elementwise_transformed_input_buffer,
sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=two_elementwise)
-def test_two_elementwise_transform_output_buffer():
+def test_two_elementwise_transform_output_buffer(use_sugared_transform):
sch = tir.Schedule(two_elementwise, debug_mask="all")
- block = sch.get_block("C")
- sch.transform_layout(block, 0, "write", packed_index_map_func)
+
+ if use_sugared_transform:
+ sch.transform_layout(
+ index_map=packed_index_map_func,
+ block="C",
+ buffer="C",
+ )
+ else:
+ block = sch.get_block("C")
+ sch.transform_layout(block, ("write", 0), packed_index_map_func)
+
tvm.ir.assert_structural_equal(two_elementwise_transformed_output_buffer,
sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=two_elementwise)
@@ -136,7 +167,7 @@ def test_simplify():
block_outer = sch.blockize(i_inner)
B = sch.cache_read(block_outer, 0, "global")
- sch.transform_layout(B, 0, "write", lambda i, j: (i // 16, j // 16, i %
16, j % 16))
+ sch.transform_layout(B, ("write", 0), lambda i, j: (i // 16, j // 16, i %
16, j % 16))
@T.prim_func
def ref(B: T.Buffer[(8, 8, 16, 16), "float32"], C: T.Buffer[(128, 128),
"float32"]):
@@ -159,5 +190,33 @@ def test_simplify():
tvm.ir.assert_structural_equal(ref.body.block.body,
sch.get(sch.get_loops(block_outer)[0]))
+def test_var_args_sugar():
+ @T.prim_func
+ def summation_3d(
+ A: T.Buffer[(1024, 1024, 32), "float32"], B: T.Buffer[(1,), "float32"]
+ ) -> None:
+ B[0] = 0
+ for i, j, k in T.grid(1024, 1024, 32):
+ with T.block("compute"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ B[0] = B[0] + A[vi, vj, vk]
+
+ @T.prim_func
+ def summation_3d_split(
+ A: T.Buffer[(1024, 1024, 8, 4), "float32"], B: T.Buffer[(1,),
"float32"]
+ ) -> None:
+ B[0] = 0
+ for i, j, k in T.grid(1024, 1024, 32):
+ with T.block("compute"):
+ vi, vj, vk = T.axis.remap("SSS", [i, j, k])
+ B[0] = B[0] + A[vi, vj, vk // 4, vk % 4]
+
+ sch = tir.Schedule(summation_3d, debug_mask="all")
+ sch.transform_layout(
+ index_map=lambda *indices, k: [*indices, k // 4, k % 4],
block="compute", buffer="A"
+ )
+ tvm.ir.assert_structural_equal(summation_3d_split, sch.mod["main"])
+
+
if __name__ == "__main__":
tvm.testing.main()