This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2f863dda62 [TIR][Schedule] Improve blockize to support blockizing
multiple blocks (#14766)
2f863dda62 is described below
commit 2f863dda62c80ce16bbbb608a0bc7afe3589a864
Author: multiverstack <[email protected]>
AuthorDate: Tue May 16 17:42:14 2023 +0800
[TIR][Schedule] Improve blockize to support blockizing multiple blocks
(#14766)
* Improve blockize to support blockize multiple blocks
* Adjust unit test to match simplified blockize result.
* Update doc
* Preserve unit iters in expr and revert test case change
* Apply review suggestion
---------
Co-authored-by: Min Chen <[email protected]>
---
include/tvm/tir/schedule/schedule.h | 7 +
python/tvm/tir/schedule/schedule.py | 12 +-
src/tir/schedule/concrete_schedule.cc | 9 +
src/tir/schedule/concrete_schedule.h | 1 +
src/tir/schedule/primitive.h | 10 +
src/tir/schedule/primitive/blockize_tensorize.cc | 287 +++++++++++++++++++--
src/tir/schedule/schedule.cc | 9 +-
src/tir/schedule/traced_schedule.cc | 11 +
src/tir/schedule/traced_schedule.h | 1 +
.../metaschedule_e2e/test_resnet50_int8.py | 2 +-
.../unittest/test_meta_schedule_trace_apply.py | 18 +-
.../python/unittest/test_tir_schedule_blockize.py | 49 ++++
12 files changed, 373 insertions(+), 43 deletions(-)
diff --git a/include/tvm/tir/schedule/schedule.h
b/include/tvm/tir/schedule/schedule.h
index 69f0520117..187d0a31d0 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -641,6 +641,13 @@ class ScheduleNode : public runtime::Object {
* \return the new block
*/
virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters =
true) = 0;
+ /*!
+ * \brief Convert specified blocks into a nested block.
+ * \param blocks the specified block to construct the new block
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in
block bindings
+ * \return the new block
+ */
+ virtual BlockRV Blockize(const Array<BlockRV>& blocks, bool
preserve_unit_iters = true) = 0;
/*!
* \brief Tensorize the computation enclosed by loop with the tensor intrin.
* \param loop_rv The loop to be tensorized
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 7c7af998be..8ebc02ccbb 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2691,13 +2691,15 @@ class Schedule(Object):
########## Schedule: Blockize & Tensorize ##########
@type_checked
- def blockize(self, loop: LoopRV, preserve_unit_iters: bool = True) ->
BlockRV:
- """Convert the subtree rooted at a specific loop into a block.
+ def blockize(
+ self, target: Union[LoopRV, List[BlockRV]], preserve_unit_iters: bool
= True
+ ) -> BlockRV:
+ """Convert multiple blocks or the subtree rooted at a specific loop
into a block.
Parameters
----------
- loop : LoopRV
- The root of the subtree.
+ target : LoopRV or List[BlockRV]
+ The root of the subtree or the specified blocks.
preserve_unit_iters : bool
Whether or not to preserve unit iterators in block bindings
@@ -2764,7 +2766,7 @@ class Schedule(Object):
block are divisible by the subspace represented by the loops starting
at the given loop.
"""
- return _ffi_api.ScheduleBlockize(self, loop, preserve_unit_iters) #
type: ignore # pylint: disable=no-member
+ return _ffi_api.ScheduleBlockize(self, target, preserve_unit_iters) #
type: ignore # pylint: disable=no-member
@type_checked
def tensorize(
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index 7192a48099..d485127242 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -791,6 +791,15 @@ BlockRV ConcreteScheduleNode::Blockize(const LoopRV&
loop_rv, bool preserve_unit
return CreateRV<BlockRV>(result);
}
+BlockRV ConcreteScheduleNode::Blockize(const Array<BlockRV>& blocks, bool
preserve_unit_iters) {
+ StmtSRef result{nullptr};
+ TVM_TIR_SCHEDULE_BEGIN();
+ result = tir::Blockize(state_, this->GetSRefs(blocks), preserve_unit_iters);
+ this->state_->DebugVerify();
+ TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_);
+ return CreateRV<BlockRV>(result);
+}
+
void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String&
intrin,
bool preserve_unit_iters) {
TVM_TIR_SCHEDULE_BEGIN();
diff --git a/src/tir/schedule/concrete_schedule.h
b/src/tir/schedule/concrete_schedule.h
index 16065df3cd..73a0b314dd 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -153,6 +153,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String&
dtype) override;
/******** Schedule: Blockize & Tensorize ********/
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override;
+ BlockRV Blockize(const Array<BlockRV>& blocks, bool preserve_unit_iters)
override;
void Tensorize(const BlockRV& block_rv, const String& intrin, bool
preserve_unit_iters) override;
void Tensorize(const LoopRV& loop_rv, const String& intrin, bool
preserve_unit_iters) override;
/******** Schedule: Annotation ********/
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 78d1cab05c..7355d38db1 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -542,6 +542,16 @@ TVM_DLL void SetAxisSeparator(ScheduleState self, const
StmtSRef& block_sref, in
*/
TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool
preserve_unit_iters);
+/*!
+ * \brief Convert specific blocks into a nested block.
+ * \param self The state of the schedule
+ * \param blocks The target blocks to construct the new block
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in
block bindings
+ * \return The new block
+ */
+TVM_DLL StmtSRef Blockize(ScheduleState self, const Array<StmtSRef>& blocks,
+ bool preserve_unit_iters);
+
/*!
* \brief Tensorize the computation enclosed by loop with the tensor intrinsic.
* \param self The state of the schedule
diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc
b/src/tir/schedule/primitive/blockize_tensorize.cc
index 25694ed6fc..994a3a95fb 100644
--- a/src/tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -139,20 +139,26 @@ Array<Array<arith::IterMark>>
TrivialSubspaceDivision(const Array<IterVar>& iter
/*!
* \brief Subspace division. The space is divided into two subspaces:
+ * If loop_sref_as_outer is false:
* 1. The subspace represented by the outer loops above `loop_sref`
(exclusive).
* 2. The subspace represented by the inner loops below `loop_sref`
(inclusive).
+ * else:
+ * 1. The subspace represented by the outer loops above `loop_sref`
(inclusive).
+ * 2. The subspace represented by the inner loops below `loop_sref`
(exclusive).
* \param realize The inner block
* \param block_sref The sref to the inner block
* \param loop_sref The loop that is the root of the second subspace.
* \param loops The loops that represents the second part of the subspace.
* \param analyzer The arithmetic analyzer to use.
* \param preserve_unit_iters Whether or not to preserve unit iterators in
block bindings
+ * \param loop_sref_as_outer Whether loop_sref is divided into outer or inner
*/
Array<Array<arith::IterMark>> SubspaceDivide(const BlockRealize& realize,
const StmtSRef& block_sref, //
const StmtSRef& loop_sref, //
std::vector<const ForNode*>*
loops,
- arith::Analyzer* analyzer, bool
preserve_unit_iters) {
+ arith::Analyzer* analyzer, bool
preserve_unit_iters,
+ bool loop_sref_as_outer = false) {
Array<Var> inner_vars;
Array<Var> outer_vars;
Map<Var, Range> loop_var_domain;
@@ -168,7 +174,7 @@ Array<Array<arith::IterMark>> SubspaceDivide(const
BlockRealize& realize,
outer_vars.push_back(loop->loop_var);
}
loop_var_domain.Set(loop->loop_var, Range::FromMinExtent(loop->min,
loop->extent));
- if (sref == loop_sref.get()) {
+ if ((loop_sref_as_outer && sref->parent == loop_sref.get()) || sref ==
loop_sref.get()) {
inner = false;
}
}
@@ -201,12 +207,14 @@ Map<Var, PrimExpr> DeriveBlockBinding(const
Array<IterVar>& iter_vars,
Array<IterVar>* outer_iter_vars,
//
Array<PrimExpr>* outer_bindings,
//
Array<IterVar>* inner_iter_vars,
//
- Array<PrimExpr>* inner_bindings, bool
preserve_unit_iters) {
+ Array<PrimExpr>* inner_bindings,
//
+ bool preserve_unit_iters, bool
reuse_outer = false) {
using arith::IterMapExpr;
using arith::IterMapExprNode;
using arith::NormalizeIterMapToExpr;
Map<Var, PrimExpr> block_var_subst;
ICHECK_EQ(iter_vars.size() + 1, division.size());
+ arith::Analyzer ana;
for (int i = 0, n = iter_vars.size(); i < n; ++i) {
const IterVar& iter_var = iter_vars[i];
arith::IterMark outer_mark = division[i][0];
@@ -219,30 +227,43 @@ Map<Var, PrimExpr> DeriveBlockBinding(const
Array<IterVar>& iter_vars,
// The inner block will have binding: iter_inner -> inner_binding
// The iter in the original block will be substituted with base +
iter_inner where
// base == iter_outer * iter_inner_extent
- if (is_one(inner_mark->extent)) { // IsOuter
- // extract this iter var to outer block directly
+ // create iter var for the outer block
+ IterVar outer_iter;
+ if (reuse_outer) {
+ outer_iter = outer_iter_vars->operator[](i);
+ ICHECK(ana.CanProveEqual(outer_iter->dom->extent, outer_mark->extent));
+ ICHECK(
+ ana.CanProveEqual(outer_bindings->operator[](i),
NormalizeIterMapToExpr(outer_binding)));
+ } else {
+ outer_iter = IterVar(/*dom=*/RangeFromExtent(outer_mark->extent),
+ /*var=*/iter_var->var.copy_with_suffix("_o"),
+ /*iter_type=*/iter_var->iter_type);
outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
- outer_iter_vars->push_back(iter_var);
- continue;
+ outer_iter_vars->push_back(outer_iter);
}
- // create iter var for the outer block
- IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent),
- /*var=*/iter_var->var.copy_with_suffix("_o"),
- /*iter_type=*/iter_var->iter_type);
- outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
- outer_iter_vars->push_back(outer_iter);
- // create iter var for the inner block
- IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent),
- /*var=*/iter_var->var.copy_with_suffix("_i"),
- /*iter_type=*/iter_var->iter_type);
- inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding));
- inner_iter_vars->push_back(inner_iter);
- // substitution
PrimExpr sub{nullptr};
- if (is_one(outer_mark->extent)) {
- sub = inner_iter->var;
+ if (is_one(inner_mark->extent)) {
+ // Skip inner var when extent is 1
+ // substitution
+ if (is_one(outer_mark->extent) && !preserve_unit_iters) {
+ // Simplify outer if not preserve_unit_iters
+ sub = make_zero(outer_mark->extent.dtype());
+ } else {
+ sub = outer_iter;
+ }
} else {
- sub = outer_iter * inner_mark->extent + inner_iter->var;
+ // create iter var for the inner block
+ IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent),
+ /*var=*/iter_var->var.copy_with_suffix("_i"),
+ /*iter_type=*/iter_var->iter_type);
+ inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding));
+ inner_iter_vars->push_back(inner_iter);
+ // substitution
+ if (is_one(outer_mark->extent)) {
+ sub = inner_iter->var;
+ } else {
+ sub = outer_iter * inner_mark->extent + inner_iter->var;
+ }
}
block_var_subst.Set(iter_var->var, sub);
}
@@ -414,6 +435,37 @@ Array<BufferRegion> EvalSetRegions(const
Array<BufferRegion>& regions,
return results;
}
+/*!
+ * \brief Get the union of the given regions
+ * \param regions The input regions for the union.
+ * \return The union regions
+ */
+Array<BufferRegion> UnionRegions(const Array<BufferRegion>& regions) {
+ typedef std::vector<Array<arith::IntSet>> ranges_t;
+ std::unordered_map<Buffer, ranges_t, ObjectPtrHash, ObjectPtrEqual>
intset_map;
+ for (const BufferRegion& buffer_region : regions) {
+ const Buffer& buffer = buffer_region->buffer;
+ if (intset_map.find(buffer) == intset_map.end()) {
+ intset_map[buffer] = {buffer->shape.size(), Array<arith::IntSet>()};
+ }
+ std::vector<Array<arith::IntSet>> dim_range(buffer->shape.size(),
Array<arith::IntSet>());
+ for (size_t dim = 0; dim < buffer->shape.size(); ++dim) {
+
intset_map[buffer][dim].push_back(arith::IntSet::FromRange(buffer_region->region[dim]));
+ }
+ }
+ Array<BufferRegion> results;
+ for (const auto& it : intset_map) {
+ const Buffer& buffer = it.first;
+ Array<Range> regions;
+ for (size_t dim = 0; dim < buffer->shape.size(); ++dim) {
+ const arith::IntSet intset = arith::Union(it.second[dim]);
+ regions.push_back({intset.min(), intset.max() + 1});
+ }
+ results.push_back(BufferRegion(buffer, regions));
+ }
+ return results;
+}
+
/*!
* \brief Create the loop nest on top of the given stmt.
* \param stmt The stmt to be wrapped.
@@ -513,6 +565,181 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef&
loop_sref, bool preserve_u
return result;
}
+BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>&
block_srefs,
+ const StmtSRef& lca, Map<Block, Block>*
block_sref_reuse,
+ bool preserve_unit_iters) {
+ Array<Stmt> seq_body;
+ PrimExpr outer_predicate{nullptr};
+ Array<IterVar> outer_iter_vars{nullptr};
+ Array<PrimExpr> outer_bindings{nullptr};
+ Array<BufferRegion> read_regions;
+ Array<BufferRegion> write_regions;
+ std::string outer_block_name = "outer_";
+ Map<Var, Var> loop_var_subst;
+ arith::Analyzer analyzer;
+ for (const auto& block_sref : block_srefs) {
+ auto block_realize = GetBlockRealize(self, block_sref);
+ auto block = block_realize->block;
+ // Step 1: Derive subspace division
+ std::vector<const ForNode*> loops;
+ Array<Array<arith::IterMark>> division = SubspaceDivide(block_realize,
block_sref, lca, &loops,
+ &analyzer,
preserve_unit_iters, true);
+ if (division.empty()) {
+ throw SubspaceNotDivisibleError(self->mod, GetRef<For>(loops.back()),
block);
+ }
+ outer_predicate = division.back()[0]->extent;
+ PrimExpr inner_predicate = division.back()[1]->extent;
+ // Step 2. Derive block bindings for both outer and inner block.
+ Array<IterVar> inner_iter_vars;
+ Array<PrimExpr> inner_bindings;
+ Map<Var, PrimExpr> block_var_subst = //
+ DeriveBlockBinding(block->iter_vars, division, //
+ &outer_iter_vars, &outer_bindings, //
+ &inner_iter_vars, &inner_bindings, //
+ preserve_unit_iters, outer_iter_vars.defined());
+ // Step 3: Do var substitution to adjust to the new block bindings
+ for (size_t i = 0; i < outer_iter_vars.size(); ++i) {
+ if (outer_bindings[i].as<Var>()) {
+ loop_var_subst.Set(Downcast<Var>(outer_bindings[i]),
outer_iter_vars[i]->var);
+ }
+ }
+ Map<Var, arith::IntSet> inner_iter_dom;
+ for (const IterVar& iter : inner_iter_vars) {
+ Range dom = Substitute(iter->dom, loop_var_subst);
+ inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(dom));
+ analyzer.Bind(iter->var, dom);
+ }
+ Block block_subst =
+ Downcast<Block>(Substitute(block, block_var_subst, block_sref_reuse,
&analyzer));
+ auto reads = EvalSetRegions(block_subst->reads, inner_iter_dom);
+ auto writes = EvalSetRegions(block_subst->writes, inner_iter_dom);
+ read_regions.insert(read_regions.end(), reads.begin(), reads.end());
+ write_regions.insert(write_regions.end(), writes.begin(), writes.end());
+ outer_block_name += block_subst->name_hint + "_";
+ // Step 4: Generate the inner block. No reduction iter vars allowed for
the outer loops.
+ bool has_outer_reduction = false;
+ if (block_subst->init.defined()) {
+ for (const IterVar& iter_var : outer_iter_vars) {
+ if (iter_var->iter_type == kCommReduce) {
+ has_outer_reduction = true;
+ break;
+ }
+ }
+ }
+ ICHECK(has_outer_reduction == false)
+ << "No reduction iter vars allowed for the outer loops when blockize
multiple blocks";
+ BlockRealize inner_realize =
GenerateInner(/*is_write_reduction=*/has_outer_reduction,
+ /*iter_vars=*/inner_iter_vars,
+ /*iter_values*/ inner_bindings,
+ /*predicate=*/inner_predicate,
+ /*block=*/block_subst);
+ block_sref_reuse->Set(block, inner_realize->block);
+ Stmt stmt = inner_realize;
+ for (const ForNode* loop : loops) {
+ ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
+ new_loop->body = std::move(stmt);
+ new_loop->extent = Substitute(new_loop->extent, loop_var_subst);
+ stmt = For(new_loop);
+ }
+ seq_body.push_back(stmt);
+ }
+ // Step 5: Generate the outer block.
+ return BlockRealize(
+ /*iter_values=*/std::move(outer_bindings),
+ /*predicate=*/std::move(outer_predicate),
+ /*block=*/
+ Block(/*iter_vars=*/std::move(outer_iter_vars),
+ /*reads=*/UnionRegions(read_regions),
+ /*writes=*/UnionRegions(write_regions),
+ /*name_hint=*/outer_block_name,
+ /*body=*/SeqStmt(seq_body),
+ /*init=*/Optional<Stmt>(NullOpt)));
+}
+
+class BlockizeRewriter : public StmtMutator {
+ public:
+ static Stmt Rewrite(const StmtSRef& lca, const Array<StmtSRef>& blocks,
+ const BlockRealize& blockized) {
+ BlockizeRewriter rewriter(lca, blocks, blockized);
+ return rewriter(GetRef<Stmt>(lca->stmt));
+ }
+
+ private:
+ explicit BlockizeRewriter(const StmtSRef& lca, const Array<StmtSRef>& blocks,
+ const BlockRealize& blockized)
+ : lca_(lca), blocks_(blocks), blockized_(blockized) {}
+
+ Stmt RewriteSeq(const Stmt& stmt) {
+ const SeqStmtNode* seq = stmt.as<SeqStmtNode>();
+ ICHECK(seq) << "Target blocks must not be nested with each other!";
+ int idx_start = -1;
+ int found_cnt = 0;
+ int last_found_idx = -1;
+ size_t cur_idx = 0;
+ Array<Stmt> new_seq;
+ for (const Stmt& it : seq->seq) {
+ target_in_ = false;
+ Stmt stmt = StmtMutator::VisitStmt(it);
+ if (target_in_) {
+ if (idx_start == -1) {
+ idx_start = cur_idx;
+ new_seq.push_back(blockized_);
+ } else {
+ ICHECK_EQ(last_found_idx, cur_idx - 1) << "Target blocks must be
consecutive!";
+ }
+ last_found_idx = cur_idx;
+ ++found_cnt;
+ } else {
+ new_seq.push_back(it);
+ }
+ ++cur_idx;
+ }
+ if (new_seq.size() == 1) return new_seq[0];
+ return SeqStmt(new_seq, seq->span);
+ }
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ if (loop == lca_->stmt) {
+ return For(loop->loop_var, loop->min, loop->extent, loop->kind,
RewriteSeq(loop->body),
+ loop->thread_binding, loop->annotations, loop->span);
+ }
+ return StmtMutator::VisitStmt_(loop);
+ }
+
+ Stmt VisitStmt_(const BlockNode* block) final {
+ if (block == lca_->stmt) {
+ return Block(block->iter_vars, block->reads, block->writes,
block->name_hint,
+ RewriteSeq(block->body), block->init, block->alloc_buffers,
block->match_buffers,
+ block->annotations, block->span);
+ }
+ for (const StmtSRef& block_sref : blocks_) {
+ if (block_sref->stmt == block) {
+ target_in_ = true;
+ break;
+ }
+ }
+ return GetRef<Stmt>(block);
+ }
+
+ StmtSRef lca_;
+ Array<StmtSRef> blocks_;
+ BlockRealize blockized_;
+ bool target_in_ = false;
+};
+
+StmtSRef Blockize(ScheduleState self, const Array<StmtSRef>& blocks, bool
preserve_unit_iters) {
+ Map<Block, Block> block_sref_reuse;
+ auto lca = GetSRefLowestCommonAncestor(blocks);
+ BlockRealize blockized =
+ BlockizeBlocks(self, blocks, lca, &block_sref_reuse,
preserve_unit_iters);
+ auto new_root = BlockizeRewriter::Rewrite(lca, blocks, blockized);
+ self->Replace(lca, new_root, block_sref_reuse);
+ StmtSRef result = self->stmt2ref.at(blockized->block.get());
+ StmtSRef scope_root = tir::GetScopeRoot(self, result,
/*require_stage_pipeline=*/false);
+ self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root));
+ return result;
+}
+
void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin&
intrin,
bool preserve_unit_iters) {
// Step 1: Blockize the subtree rooted at the given loop if needed
@@ -636,13 +863,19 @@ struct BlockizeTraits : public
UnpackedInstTraits<BlockizeTraits> {
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumDecisions = 0;
- static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Bool
preserve_unit_iters) {
- return sch->Blockize(loop_rv, preserve_unit_iters.operator bool());
+ static BlockRV UnpackedApplyToSchedule(Schedule sch, ObjectRef target, Bool
preserve_unit_iters) {
+ if (auto loop = target.as<LoopRV>()) {
+ return sch->Blockize(loop.value(), preserve_unit_iters.operator bool());
+ } else if (auto blocks = target.as<Array<BlockRV>>()) {
+ return sch->Blockize(blocks.value(), preserve_unit_iters.operator
bool());
+ }
+ LOG(FATAL) << "TypeError: expect Loop or list of Blocks, but gets:" <<
target->GetTypeKey();
}
- static String UnpackedAsPython(Array<String> outputs, String loop_rv, Bool
preserve_unit_iters) {
+ static String UnpackedAsPython(Array<String> outputs, ObjectRef target,
+ Bool preserve_unit_iters) {
PythonAPICall py("blockize");
- py.Input("loop", loop_rv);
+ py.Input("target", target);
py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
py.SingleOutput(outputs);
return py.Str();
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 8663ac2b97..56d0d1efa9 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -228,7 +228,14 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType")
.set_body_method<Schedule>(&ScheduleNode::UnsafeSetDType);
/******** (FFI) Blockize & Tensorize ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
- .set_body_method<Schedule>(&ScheduleNode::Blockize);
+ .set_body_typed([](Schedule self, ObjectRef target, bool
preserve_unit_iters) {
+ if (auto loop_rv = target.as<LoopRV>()) {
+ return self->Blockize(loop_rv.value(), preserve_unit_iters);
+ } else if (auto blocks = target.as<Array<BlockRV>>()) {
+ return self->Blockize(blocks.value(), preserve_unit_iters);
+ }
+ LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey();
+ });
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize")
.set_body_typed([](Schedule self, ObjectRef rv, String intrin, bool
preserve_unit_iters) {
if (auto block_rv = rv.as<BlockRV>()) {
diff --git a/src/tir/schedule/traced_schedule.cc
b/src/tir/schedule/traced_schedule.cc
index 4d820078e5..ceeeacb335 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -558,6 +558,17 @@ BlockRV TracedScheduleNode::Blockize(const LoopRV&
loop_rv, bool preserve_unit_i
return new_block;
}
+BlockRV TracedScheduleNode::Blockize(const Array<BlockRV>& blocks, bool
preserve_unit_iters) {
+ BlockRV new_block = ConcreteScheduleNode::Blockize(blocks,
preserve_unit_iters);
+ static const InstructionKind& kind = InstructionKind::Get("Blockize");
+ trace_->Append(/*inst=*/Instruction(
+ /*kind=*/kind,
+ /*inputs=*/{blocks},
+ /*attrs=*/{Bool(preserve_unit_iters)},
+ /*outputs=*/{new_block}));
+ return new_block;
+}
+
void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin,
bool preserve_unit_iters) {
ConcreteScheduleNode::Tensorize(loop_rv, intrin, preserve_unit_iters);
diff --git a/src/tir/schedule/traced_schedule.h
b/src/tir/schedule/traced_schedule.h
index 16ec86f227..2d47ee9aff 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -111,6 +111,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String&
dtype) final;
/******** Schedule: Blockize & Tensorize ********/
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final;
+ BlockRV Blockize(const Array<BlockRV>& blocks, bool preserve_unit_iters)
final;
void Tensorize(const BlockRV& block_rv, const String& intrin, bool
preserve_unit_iters) final;
void Tensorize(const LoopRV& loop_rv, const String& intrin, bool
preserve_unit_iters) final;
/******** Schedule: Annotation ********/
diff --git
a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
index 030e47ac58..111448ea57 100644
--- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
+++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
@@ -493,7 +493,7 @@ def _schedule_async_dma_conv2d():
sch.parallel(new_loops[4])
sch.unroll(new_loops[5])
# TODO(nverke): Add compute optimizations here.
- sch.blockize(loop=oc_i)
+ sch.blockize(target=oc_i)
sch.tensorize(oc_i, VRMPY_u8i8i32_VTCM_INTRIN)
diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py
b/tests/python/unittest/test_meta_schedule_trace_apply.py
index 4d22c4ff88..78b2fdbf3d 100644
--- a/tests/python/unittest/test_meta_schedule_trace_apply.py
+++ b/tests/python/unittest/test_meta_schedule_trace_apply.py
@@ -2150,7 +2150,7 @@ def test_conv2d_int8_tensorcore():
l28, l29 = sch.split(loop=l21, factors=[None, 16],
preserve_unit_iters=True)
l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b1)
sch.reorder(l34, l36, l29, l27, l25)
- b38 = sch.blockize(loop=l29)
+ b38 = sch.blockize(target=l29)
sch.annotate(
block_or_loop=b38,
ann_key="meta_schedule.auto_tensorize",
@@ -2243,7 +2243,7 @@ def test_conv2d_int8_tensorcore():
l95, l96 = sch.split(loop=l91, factors=[None, 16],
preserve_unit_iters=True)
l97, l98, l99, l100, l101, l102, l103 = sch.get_loops(block=b86)
sch.reorder(l102, l96, l94)
- b104 = sch.blockize(loop=l96)
+ b104 = sch.blockize(target=l96)
sch.annotate(
block_or_loop=b104,
ann_key="meta_schedule.auto_tensorize",
@@ -2308,7 +2308,7 @@ def test_conv2d_int8_tensorcore():
l157,
) = sch.get_loops(block=b129)
sch.reorder(l156, l144, l142)
- b158 = sch.blockize(loop=l144)
+ b158 = sch.blockize(target=l144)
sch.annotate(
block_or_loop=b158,
ann_key="meta_schedule.auto_tensorize",
@@ -2351,7 +2351,7 @@ def test_conv2d_int8_tensorcore():
l191,
) = sch.get_loops(block=b159)
sch.reorder(l190, l176, l174)
- b192 = sch.blockize(loop=l176)
+ b192 = sch.blockize(target=l176)
sch.annotate(
block_or_loop=b192,
ann_key="meta_schedule.auto_tensorize",
@@ -2554,7 +2554,7 @@ def test_conv2d_int8_vnni():
l34, l35 = sch.split(loop=l26, factors=[None, 16],
preserve_unit_iters=True)
l36, l37, l38, l39, l40, l41, l42, l43, l44, l45, l46, l47 =
sch.get_loops(block=b1)
sch.reorder(l42, l43, l44, l45, l46, l35, l33)
- b48 = sch.blockize(loop=l35)
+ b48 = sch.blockize(target=l35)
sch.annotate(block_or_loop=b48,
ann_key="meta_schedule.auto_tensorize", ann_val=VNNI_INTRIN)
l49, l50, l51, l52, l53, l54, l55, l56, l57, l58 =
sch.get_loops(block=b48)
v59, v60, v61, v62 = sch.sample_perfect_tile(
@@ -3119,7 +3119,7 @@ def test_inline_order():
l22, l23 = sch.split(loop=l15, factors=[None, 16],
preserve_unit_iters=True)
l24, l25, l26, l27, l28, l29, l30, l31 = sch.get_loops(block=b1)
sch.reorder(l28, l30, l23, l21, l19)
- b32 = sch.blockize(loop=l23)
+ b32 = sch.blockize(target=l23)
sch.annotate(
block_or_loop=b32,
ann_key="meta_schedule.auto_tensorize",
@@ -3212,7 +3212,7 @@ def test_inline_order():
l89, l90 = sch.split(loop=l85, factors=[None, 16],
preserve_unit_iters=True)
l91, l92, l93, l94, l95, l96, l97 = sch.get_loops(block=b80)
sch.reorder(l96, l90, l88)
- b98 = sch.blockize(loop=l90)
+ b98 = sch.blockize(target=l90)
sch.annotate(
block_or_loop=b98,
ann_key="meta_schedule.auto_tensorize",
@@ -3277,7 +3277,7 @@ def test_inline_order():
l151,
) = sch.get_loops(block=b123)
sch.reorder(l150, l138, l136)
- b152 = sch.blockize(loop=l138)
+ b152 = sch.blockize(target=l138)
sch.annotate(
block_or_loop=b152,
ann_key="meta_schedule.auto_tensorize",
@@ -3320,7 +3320,7 @@ def test_inline_order():
l185,
) = sch.get_loops(block=b153)
sch.reorder(l184, l170, l168)
- b186 = sch.blockize(loop=l170)
+ b186 = sch.blockize(target=l170)
sch.annotate(
block_or_loop=b186,
ann_key="meta_schedule.auto_tensorize",
diff --git a/tests/python/unittest/test_tir_schedule_blockize.py
b/tests/python/unittest/test_tir_schedule_blockize.py
index cd4ce663e5..d151e4b438 100644
--- a/tests/python/unittest/test_tir_schedule_blockize.py
+++ b/tests/python/unittest/test_tir_schedule_blockize.py
@@ -305,5 +305,54 @@ def test_blockize_outer_int64_shape(preserve_unit_iters):
verify_trace_roundtrip(sch=s, mod=single_elementwise_int64)
+def test_blockize_blocks():
+ @T.prim_func
+ def blocks_func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128,
128), "float32")) -> None:
+ for m in T.serial(6):
+ for i, j in T.grid(3, 1):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * 2.0
+
+ for i, j in T.grid(128, 64):
+ with T.block("C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A[vi, vj + 64])
+ T.writes(B[vi, vj + 64])
+ B[vi, vj + 64] = A[vi, vj + 64] * 3.0
+
+ @T.prim_func
+ def after_blocks_blockize(
+ A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")
+ ) -> None:
+ for m in range(6):
+ with T.block("outer_B_C_"):
+ vi_o = T.axis.spatial(1, 0)
+ vj_o = T.axis.spatial(1, 0)
+ T.reads(A[0:128, 0:128])
+ T.writes(B[0:128, 0:128])
+ for i, j in T.grid(3, 1):
+ with T.block("B"):
+ vi_i = T.axis.spatial(3, i)
+ T.reads(A[vi_i, 0])
+ T.writes(B[vi_i, 0])
+ B[vi_i, 0] = A[vi_i, 0] * T.float32(2)
+ for i, j in T.grid(128, 64):
+ with T.block("C"):
+ vi_i, vj_i = T.axis.remap("SS", [i, j])
+ T.reads(A[vi_i, vj_i + 64])
+ T.writes(B[vi_i, vj_i + 64])
+ B[vi_i, vj_i + 64] = A[vi_i, vj_i + 64] * T.float32(3)
+
+ s = tir.Schedule(blocks_func, debug_mask="all")
+ blocks = [s.get_block("B"), s.get_block("C")]
+ s.blockize(blocks, preserve_unit_iters=False)
+ expected = after_blocks_blockize
+ tvm.ir.assert_structural_equal(s.mod["main"], expected)
+ verify_trace_roundtrip(sch=s, mod=blocks_func)
+
+
if __name__ == "__main__":
tvm.testing.main()