This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d4a396825b [TIR] Add schedule primitive TransformBlockLayout (#11485)
d4a396825b is described below
commit d4a396825bead4c617a4867a10dd6eff7797add4
Author: Wuwei Lin <[email protected]>
AuthorDate: Sun May 29 09:12:17 2022 -0700
[TIR] Add schedule primitive TransformBlockLayout (#11485)
* [TIR] Add schedule primitive TransformBlockLayout
* fixup! [TIR] Add schedule primitive TransformBlockLayout
Fix doc
---
include/tvm/tir/schedule/schedule.h | 10 +
python/tvm/tir/schedule/schedule.py | 61 +++++
src/tir/schedule/analysis.h | 11 +
src/tir/schedule/analysis/analysis.cc | 29 ++
src/tir/schedule/concrete_schedule.cc | 8 +
src/tir/schedule/concrete_schedule.h | 1 +
src/tir/schedule/primitive.h | 12 +
.../schedule/primitive/layout_transformation.cc | 304 +++++++++++++++++++++
src/tir/schedule/primitive/loop_transformation.cc | 29 +-
src/tir/schedule/schedule.cc | 2 +
src/tir/schedule/traced_schedule.cc | 10 +
src/tir/schedule/traced_schedule.h | 1 +
src/tir/schedule/transform.cc | 31 +++
src/tir/schedule/transform.h | 39 +++
.../unittest/test_tir_schedule_transform_layout.py | 113 ++++++++
15 files changed, 635 insertions(+), 26 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index 18e15d1670..48014280a5 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -545,6 +545,16 @@ class ScheduleNode : public runtime::Object {
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type, const
IndexMap& index_map) = 0;
+ /*!
+ * \brief Apply a transformation represented by IndexMap to block
+ * \details The block iters and the block body are transformed by the given
index_map.
+ * Outer loops corresponding to each new block iter are regenerated.
+ * The index_map is required to be bijective affine since we need its
inverse mapping.
+ * \param block_rv The block to be transformed
+ * \param index_map The transformation to apply.
+ */
+ virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap&
index_map) = 0;
+
/*!
* \brief Set the axis separator of a buffer, where the buffer is specified
by a block and a read
* or write index
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index dc687b1eae..f86228848b 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2286,6 +2286,67 @@ class Schedule(Object):
self, block, buffer_index, buffer_index_type_enum,
axis_separators
)
+ @type_checked
+ def transform_block_layout(
+ self,
+ block: BlockRV,
+ index_map: Union[IndexMap, Callable],
+ ) -> None:
+ """Apply a transformation represented by IndexMap to block
+
+ Parameters
+ ----------
+ block : BlockRV
+ The block to be transformed
+
+ index_map : Union[IndexMap, Callable]
+ The transformation to apply.
+
+ Examples
+ --------
+
+ Before transform_block_layout, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_transform_block_layout(
+ A: T.Buffer[(16, 16), "float32"],
+ B: T.Buffer[(16, 16), "float32"]
+ ) -> None:
+ for i, j in T.grid(16, 16):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+
+ Create the schedule and do transform_block_layout:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_transform_block_layout)
+ sch.transform_block_layout(sch.get_block("B"), lambda i, j: (i *
16 + j,))
+ print(sch.mod["main"].script())
+
+ After applying transform_block_layout, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def after_transform_block_layout(
+ A: T.Buffer[(16, 16), "float32"],
+ B: T.Buffer[(16, 16), "float32"]
+ ) -> None:
+ for i in range(256):
+ with T.block("B"):
+ vi, = T.axis.remap("S", [i])
+ B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0
+ """
+ if callable(index_map):
+ index_map = IndexMap.from_func(index_map)
+ _ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint:
disable=no-member
+ self, block, index_map
+ )
+
@type_checked
def set_axis_separator(
self,
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index c9c3d72ae0..0574cfefad 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -277,6 +277,17 @@ bool GetVarsTouchedByBlockIters(const BlockRealize&
block_realize,
std::unordered_set<const VarNode*>*
data_par_vars,
std::unordered_set<const VarNode*>*
reduce_vars);
+/******** Loop properties ********/
+/*!
+ * \brief Check the loop starts with zero.
+ * \param self The schedule state
+ * \param loop_sref The StmtSRef that points to the loop to be checked
+ * \param analyzer The arithmetic analyzer
+ * \throw ScheduleError If the loop doesn't starts with zero.
+ */
+void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef&
loop_sref,
+ arith::Analyzer* analyzer);
+
/******** Block-loop relation ********/
/*!
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index 4777ee2657..c4719015da 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -686,6 +686,35 @@ bool GetVarsTouchedByBlockIters(const BlockRealize&
block_realize,
return has_block_vars_of_other_types;
}
+/******** Loop properties ********/
+
+void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef&
loop_sref,
+ arith::Analyzer* analyzer) {
+ class LoopNotStartWithZeroError : public ScheduleError {
+ public:
+ explicit LoopNotStartWithZeroError(IRModule mod, For loop)
+ : mod_(mod), loop_(std::move(loop)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The primitive only supports loop starting with 0";
+ }
+
+ String DetailRenderTemplate() const final {
+ return "The loop {0} does not start with 0, which is not supported";
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
+
+ IRModule mod_;
+ For loop_;
+ };
+ const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+ if (!analyzer->CanProve(loop->min == 0)) {
+ throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
+ }
+}
+
/******** Block-loop relation ********/
Array<StmtSRef> GetChildBlockSRefOnSRefTree(const ScheduleState& self,
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index 7b953220f2..8066d85a8e 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -693,6 +693,14 @@ void ConcreteScheduleNode::TransformLayout(const BlockRV&
block_rv, int buffer_i
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
}
+void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv,
+ const IndexMap& index_map) {
+ TVM_TIR_SCHEDULE_BEGIN();
+ tir::TransformBlockLayout(state_, this->GetSRef(block_rv), index_map);
+ this->state_->DebugVerify();
+ TVM_TIR_SCHEDULE_END("transform_block_layout", this->error_render_level_);
+}
+
void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int
buffer_index,
BufferIndexType buffer_index_type,
const Array<IntImm>&
axis_separators) {
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index 9293aa3493..8e83aac2ce 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -134,6 +134,7 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Schedule: Layout transformation ********/
void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const IndexMap& index_map) override;
+ void TransformBlockLayout(const BlockRV& block_rv, const IndexMap&
index_map) override;
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) override;
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index d55b896934..50dedf71ff 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -442,6 +442,18 @@ TVM_DLL void Unannotate(ScheduleState self, const
StmtSRef& sref, const String&
TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref,
int buffer_index,
BufferIndexType buffer_index_type, const
IndexMap& index_map);
+/*!
+ * \brief Apply a transformation represented by IndexMap to block
+ * \details The block iters and the block body are transformed by the given
index_map.
+ * Outer loops corresponding to each new block iter are regenerated.
+ * The index_map is required to be bijective affine since we need its inverse
mapping.
+ * \param self The state of the schedule
+ * \param block_sref The block sref that refers to the block to be transformed
+ * \param index_map The transformation to apply.
+ */
+TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef&
block_sref,
+ const IndexMap& index_map);
+
/******** Schedule: Misc ********/
} // namespace tir
diff --git a/src/tir/schedule/primitive/layout_transformation.cc
b/src/tir/schedule/primitive/layout_transformation.cc
index cf95665ee8..6da796fc95 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -192,6 +192,269 @@ void TransformLayout(ScheduleState self, const StmtSRef&
block_sref, int buffer_
self->Replace(scope_sref, new_scope_block, block_sref_reuse);
}
+/*!
+ * \brief Detect the block iter type assoicated with the expression
+ *
+ * This function collects block iters in the expression and check if the block
iters have the same
+ * iter type. The detected iter type is the iter type of the block iters in
the expression
+ * if they have the same iter type, otherwise the detected iter type will be
kOpaque.
+ *
+ * \param expr The expression
+ * \param block_iter_type_map The mapping from block iter to iter type
+ * \return The detected block iter type
+ */
+IterVarType DetectNewBlockIterType(
+ const PrimExpr& expr,
+ const std::unordered_map<const VarNode*, IterVarType>&
block_iter_type_map) {
+ IterVarType result{kOpaque};
+ bool found = false;
+ PostOrderVisit(expr, [&](const ObjectRef& obj) {
+ if (const VarNode* var = obj.as<VarNode>()) {
+ auto it = block_iter_type_map.find(var);
+ if (it != block_iter_type_map.end()) {
+ if (!found) {
+ found = true;
+ result = it->second;
+ } else if (result != it->second) {
+ result = kOpaque;
+ return false;
+ }
+ }
+ }
+ return true;
+ });
+ return result;
+}
+
+class NotBijectiveAffineIndexMapError : public ScheduleError {
+ public:
+ NotBijectiveAffineIndexMapError(IRModule mod, IndexMap index_map)
+ : mod_(std::move(mod)), index_map_(std::move(index_map)) {}
+ String FastErrorString() const final {
+ return "ScheduleError: The index map is not bijective affine.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The index map " << index_map_->ToPythonString() << " is not
bijective affine.";
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+
+ Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ private:
+ IRModule mod_;
+ IndexMap index_map_;
+};
+
+class IndexMapNotApplicableToBlockIterError : public ScheduleError {
+ public:
+ static void Check(const IRModule mod, const Block& block, const IndexMap&
index_map) {
+ if (index_map->initial_indices.size() != block->iter_vars.size()) {
+ throw IndexMapNotApplicableToBlockIterError(mod, block, index_map);
+ }
+ }
+ explicit IndexMapNotApplicableToBlockIterError(IRModule mod, Block block,
IndexMap index_map)
+ : mod_(std::move(mod)), block_(std::move(block)),
index_map_(std::move(index_map)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: The index map can't be applied to block iters
because the number of "
+ "parameters mismatch.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The index map " << index_map_->ToPythonString()
+ << " can't be applied to block iters of {0} because the number of
parameters mismatch. "
+ "Expected: "
+ << index_map_->initial_indices.size() << ", actual: " <<
block_->iter_vars.size();
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+
+ Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+ private:
+ IRModule mod_;
+ Block block_;
+ IndexMap index_map_;
+};
+
+class NotTrivialBindingError : public ScheduleError {
+ public:
+ explicit NotTrivialBindingError(IRModule mod, Block block)
+ : mod_(std::move(mod)), block_(std::move(block)) {}
+
+ static void CheckBlockHasTrivialBinding(const IRModule& mod, const
BlockRealize& block_realize,
+ std::unordered_set<const VarNode*>
outer_loop_vars) {
+ // Step 2: Check all the binding values are loops vars
+ for (const PrimExpr& iter_value : block_realize->iter_values) {
+ const VarNode* loop_var = iter_value.as<VarNode>();
+ if (!loop_var || !outer_loop_vars.count(loop_var)) {
+ throw NotTrivialBindingError(mod, block_realize->block);
+ }
+ }
+ }
+
+ String FastErrorString() const final {
+ return "ScheduleError: The binding values of the block are not variables
of outer loops.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "The binding values of the {0} are not variables of outer loops.";
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+ IRModule mod_;
+ Block block_;
+};
+
+class OpaqueNewIterTypeError : public ScheduleError {
+ public:
+ explicit OpaqueNewIterTypeError(IRModule mod, Block block, PrimExpr
iter_value)
+ : mod_(std::move(mod)), block_(std::move(block)),
iter_value_(std::move(iter_value)) {}
+
+ String FastErrorString() const final {
+ return "ScheduleError: Cannot detect the new block iter type because it
contains more than one "
+ "type of original iter vars.";
+ }
+
+ String DetailRenderTemplate() const final {
+ std::ostringstream os;
+ os << "Cannot detect the block iter type for new iter value " <<
PrettyPrint(iter_value_)
+ << " in {0} because it contains more than one type of original iter
vars.";
+ return os.str();
+ }
+
+ IRModule mod() const final { return mod_; }
+ Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+ IRModule mod_;
+ Block block_;
+ PrimExpr iter_value_;
+};
+
+void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
+ const IndexMap& index_map) {
+ const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
+ const Block& block = GetRef<Block>(block_ptr);
+ arith::Analyzer analyzer;
+
+ // Step 1: Collect outer loops and loop vars
+ Array<StmtSRef> loops = GetLoops(block_sref); // outer loops of the block
+ std::unordered_set<const VarNode*> loop_vars; // loop vars of the outer
loops
+ for (const StmtSRef& loop_sref : loops) {
+ CheckLoopStartsWithZero(self, loop_sref, &analyzer);
+ loop_vars.emplace(loop_sref->StmtAs<ForNode>()->loop_var.get());
+ }
+
+ // Step 2: Check the all outer loops have a single child and the block
bindings are trivial (all
+ // binding values are loop vars)
+ StmtSRef scope_sref{nullptr}; // the scope statement for replacement
+ if (!loops.empty()) {
+ scope_sref = loops.front();
+ CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front());
+ } else {
+ scope_sref = block_sref;
+ }
+
+ BlockRealize block_realize = GetBlockRealize(self, block_sref);
+ NotTrivialBindingError::CheckBlockHasTrivialBinding(self->mod,
block_realize, loop_vars);
+
+ // Step 3: Collect information of block iter vars
+ Array<PrimExpr> block_vars; // iter_var->var of each block iter
+ Map<Var, Range> block_iter_dom; // domain of block iter
+ std::unordered_map<const VarNode*, IterVarType> block_iter_type; // iter
type of block iter
+
+ Array<PrimExpr>
+ block_iter_range_array; // array of block iter extents in the same
order as block iters
+ for (const auto& iter_var : block->iter_vars) {
+ block_vars.push_back(iter_var->var);
+ block_iter_dom.Set(iter_var->var, iter_var->dom);
+ block_iter_type[iter_var->var.get()] = iter_var->iter_type;
+ ICHECK(is_zero(iter_var->dom->min));
+ block_iter_range_array.push_back(iter_var->dom->extent);
+ }
+
+ // Step 4: Apply the IndexMap to block iters.
+ IndexMapNotApplicableToBlockIterError::Check(self->mod, block, index_map);
+ Array<PrimExpr> transformed_block_iters = index_map->MapIndices(block_vars);
+ Array<PrimExpr> new_block_iter_range =
index_map->MapShape(block_iter_range_array);
+
+ auto iter_map = arith::DetectIterMap(
+ /*indices=*/transformed_block_iters, /*input_iters=*/block_iter_dom,
/*predicate=*/Bool(true),
+ /*require_bijective=*/true, &analyzer,
/*simplify_trivial_iterators=*/true);
+ if (iter_map.empty()) {
+ throw NotBijectiveAffineIndexMapError(self->mod, index_map);
+ }
+
+ // Step 5: Create the new block after transformation.
+
+ // Step 5.1: Create new block iters. After applying the IndexMap f to block
iters ax_0, ..., ax_n,
+ // create block iter each expression in f(ax_0, ..., ax_n).
+ Array<IterVar> new_block_iters; // new block iters
+ Array<PrimExpr> new_block_vars; // iter_var->var of new block iters
+ for (size_t i = 0; i < index_map->final_indices.size(); ++i) {
+ Var new_block_var{"v" + std::to_string(i), DataType::Int(32)};
+ new_block_vars.push_back(new_block_var);
+ IterVarType iter_type = DetectNewBlockIterType(transformed_block_iters[i],
block_iter_type);
+ if (iter_type == kOpaque) {
+ throw OpaqueNewIterTypeError(self->mod, GetRef<Block>(block_ptr),
transformed_block_iters[i]);
+ }
+ new_block_iters.push_back(IterVar(/*dom=*/Range::FromMinExtent(0,
new_block_iter_range[i]),
+ /*var=*/std::move(new_block_var),
/*iter_type=*/iter_type));
+ }
+
+ // Step 5.2: Update the block body. Use the inverse map f^{-1} to replace
the original block iters
+ // in the body.
+
+ auto inverse_map = arith::InverseAffineIterMap(iter_map, new_block_vars);
+ // Trivial block iters will be simplified in DetectIterMap, they should be
mapped to constant
+ // zero.
+ for (const auto& iter_var : block_ptr->iter_vars) {
+ if (inverse_map.find(iter_var->var) == inverse_map.end()) {
+ ICHECK(is_one(iter_var->dom->extent));
+ inverse_map.Set(iter_var->var, 0);
+ }
+ }
+
+ Block new_block = Downcast<Block>(Substitute(GetRef<Block>(block_ptr),
inverse_map));
+ new_block.CopyOnWrite()->iter_vars = new_block_iters;
+ new_block = Downcast<Block>(BlockBufferAccessSimplifier::Simplify(new_block,
&analyzer));
+
+ // Step 5.3: Create outer loops for each new block iter.
+
+ // Make new loop vars
+ Array<PrimExpr> new_loop_vars;
+ for (int i = 0; i < static_cast<int>(new_block_iters.size()); ++i) {
+ new_loop_vars.push_back(Var("ax" + std::to_string(i), DataType::Int(32)));
+ }
+
+ // Make new block realize
+ BlockRealizeNode* new_block_realize = block_realize.CopyOnWrite();
+ new_block_realize->iter_values = new_loop_vars;
+ new_block_realize->block = new_block;
+
+ // Generate outer loops
+ Stmt body = GetRef<Stmt>(new_block_realize);
+ for (int i = static_cast<int>(new_loop_vars.size()) - 1; i >= 0; --i) {
+ body = For(Downcast<Var>(new_loop_vars[i]), 0, new_block_iter_range[i],
ForKind::kSerial,
+ std::move(body));
+ }
+
+ // Step 6: Do the actual replacement
+ self->Replace(scope_sref, body, {{block, new_block}});
+}
+
class BufferAxisSeparatorMutator : private ReplaceBufferMutator {
public:
static Block Mutate(const Block& scope_block, const Buffer& old_buffer,
Buffer new_buffer,
@@ -270,6 +533,7 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef&
block_sref, int buffer
// Step 4: Replace the scope block with the new block
self->Replace(scope_sref, new_scope_block, block_sref_reuse);
}
+
/******** InstructionKind Registration ********/
struct TransformLayoutTraits : public
UnpackedInstTraits<TransformLayoutTraits> {
@@ -324,6 +588,45 @@ struct TransformLayoutTraits : public
UnpackedInstTraits<TransformLayoutTraits>
friend struct ::tvm::tir::UnpackedInstTraits;
};
+struct TransformBlockLayoutTraits : public
UnpackedInstTraits<TransformBlockLayoutTraits> {
+ static constexpr const char* kName = "TransformBlockLayout";
+ static constexpr bool kIsPure = false;
+
+ private:
+ static constexpr size_t kNumInputs = 1;
+ static constexpr size_t kNumAttrs = 1;
+ static constexpr size_t kNumDecisions = 0;
+
+ static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap
index_map) {
+ return sch->TransformBlockLayout(block_rv, index_map);
+ }
+
+ static String UnpackedAsPython(Array<String> outputs, String block_rv,
IndexMap index_map) {
+ PythonAPICall py("transform_block_layout");
+ py.Input("block", block_rv);
+ py.Input("index_map", index_map->ToPythonString());
+ return py.Str();
+ }
+
+ public:
+ static ObjectRef AttrsAsJSON(const Array<ObjectRef>& attrs) {
+ Array<ObjectRef> attrs_record;
+ attrs_record.reserve(kNumAttrs);
+ attrs_record.push_back(String(::tvm::SaveJSON(attrs[0])));
+ return std::move(attrs_record);
+ }
+
+ static Array<ObjectRef> AttrsFromJSON(const ObjectRef& attrs_record_) {
+ Array<ObjectRef> attrs_record = Downcast<Array<ObjectRef>>(attrs_record_);
+ Array<ObjectRef> attrs;
+ attrs.push_back(::tvm::LoadJSON(Downcast<String>(attrs_record[0])));
+ return attrs;
+ }
+
+ template <typename>
+ friend struct ::tvm::tir::UnpackedInstTraits;
+};
+
struct SetAxisSeparatorTraits : public
UnpackedInstTraits<SetAxisSeparatorTraits> {
static constexpr const char* kName = "SetAxisSeparator";
static constexpr bool kIsPure = false;
@@ -359,6 +662,7 @@ struct SetAxisSeparatorTraits : public
UnpackedInstTraits<SetAxisSeparatorTraits
};
TVM_REGISTER_INST_KIND_TRAITS(TransformLayoutTraits);
+TVM_REGISTER_INST_KIND_TRAITS(TransformBlockLayoutTraits);
TVM_REGISTER_INST_KIND_TRAITS(SetAxisSeparatorTraits);
} // namespace tir
diff --git a/src/tir/schedule/primitive/loop_transformation.cc
b/src/tir/schedule/primitive/loop_transformation.cc
index d64a72ed34..dbe6a3bbc0 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -250,25 +250,6 @@ class NotOnlyChildError : public ScheduleError {
For inner_;
};
-class LoopNotStartWithZeroError : public ScheduleError {
- public:
- explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod),
loop_(std::move(loop)) {}
-
- String FastErrorString() const final {
- return "ScheduleError: The primitive only supports loop starting with 0";
- }
-
- String DetailRenderTemplate() const final {
- return "The loop {0} does not start with 0, which is not supported";
- }
-
- IRModule mod() const final { return mod_; }
- Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
-
- IRModule mod_;
- For loop_;
-};
-
class NotSingleInferFactorError : public ScheduleError {
public:
explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}
@@ -407,10 +388,8 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef&
loop_sref,
}
// Currently, loops not starting with 0 are not supported
arith::Analyzer analyzer;
- if (!analyzer.CanProve(loop->min == 0)) {
- throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
- }
- // Step 2. Replace all occurrences of the original loop var with new
variables
+ CheckLoopStartsWithZero(self, loop_sref, &analyzer);
+
int n = factors.size();
PrimExpr substitute_value = 0;
std::vector<Var> new_loop_vars;
@@ -482,9 +461,7 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>&
loop_srefs) {
}
outer_loop_sref = sref;
outer_loop = loop;
- if (!analyzer.CanProve(loop->min == 0)) {
- throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
- }
+ CheckLoopStartsWithZero(self, sref, &analyzer);
const VarNode* used_var = nullptr;
auto f_contain = [&outer_loop_vars, &used_var](const VarNode* var) {
if (outer_loop_vars.count(var)) {
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 8dc0c52111..fb884ce77f 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -233,6 +233,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout")
return self->TransformLayout(block_rv, buffer_index,
static_cast<BufferIndexType>(buffer_index_type), index_map);
});
+TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout")
+ .set_body_method<Schedule>(&ScheduleNode::TransformBlockLayout);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetAxisSeparator")
.set_body_typed([](Schedule self, const BlockRV& block_rv, int
buffer_index,
int buffer_index_type, const Array<IntImm>&
axis_separators) {
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index 865b6f3784..8156480a45 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -442,6 +442,16 @@ void TracedScheduleNode::TransformLayout(const BlockRV&
block_rv, int buffer_ind
/*outputs=*/{}));
}
+void TracedScheduleNode::TransformBlockLayout(const BlockRV& block_rv, const
IndexMap& index_map) {
+ ConcreteScheduleNode::TransformBlockLayout(block_rv, index_map);
+ static const InstructionKind& kind =
InstructionKind::Get("TransformBlockLayout");
+ trace_->Append(
+ /*inst=*/Instruction(/*kind=*/kind,
+ /*inputs=*/{block_rv},
+ /*attrs=*/{index_map},
+ /*outputs=*/{}));
+}
+
void TracedScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int
buffer_index,
BufferIndexType buffer_index_type,
const Array<IntImm>&
axis_separators) {
diff --git a/src/tir/schedule/traced_schedule.h
b/src/tir/schedule/traced_schedule.h
index 12c076d886..d1860be951 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -98,6 +98,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
/******** Schedule: Layout transformation ********/
void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const IndexMap& index_map) override;
+ void TransformBlockLayout(const BlockRV& block_rv, const IndexMap&
index_map) override;
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) final;
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index 6c4f3e1b7a..79802ecd65 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -280,5 +280,36 @@ Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule&
sch, const tir::Block
TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin);
+/******** BlockBufferAccessSimplifier ********/
+void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array<BufferRegion>*
old_access_regions) {
+ auto fmutate = [this](const BufferRegion& buffer_region) {
+ std::vector<Range> new_buffer_region;
+ for (const auto& range : buffer_region->region) {
+
new_buffer_region.push_back(Range::FromMinExtent(analyzer_->Simplify(range->min),
+
analyzer_->Simplify(range->extent)));
+ }
+ return BufferRegion(buffer_region->buffer, new_buffer_region);
+ };
+ (*old_access_regions).MutateByApply(fmutate);
+}
+
+Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) {
+ Block block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
+ auto* n = block.CopyOnWrite();
+ SimplifyAccessRegion(&n->reads);
+ SimplifyAccessRegion(&n->writes);
+ return std::move(block);
+}
+
+Stmt BlockBufferAccessSimplifier::VisitStmt_(const BufferStoreNode* op) {
+ auto node =
Downcast<BufferStore>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
+ return VisitBufferAccess(std::move(node));
+}
+
+PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) {
+ auto node =
Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));
+ return VisitBufferAccess(std::move(node));
+}
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h
index 52e27350d4..192d44d9e9 100644
--- a/src/tir/schedule/transform.h
+++ b/src/tir/schedule/transform.h
@@ -26,6 +26,7 @@
#include <unordered_map>
#include <utility>
+#include "../../arith/ir_mutator_with_analyzer.h"
#include "../ir/functor_common.h"
namespace tvm {
@@ -172,6 +173,44 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const
StmtSRef& leaf_block_
Optional<tir::LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const
tir::BlockRV& block_rv,
const String& intrin_name);
+/******** Block mutation ********/
+
+/*!
+ * \brief Simplifier for indices of buffer access and block buffer access
regions.
+ */
+class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer {
+ public:
+ /*!
+ * \brief Simplify indices of buffer access and block buffer access regions
in the statement
+ * \param stmt The statement to be simplified
+ * \param analyzer The arithmetic analyzer
+ * \return The simplified statement
+ */
+ static Stmt Simplify(const Stmt& stmt, arith::Analyzer* analyzer) {
+ BlockBufferAccessSimplifier simplifier(analyzer);
+ return simplifier(stmt);
+ }
+
+ private:
+ explicit BlockBufferAccessSimplifier(arith::Analyzer* analyzer)
+ : IRMutatorWithAnalyzer(analyzer) {}
+
+ using IRMutatorWithAnalyzer::VisitExpr_;
+ using IRMutatorWithAnalyzer::VisitStmt_;
+
+ void SimplifyAccessRegion(Array<BufferRegion>* old_access_regions);
+ Stmt VisitStmt_(const BlockNode* op) final;
+ Stmt VisitStmt_(const BufferStoreNode* op) final;
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final;
+
+ template <typename Node>
+ Node VisitBufferAccess(Node node) {
+ node.CopyOnWrite()->indices.MutateByApply(
+ [this](const PrimExpr& expr) { return analyzer_->Simplify(expr); });
+ return node;
+ }
+};
+
} // namespace tir
} // namespace tvm
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py
b/tests/python/unittest/test_tir_schedule_transform_layout.py
index 699eaf1236..e184bc3f62 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -91,6 +91,83 @@ def two_elementwise_transformed_output_buffer(
C[vi // 16, vj // 16, vi % 16, vj % 16] = B[vi, vj] + 1.0
[email protected]_func
+def elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128),
"float32"]) -> None:
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+
+
[email protected]_func
+def elementwise_transformed(A: T.Buffer[(128, 128), "float32"], B:
T.Buffer[(128, 128), "float32"]) -> None:
+ for i in range(16384):
+ with T.block("B"):
+ vi, = T.axis.remap("S", [i])
+ B[vi // 128, vi % 128] = A[vi // 128, vi % 128] * 2.0
+
+
[email protected]_func
+def conv2d_nhwc(
+ Input: T.Buffer[(1, 224, 224, 3), "float32"],
+ Weight: T.Buffer[(7, 7, 3, 64), "float32"],
+ Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"],
+) -> None:
+ PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
+ for i0, i1, i2, i3 in T.grid(1, 230, 230, 3):
+ with T.block("PadInput"):
+ i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+ ((((i1_1 >= 3) and (i1_1 < 227)) and (i2_1 >= 3)) and (i2_1 <
227)),
+ Input[i0_1, (i1_1 - 3), (i2_1 - 3), i3_1],
+ T.float32(0),
+ dtype="float32",
+ )
+ for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 112, 112, 64, 7, 7, 3):
+ with T.block("conv2d_nhwc"):
+ n, h, w, co, rh, rw, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3,
i4, i5, i6])
+ with T.init():
+ Conv2d_nhwc[n, h, w, co] = T.float32(0)
+ Conv2d_nhwc[n, h, w, co] = Conv2d_nhwc[n, h, w, co] + (
+ PadInput[n, ((h * 2) + rh), ((w * 2) + rw), ((T.floordiv(co,
64) * 3) + rc)]
+ * Weight[rh, rw, rc, co]
+ )
+
+
[email protected]_func
+def conv2d_nhwc_transformed(
+ Input: T.Buffer[(1, 224, 224, 3), "float32"],
+ Weight: T.Buffer[(7, 7, 3, 64), "float32"],
+ Conv2d_nhwc: T.Buffer[(1, 112, 112, 64), "float32"],
+) -> None:
+ PadInput = T.alloc_buffer([1, 230, 230, 3], dtype="float32")
+ for i0, i1, i2, i3 in T.grid(1, 230, 230, 3):
+ with T.block("PadInput"):
+ i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1])
+ T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
+ PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
+ i1_1 >= 3 and i1_1 < 227 and i2_1 >= 3 and i2_1 < 227,
+ Input[i0_1, i1_1 - 3, i2_1 - 3, i3_1],
+ T.float32(0),
+ dtype="float32",
+ )
+ for ax0, ax_1, ax_2 in T.grid(12544, 64, 147):
+ with T.block("conv2d_nhwc"):
+ bv0, bv1, bv2 = T.axis.remap("SSR", [ax0, ax_1, ax_2])
+ T.reads(
+ PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2 %
21 // 3, bv2 % 3],
+ Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1],
+ )
+ T.writes(Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1])
+ with T.init():
+ Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = T.float32(0)
+ Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1] = (
+ Conv2d_nhwc[0, bv0 // 112, bv0 % 112, bv1]
+ + PadInput[0, bv0 // 112 * 2 + bv2 // 21, bv0 % 112 * 2 + bv2
% 21 // 3, bv2 % 3]
+ * Weight[bv2 // 21, bv2 % 21 // 3, bv2 % 3, bv1]
+ )
+
# pylint:
enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
# fmt: on
@@ -218,5 +295,41 @@ def test_var_args_sugar():
tvm.ir.assert_structural_equal(summation_3d_split, sch.mod["main"])
+def test_transform_block_layout_basic():
+ sch = tir.Schedule(elementwise, debug_mask="all")
+ block = sch.get_block("B")
+ sch.transform_block_layout(block, lambda i, j: (i * 128 + j,))
+ tvm.ir.assert_structural_equal(elementwise_transformed, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=elementwise)
+
+
+def test_transform_block_layout_conv2d_nhwc():
+ sch = tir.Schedule(conv2d_nhwc, debug_mask="all")
+ block = sch.get_block("conv2d_nhwc")
+ sch.transform_block_layout(
+ block,
+ lambda n, h, w, co, rh, rw, rc: (n * 112 * 112 + h * 112 + w, co, rh *
7 * 3 + rw * 3 + rc),
+ )
+ tvm.ir.assert_structural_equal(conv2d_nhwc_transformed, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=conv2d_nhwc)
+
+
+def test_transform_block_layout_fail_non_affine():
+ sch = tir.Schedule(elementwise, debug_mask="all")
+ block = sch.get_block("B")
+ with pytest.raises(tir.ScheduleError):
+ sch.transform_block_layout(block, lambda i, j: (i + j,))
+
+
+def test_transform_block_layout_fail_mixed_iter_type():
+ sch = tir.Schedule(conv2d_nhwc, debug_mask="all")
+ block = sch.get_block("conv2d_nhwc")
+ with pytest.raises(tir.ScheduleError):
+ sch.transform_block_layout(
+ block,
+ lambda n, h, w, co, rh, rw, rc: (n * 112 * 112 + h * 112 + w, co *
7 + rh, rw * 3 + rc),
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()