vinx13 commented on a change in pull request #9727:
URL: https://github.com/apache/tvm/pull/9727#discussion_r779086842
##########
File path: src/tir/transforms/lower_custom_datatypes.cc
##########
@@ -103,32 +103,59 @@ class CustomDatatypesLowerer : public StmtExprMutator {
}
}
- PrimExpr VisitExpr_(const LoadNode* load) final {
- bool to_be_lowered =
datatype::Registry::Global()->GetTypeRegistered(load->dtype.code());
- PrimExpr expr = StmtExprMutator::VisitExpr_(load);
- load = expr.as<LoadNode>();
- if (to_be_lowered) {
- auto new_load_type = DataType::UInt(load->dtype.bits());
- auto buffer_var = load->buffer_var;
- auto it = var_remap_.find(buffer_var);
- if (it != var_remap_.end()) {
- buffer_var = it->second;
- }
- return Load(new_load_type, buffer_var, load->index, load->predicate);
- }
- return expr;
+ PrimExpr VisitExpr_(const LoadNode* op) final {
+ LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use
BufferLoadNode instead.";
Review comment:
comment to this and other passes: is this needed since the base class
`StmtMutator` has the default implantation that throws an error
##########
File path: include/tvm/tir/buffer.h
##########
@@ -55,8 +55,48 @@ class BufferNode : public Object {
Var data;
/*! \brief data type in the content of the tensor */
DataType dtype;
- /*! \brief The shape of the buffer */
+ /*! \brief The shape of the buffer
+ *
+ * This contains the shape as it is accessed by
+ * BufferLoad/BufferStore nodes, and used by the low-level code
+ * generators.
+ */
Array<PrimExpr> shape;
+ /*! \brief The shape of the buffer prior to flattening
+ *
+ * This contains the shape as it exists prior to flattening, and is
+ * used for validating the shape of the tensor passed into the
+ * packed API.
+ *
+ * TODO(Lunderberg): Should this be a reference to the entire
Review comment:
What the plan for this TODO? I feel that keeping a reference to the
pre-flattened Buffer can be simpler
##########
File path: src/printer/tir_text_printer.cc
##########
@@ -223,6 +223,9 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf,
Doc doc) {
if (!is_zero(buf->elem_offset)) {
doc << ", elem_offset=" << Print(buf->elem_offset);
}
+ if (buf->axis_separators.size()) {
+ doc << ", axis_separators=" << Print(buf->axis_separators);
+ }
Review comment:
Note on the future work: to support `axis_separator` from tvm script,
tvm script printer and the parser also need updates
##########
File path: python/tvm/te/schedule.py
##########
@@ -519,9 +523,149 @@ def rolling_buffer(self):
"""
_ffi_api.StageRollingBuffer(self)
+ def transform_layout(self, mapping_function: Callable[...,
List[tvm.tir.PrimExpr]]):
+ """Defines the layout transformation for the current stage's tensor.
+
+ The map from initial_indices to final_indices must be an
+ invertible affine transformation. This method may be called
+ more than once for a given tensor, in which case each
+ transformation is applied sequentially.
+
+ If the stage is a ComputeOp, then the iteration order of the
+ compute stage is rewritten to be a row-major traversal of the
+ tensor, and the new loop iteration variables are returned.
+ For all other stages, the loop iteration order is unmodified,
+ and the return value is None.
+
+ Parameters
+ ----------
+ mapping_function : Callable[..., List[tvm.tir.PrimExpr]]
+
+ A callable that accepts N arguments of type tvm.tir.Var,
+ and outputs a list of PrimExpr. The input arguments
+ represent the location of a value in the current stage's
+ tensor, using the pre-transformation layout. The return
+ value of the function gives the location of that value in
+ the current stage's tensor, using the post-transformation
+ layout.
+
+ Returns
+ -------
+ new_iter_vars : Optional[List[tvm.tir.IterVar]]
+
+ If the stage is a ComputeOp, then the return will be the
+ updated loop iteration variables over the data array, in
+ the same order as the output values from the
+ `mapping_function`.
+
+ Otherwise, the return value is None.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ # ``A`` is a tensor whose compute definition is in NHWC
+ # format, and should be transformed into NCHWc format.
+
+ s[A].transform_layout(
+ lambda n,h,w,c: [n, c//4, h, w, c%4]
+ )
+
+
+ .. code-block:: python
+
+ # ``A`` is a tensor whose compute definition is in an
+ # arbitrary format, and should be transformed such that
+ # the last index is split, with the slower-changing index
+ # of the split placed at the slowest changing dimension.
+
+ s[A].transform_layout(
+ lambda *indices, i: [i//4, *indices, i%4]
+ )
+
+ .. code-block:: python
+
+ # ``B`` is a tensor defined by te.compute to be a copy of
+ # ``A`, and should be transformed such that ``B``'s layout
+ # is a transpose of ``A``'s layout. The loop iteration
+ # that computes ``B`` will correspond to ``B``'s memory
+ # layout.
+
+ A = te.placeholder([n,m])
+ B = te.compute(A.shape, lambda i,j: A[i,j])
+ s = te.create_schedule(B.op)
+
+ s[B].transform_layout(lambda i,j: [j,i])
+
+ """
+
+ args = []
+ var_arg_name = None
+ kwargs = collections.OrderedDict()
+ default_index_dtype = "int32"
+
+ # Make a dummy variable for each explicitly named input index.
+ # We may have some keyword-only arguments, if the function has
+ # *args before the last argument.
+ params = inspect.signature(mapping_function).parameters
+ for name, param in params.items():
+ if param.kind in [
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ ]:
+ args.append(tvm.tir.Var(name, default_index_dtype))
+
+ elif param.kind == inspect.Parameter.VAR_POSITIONAL:
+ var_arg_name = name
+
+ elif param.kind == inspect.Parameter.KEYWORD_ONLY:
+ kwargs[name] = tvm.tir.Var(name, default_index_dtype)
+
+ elif param.kind in [inspect.Parameter.VAR_KEYWORD]:
+ raise ValueError("transform_layout mapping may not have
**kwargs")
+
+ ndim = len(self.op.output(0).shape)
+
+ # Now that all the named arguments have been collected,
+ # everything that remains should go to the *args, if
+ # specified.
+ if var_arg_name is not None:
+ num_var_args = ndim - len(args) - len(kwargs)
+ for i in range(num_var_args):
+ args.append(tvm.tir.Var(f"{var_arg_name}[{i}]",
default_index_dtype))
+
+ initial_indices = args + list(kwargs.values())
+ if len(initial_indices) != ndim:
+ raise ValueError(
+ f"transform_layout mapping accepts {len(params)} initial
indices, "
+ f"but {self.op.name} is {len(self.op.shape)}-dimensional"
+ )
+
+ mapping = mapping_function(*args, **kwargs)
+
+ final_indices = []
+ axis_separators = []
+ for val in mapping:
+ if isinstance(val, tvm.ir.PrimExpr):
+ final_indices.append(val)
+ elif val is AXIS_SEPARATOR:
+ axis_separators.append(len(final_indices))
+ else:
+ raise TypeError(
+ "Expected mapping function to return list of "
+ "either tvm.ir.PrimExpr or tvm.te.AXIS_SEPARATOR. "
+ "Instead received {val} of type {type(val)}."
+ )
+
+ new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices,
final_indices)
+ _ffi_api.StageSetAxisSeparators(self, axis_separators)
Review comment:
Since axis separator is only an attribute to the buffer that affect
flattening, would it be better to separate this two API calls to two schedule
primitives, even if on the user-facing level? (setting axis separator doesn't
have to be involved with layout transform)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]