junrushao1994 commented on code in PR #12070:
URL: https://github.com/apache/tvm/pull/12070#discussion_r919832050
##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -122,531 +136,462 @@ Array<Array<arith::IterMark>>
TrivialSubspaceDivision(const Array<IterVar>& iter
}
/*!
- * \brief Generate the blockized init block.
- * \param block The original block with init.
- * \param inner_block_realize The block realize of the inner block after
blockize.
- * \param inner_loops The inner loops after blockize.
- * \return The subtree of the init block and its outer loops.
+ * \brief Subspace division. The space is divided into two subspaces:
+ * 1. The subspace represented by the outer loops above `loop_sref`
(exclusive).
+ * 2. The subspace represented by the inner loops below `loop_sref`
(inclusive).
+ * \param realize The inner block
+ * \param block_sref The sref to the inner block
+ * \param loop_sref The loop that is the root of the second subspace.
+ * \param loops The loops that represents the second part of the subspace.
+ * \param analyzer The arithmetic analyzer to use.
*/
-Stmt GenerateBlockizedInit(const Block& block, const BlockRealize&
inner_block_realize,
- const std::vector<const ForNode*>& inner_loops) {
- Array<IterVar> init_block_iters;
- Array<PrimExpr> init_bindings;
- const Block& inner_block = inner_block_realize->block;
-
- // Step 1: Collect data-parallel block iters
- for (size_t i = 0; i < inner_block->iter_vars.size(); i++) {
- const IterVar& iter_var = inner_block->iter_vars[i];
- const PrimExpr& binding = inner_block_realize->iter_values[i];
- if (iter_var->iter_type == IterVarType::kDataPar &&
- UsesVar(block->init.value(),
- [tgt_var = iter_var->var.get()](const VarNode* var) { return
var == tgt_var; })) {
- init_block_iters.push_back(iter_var);
- init_bindings.push_back(binding);
+Array<Array<arith::IterMark>> SubspaceDivide(const BlockRealize& realize,
+ const StmtSRef& block_sref, //
+ const StmtSRef& loop_sref, //
+ std::vector<const ForNode*>*
loops,
+ arith::Analyzer* analyzer) {
+ Array<Var> inner_vars;
+ Array<Var> outer_vars;
+ Map<Var, Range> loop_var_domain;
+ bool inner = true;
+ for (StmtSRefNode* sref = block_sref->parent; //
+ sref && sref->stmt->IsInstance<ForNode>(); //
+ sref = sref->parent) {
+ const ForNode* loop = static_cast<const ForNode*>(sref->stmt);
+ if (inner) {
+ loops->push_back(loop);
+ inner_vars.push_back(loop->loop_var);
+ } else {
+ outer_vars.push_back(loop->loop_var);
}
- }
-
- // Step 2: Collect loops related to iters of the init block
- std::vector<const ForNode*> init_loops;
- for (const ForNode* inner_loop : inner_loops) {
- for (const PrimExpr& init_binding : init_bindings) {
- if (UsesVar(init_binding, [tgt_var = inner_loop->loop_var.get()](const
VarNode* var) {
- return var == tgt_var;
- })) {
- init_loops.push_back(inner_loop);
- break;
- }
+ loop_var_domain.Set(loop->loop_var, Range::FromMinExtent(loop->min,
loop->extent));
+ if (sref == loop_sref.get()) {
+ inner = false;
}
}
-
- // Step 3: Create new block iters for the init block
- Map<Var, PrimExpr> subst_map;
- for (size_t i = 0; i < init_block_iters.size(); i++) {
- IterVar new_iter_var = init_block_iters[i];
- Var old_var = new_iter_var->var;
- Var new_var = old_var.copy_with_suffix("_init");
- new_iter_var.CopyOnWrite()->var = new_var;
- subst_map.Set(old_var, new_var);
- init_block_iters.Set(i, std::move(new_iter_var));
- }
-
- // Step 4: Generate loop nests and the init block
- Stmt new_init = BlockRealize(
- /*iter_values=*/init_bindings,
- /*predicate=*/inner_block_realize->predicate,
- /*block=*/
- Block{/*iter_vars=*/init_block_iters,
- /*reads=*/{},
- /*writes=*/block->writes,
- /*name_hint=*/block->name_hint + "_init",
- /*body=*/block->init.value(),
- /*init=*/NullOpt});
-
- // Step 5: Generate the parent loops for the init block
- for (const ForNode* init_loop : init_loops) {
- ObjectPtr<ForNode> new_loop = make_object<ForNode>(*init_loop);
- new_loop->loop_var = init_loop->loop_var.copy_with_suffix("");
- subst_map.Set(init_loop->loop_var, new_loop->loop_var);
- new_loop->body = std::move(new_init);
- new_init = For(new_loop);
+ Array<Array<arith::IterMark>> result =
+ arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars,
realize->predicate,
+ arith::IterMapLevel::Surjective, analyzer);
+ if (!result.empty()) {
+ return result;
}
-
- // Step 6: Substitute with new loop variables and block iters to prevent
duplication of
- // variables in the outer block.
- new_init = Substitute(new_init, subst_map);
-
- return new_init;
+ return TrivialSubspaceDivision(realize->block->iter_vars,
+ realize->iter_values, //
+ realize->predicate, //
+ outer_vars, inner_vars);
}
/*!
- * \brief A helper to collect the parent loops of the block. The loops are
divided into two groups,
- * 'outer_loops', and 'inner_loops', by a specified loop as the separator.
'outer_loops' are the
- * ancestor loops of the separator loop. 'inner_loops' include the separator
loop itself, and its
- * successor loops. It is possible that 'outer_loops' is empty.
+ * \brief Derive the block bindings for both inner and outer block
+ * \param iter_vars The original block iterators to the inner block
+ * \param division The subspace division.
+ * \param outer_iter_vars The outer block iterators.
+ * \param outer_bindings The outer block bindings.
+ * \param inner_iter_vars The inner block iterators.
+ * \param inner_bindings The inner block bindings.
+ * \return A substitution plan to the iterators in the original inner block.
*/
-class LoopSubspaceCollector {
- public:
- /*!
- * \brief Collect the parent loops of the block and store the result in the
corresponding fields.
- * \param block_sref The sref to the target block.
- * \param loop_sref The sref to the separator loop. The loop itself is
counted as an inner loop.
- */
- void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) {
- bool inner = true;
- for (StmtSRefNode* current_sref = block_sref->parent;
- current_sref && current_sref->stmt->IsInstance<ForNode>();
- current_sref = current_sref->parent) {
- const auto* current_loop = current_sref->StmtAs<ForNode>();
- ICHECK(current_loop);
- if (inner) {
- inner_loops.push_back(current_loop);
- inner_loop_vars.push_back(current_loop->loop_var);
- } else {
- outer_loops.push_back(current_loop);
- outer_loop_vars.push_back(current_loop->loop_var);
- }
- loop_var_domain.Set(current_loop->loop_var,
- Range::FromMinExtent(current_loop->min,
current_loop->extent));
- if (current_sref == loop_sref.get()) inner = false;
+Map<Var, PrimExpr> DeriveBlockBinding(const Array<IterVar>& iter_vars,
//
+ const Array<Array<arith::IterMark>>&
division, //
+ Array<IterVar>* outer_iter_vars,
//
+ Array<PrimExpr>* outer_bindings,
//
+ Array<IterVar>* inner_iter_vars,
//
+ Array<PrimExpr>* inner_bindings) {
+ using arith::IterMapExpr;
+ using arith::IterMapExprNode;
+ using arith::NormalizeIterMapToExpr;
+ Map<Var, PrimExpr> block_var_subst;
+ ICHECK_EQ(iter_vars.size() + 1, division.size());
+ for (int i = 0, n = iter_vars.size(); i < n; ++i) {
+ const IterVar& iter_var = iter_vars[i];
+ arith::IterMark outer_mark = division[i][0];
+ arith::IterMark inner_mark = division[i][1];
+ IterMapExpr outer_binding = Downcast<IterMapExpr>(outer_mark->source);
+ IterMapExpr inner_binding = Downcast<IterMapExpr>(inner_mark->source);
+ // After computing the subspace division, bindings[i] can be written as
+ // outer_binding * inner_binding->extent + inner_binding
+ // The outer block will have binding: iter_outer -> outer_binding
+ // The inner block will have binding: iter_inner -> inner_binding
+ // The iter in the original block will be substituted with base +
iter_inner where
+ // base == iter_outer * iter_inner_extent
+ if (is_one(inner_mark->extent)) { // IsOuter
+ // extract this iter var to outer block directly
+ outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
+ outer_iter_vars->push_back(iter_var);
+ continue;
}
+ // create iter var for the outer block
+ IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent),
+ /*var=*/iter_var->var.copy_with_suffix("_o"),
+ /*iter_type=*/iter_var->iter_type);
+ outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
+ outer_iter_vars->push_back(outer_iter);
+ // create iter var for the inner block
+ IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent),
+ /*var=*/iter_var->var.copy_with_suffix("_i"),
+ /*iter_type=*/iter_var->iter_type);
+ inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding));
+ inner_iter_vars->push_back(inner_iter);
+ // substitution
+ PrimExpr sub{nullptr};
+ if (is_one(outer_mark->extent)) {
+ sub = inner_iter->var;
+ } else {
+ sub = outer_iter * inner_mark->extent + inner_iter->var;
+ }
+ block_var_subst.Set(iter_var->var, sub);
}
- /*! \brief Outer loops which are ancestors of the separator. */
- std::vector<const ForNode*> outer_loops;
- /*! \brief Inner loops which are the separator itself or its successors. */
- std::vector<const ForNode*> inner_loops;
- /*! \brief Loop variables of the outer loops. */
- Array<Var> outer_loop_vars;
- /*! \brief Loop variables of the inner loops. */
- Array<Var> inner_loop_vars;
- /*! \brief Domain of the loop variables. */
- Map<Var, Range> loop_var_domain;
-};
+ return block_var_subst;
+}
/*!
- * \brief Check the bindings of the block iters can be divided by a subspace
collected by the
- * collector.
- * \param mod The current IR module.
- * \param block_realize The block realize to be checked.
- * \param collector The collector which has collected the loops of the block.
- * \param analyzer The arithmetic analyzer.
- * \return The result of the subspace division.
- * \throws ScheduleError If the bindings are not divisible by the subspace.
+ * \brief Generate the inner block for blockization
+ * \param is_write_reduction Whether the write regions of the inner block are
actually reduction.
+ * \param iter_vars IterVars used in the inner block.
+ * \param iter_values IterVar bindings used in the inner block.
+ * \param predicate The predicate of the inner block.
+ * \param block The inner block as a template to be created from. This method
will modify its
+ * `iter_vars`, `init` and `reads` fields.
+ * \return The inner block created.
*/
-Array<Array<arith::IterMark>> CheckSubspaceDivisible(const IRModule& mod,
- const BlockRealize&
block_realize,
- const
LoopSubspaceCollector& collector,
- arith::Analyzer*
analyzer) {
- const Block& block = block_realize->block;
-
- Array<Array<arith::IterMark>> division = arith::SubspaceDivide(
- block_realize->iter_values, collector.loop_var_domain,
collector.inner_loop_vars,
- block_realize->predicate, arith::IterMapLevel::Surjective, analyzer);
-
- if (division.empty()) {
- // If we can't do perfect subspace division, check if it is a trivial case
of subspace division.
- // In this case, we can still blockize.
- division = TrivialSubspaceDivision(block->iter_vars,
block_realize->iter_values,
- collector.outer_loop_vars,
collector.inner_loop_vars,
- block_realize->predicate);
- }
- if (division.empty()) {
- throw SubspaceNotDivisibleError(mod,
GetRef<For>(collector.inner_loops.back()), block);
+BlockRealize GenerateInner(bool is_write_reduction,
+ const Array<IterVar>& iter_vars, //
+ const Array<PrimExpr>& iter_values, //
+ const PrimExpr& predicate, //
+ Block block) {
+ BlockNode* n = block.CopyOnWrite();
+ n->iter_vars = iter_vars;
+ n->init = NullOpt;
+ if (is_write_reduction) {
+ Array<BufferRegion> reads;
+ reads.reserve(block->writes.size() + block->reads.size());
+ reads.insert(reads.end(), block->writes.begin(), block->writes.end());
+ reads.insert(reads.end(), block->reads.begin(), block->reads.end());
+ n->reads = std::move(reads);
}
- return division;
+ return BlockRealize(/*iter_values=*/iter_values, /*predicate=*/predicate,
+ /*block=*/block);
}
/*!
- * \brief The binding extractor to compute the bindings of the outer and the
inner blocks after
- * blockize.
+ * \brief Generate the init stmt for the outer block
+ * \param block The original block with init.
+ * \param inner_realize The block realize of the inner block after blockize.
+ * \param loops The inner loops after blockize.
+ * \return The subtree of the init block and its outer loops.
*/
-class BlockizedBindingExtractor {
- public:
- /*!
- * \brief Extract bindings for blockize.
- * \param iter_vars The iter vars of the original inner block.
- * \param division The result of the subspace division.
- */
- void ExtractBindings(const Array<IterVar>& iter_vars,
- const Array<Array<arith::IterMark>>& division,
arith::Analyzer* analyzer) {
- ICHECK_EQ(iter_vars.size() + 1, division.size());
- for (size_t i = 0; i < iter_vars.size(); ++i) {
- const IterVar& iter_var = iter_vars[i];
- arith::IterMark outer_mark = division[i][0];
- arith::IterMark inner_mark = division[i][1];
- const auto* outer_binding =
- TVM_TYPE_AS(outer_binding, outer_mark->source,
arith::IterMapExprNode);
- const auto* inner_binding =
- TVM_TYPE_AS(inner_binding, inner_mark->source,
arith::IterMapExprNode);
-
- // After computing the subspace division, bindings[i] can be written as
- // outer_binding * inner_binding->extent + inner_binding
- // The outer block will have binding: iter_outer -> outer_binding
- // The inner block will have binding: iter_inner -> inner_binding
- // The iter in the original block will be substituted with base +
iter_inner where
- // base == iter_outer * iter_inner_extent
-
- if (is_one(division[i][1]->extent)) { // IsOuter
- // extract this iter var to outer block directly
- outer_bindings.push_back(
-
arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(outer_binding)));
- outer_iter_vars.push_back(iter_var);
- } else {
- // create iter var for the outer block
- const IterVar outer_var(/*dom=*/Range::FromMinExtent(0,
division[i][0]->extent),
- /*var=*/iter_var->var.copy_with_suffix("_o"),
- /*iter_type=*/iter_var->iter_type);
- outer_bindings.push_back(
-
arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(outer_binding)));
- outer_iter_vars.push_back(outer_var);
- PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var *
division[i][1]->extent;
- // create iter var for the inner block
- IterVar new_iter(Range::FromMinExtent(0, division[i][1]->extent),
Var(iter_var->var),
- iter_var->iter_type, iter_var->thread_tag,
iter_var->span);
- inner_iter_dom_map.Set(new_iter->var,
arith::IntSet::FromRange(new_iter->dom));
- analyzer->Bind(new_iter->var, new_iter->dom);
- inner_iter_vars.push_back(new_iter);
- inner_bindings.push_back(
-
arith::NormalizeIterMapToExpr(GetRef<arith::IterMapExpr>(inner_binding)));
- inner_iter_subst_map.Set(iter_var->var, base + new_iter->var);
+Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize&
inner_realize,
+ const std::vector<const ForNode*>& loops, String
block_name) {
+ const Block& inner_block = inner_realize->block;
+ Map<Var, PrimExpr> subst_map;
+ // Step 1: Create new block vars for the block inside the init stmt of outer
block
+ // A iter is used in the block if
+ // 1) It is data parallel
+ // 2) It is used in the original init block
+ Array<IterVar> iter_vars;
+ Array<PrimExpr> iter_values;
+ ICHECK_EQ(inner_block->iter_vars.size(), inner_realize->iter_values.size());
+ int n = inner_block->iter_vars.size();
+ iter_vars.reserve(n);
+ iter_values.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ const IterVar& old_iter_var = inner_block->iter_vars[i];
+ const PrimExpr& iter_value = inner_realize->iter_values[i];
+ if (old_iter_var->iter_type == IterVarType::kDataPar &&
+ UsesVar(block_init, old_iter_var->var)) {
+ ObjectPtr<IterVarNode> new_iter_var =
make_object<IterVarNode>(*old_iter_var.get());
+ new_iter_var->var = new_iter_var->var.copy_with_suffix("_init");
+ subst_map.Set(old_iter_var->var, new_iter_var->var);
+ iter_vars.push_back(IterVar(new_iter_var));
+ iter_values.push_back(iter_value);
+ }
+ }
+ // Step 2: Generate the block inside init stmt of outer block
+ Stmt stmt = BlockRealize(
+ /*iter_values=*/iter_values,
+ /*predicate=*/inner_realize->predicate,
+ /*block=*/
+ Block(/*iter_vars=*/iter_vars,
+ /*reads=*/{},
+ /*writes=*/inner_block->writes,
+ /*name_hint=*/block_name,
+ /*body=*/block_init,
+ /*init=*/NullOpt));
+ // Step 3. Create the loop nest on top of the block
+ for (const ForNode* loop : loops) {
+ bool is_init_loop = false;
+ for (const PrimExpr& init_binding : iter_values) {
+ if (UsesVar(init_binding, loop->loop_var)) {
+ is_init_loop = true;
+ break;
}
}
+ if (is_init_loop) {
+ ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
+ new_loop->loop_var = loop->loop_var.copy_with_suffix("");
+ new_loop->body = std::move(stmt);
+ subst_map.Set(loop->loop_var, new_loop->loop_var);
+ stmt = For(new_loop);
+ }
}
- Map<Var, PrimExpr> inner_iter_subst_map;
- /*! \brief Iters of the outer block. */
- Array<IterVar> outer_iter_vars;
- /*! \brief Iters of the outer block. */
- Array<IterVar> inner_iter_vars;
- /*! \brief Binding values of the outer block. */
- Array<PrimExpr> outer_bindings;
- /*! \brief Binding values of the inner block. */
- Array<PrimExpr> inner_bindings;
- /*! \brief The domain of the inner block iters. */
- Map<Var, arith::IntSet> inner_iter_dom_map;
-};
+ // Step 4: Substitute the iter vars and loop vars
+ return Substitute(stmt, subst_map);
+}
/*!
- * \brief Replacer for the inner block after blockize. Inner block iters will
be replaced with
- * base + inner_iter and the expressions after substituion will be simplified
if possible.
+ * \brief Substitute variables in the stmt, do simplification and track block
substitution
+ * \param stmt The stmt to be substituted.
+ * \param sub The substitution map.
+ * \param block_sref_reuse The block substitution happens during the
substitution.
+ * \param analyzer The analyzer for arithmetic simplification.
+ * \return The substituted stmt.
*/
-class InnerIterReplacer : public StmtExprMutator {
- public:
- /*!
- * \brief The constructor
- * \param subst_map The substitution map of the inner block iters.
- * \param analyzer The arithmetic analyzer.
- * \param block_sref_reuse The map to save the block reuse information.
- */
- InnerIterReplacer(Map<Var, PrimExpr> subst_map, arith::Analyzer* analyzer,
- Map<Block, Block>* block_sref_reuse)
- : subst_map_(std::move(subst_map)),
- analyzer_(analyzer),
- block_sref_reuse_(block_sref_reuse) {}
-
- PrimExpr VisitExpr_(const VarNode* op) final {
- auto it = subst_map_.find(GetRef<Var>(op));
- if (it != subst_map_.end()) {
- return (*it).second;
+Stmt Substitute(const Stmt& stmt, const Map<Var, PrimExpr>& sub,
+ Map<Block, Block>* block_sref_reuse, arith::Analyzer*
analyzer) {
+ struct Replacer : public StmtExprMutator {
+ explicit Replacer(const Map<Var, PrimExpr>& sub, Map<Block, Block>*
block_sref_reuse,
+ arith::Analyzer* analyzer)
+ : sub_(sub), block_sref_reuse_(block_sref_reuse), analyzer_(analyzer)
{}
+
+ PrimExpr VisitExpr(const PrimExpr& op) final {
+ PrimExpr result = StmtExprMutator::VisitExpr(op);
+ if (!result.same_as(op)) {
+ return analyzer_->Simplify(result);
+ }
+ return result;
}
- return StmtExprMutator::VisitExpr_(op);
- }
- PrimExpr VisitExpr(const PrimExpr& op) final {
- PrimExpr result = StmtExprMutator::VisitExpr(op);
- if (!result.same_as(op)) {
- return analyzer_->Simplify(result);
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ if (Optional<PrimExpr> e = sub_.Get(GetRef<Var>(op))) {
+ return e.value();
+ }
+ return StmtExprMutator::VisitExpr_(op);
}
- return result;
- }
- Stmt VisitStmt_(const BlockNode* op) final {
- Stmt result = StmtExprMutator::VisitStmt_(op);
- if (!result.same_as(GetRef<Stmt>(op))) {
- block_sref_reuse_->Set(GetRef<Block>(op), Downcast<Block>(result));
+ Stmt VisitStmt_(const BlockNode* op) final {
+ Block src = GetRef<Block>(op);
+ Block tgt = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+ if (!src.same_as(tgt)) {
+ block_sref_reuse_->Set(src, tgt);
+ }
+ return tgt;
}
- return result;
- }
- private:
- Map<Var, PrimExpr> subst_map_;
- arith::Analyzer* analyzer_;
- Map<Block, Block>* block_sref_reuse_;
-};
+ const Map<Var, PrimExpr>& sub_;
+ Map<Block, Block>* block_sref_reuse_;
+ arith::Analyzer* analyzer_;
+ };
+ return Replacer(sub, block_sref_reuse, analyzer)(stmt);
+}
/*!
- * \brief Compute the access region of the outer block by relaxing the inner
loops.
- * \param buffer_region The original buffer region.
- * \param The range of the inner loops.
- * \return The new buffer region.
+ * \brief Relax the variables for the given regions
+ * \param regions The regions to be relaxed.
+ * \param dom_map The variables to be relaxed
+ * \return The relaxed regions
*/
-BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region,
- const Map<Var, arith::IntSet>&
inner_iter_relaxed_range) {
- Array<Range> new_region;
- new_region.reserve(buffer_region->region.size());
- Array<arith::IntSet> relaxed_int_set =
- arith::EvalSet(buffer_region->region, inner_iter_relaxed_range);
- ICHECK(buffer_region->region.size() == buffer_region->buffer->shape.size());
- for (size_t i = 0; i < buffer_region->region.size(); i++) {
- Range max_range = Range::FromMinExtent(0, buffer_region->buffer->shape[i]);
- new_region.push_back(relaxed_int_set[i].CoverRange(max_range));
+Array<BufferRegion> EvalSetRegions(const Array<BufferRegion>& regions,
+ const Map<Var, arith::IntSet>& dom_map) {
+ Array<BufferRegion> results;
+ results.reserve(regions.size());
+ for (const BufferRegion& buffer_region : regions) {
+ const Buffer& buffer = buffer_region->buffer;
+ Array<arith::IntSet> relaxed = arith::EvalSet(buffer_region->region,
dom_map);
+ ICHECK_EQ(relaxed.size(), buffer->shape.size());
+ int ndim = buffer->shape.size();
+ Array<Range> new_region;
+ new_region.reserve(ndim);
+ for (int i = 0; i < ndim; ++i) {
+
new_region.push_back(relaxed[i].CoverRange(RangeFromExtent(buffer->shape[i])));
+ }
+ results.push_back(BufferRegion(buffer, new_region));
}
- return BufferRegion(buffer_region->buffer, std::move(new_region));
+ return results;
}
/*!
- * \brief Generate the outer block after blockize.
- * \param extractor The binding extractor which has extracted the blockized
bindings.
- * \param block The original inner block.
- * \param inner_block_realize The block realize of the inner block after
blockize.
- * \param inner_loops The inner loops after blockize.
- * \param predicate The outer predicate of the subspace division.
- * \return The block realize of the outer block after blockize.
+ * \brief Create the loop nest on top of the given stmt.
+ * \param stmt The stmt to be wrapped.
+ * \param loops The loop nests
+ * \return The wrapped stmt.
*/
-BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor&
extractor,
- const Block& block, BlockRealize
inner_block_realize,
- const std::vector<const ForNode*>&
inner_loops,
- PrimExpr predicate) {
- // Step 1: Generate the init block if needed
- Optional<Stmt> new_init = NullOpt;
- if (block->init.defined()) {
- new_init = GenerateBlockizedInit(block, inner_block_realize, inner_loops);
- }
-
- // Step 2: Compute the access regions of the outer block by relaxing the
inner loops
- Array<BufferRegion> new_reads = block->reads;
- Array<BufferRegion> new_writes = block->writes;
-
- auto f_mutate = [&](const BufferRegion& buffer_region) {
- return RelaxBlockizedInnerIters(buffer_region,
extractor.inner_iter_dom_map);
- };
- new_reads.MutateByApply(f_mutate);
- new_writes.MutateByApply(f_mutate);
-
- // Step 3: Generate the body of the outer block. The body of the outer block
is the inner block
- // realize and its surrounding loops.
- Stmt outer_block_body = inner_block_realize;
- for (const ForNode* loop : inner_loops) {
+Stmt MakeLoopNest(Stmt stmt, const std::vector<const ForNode*>& loops) {
+ for (const ForNode* loop : loops) {
ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
- new_loop->body = std::move(outer_block_body);
- outer_block_body = For(new_loop);
+ new_loop->body = std::move(stmt);
+ stmt = For(new_loop);
}
-
- // Step 4: Generate the outer block and block realize.
- return BlockRealize(/*iter_values=*/std::move(extractor.outer_bindings),
- /*predicate=*/std::move(predicate),
- /*block=*/
-
Block(/*iter_vars=*/std::move(extractor.outer_iter_vars), //
- /*reads=*/std::move(new_reads),
//
- /*writes=*/std::move(new_writes),
//
- /*name_hint=*/block->name_hint + "_o",
//
- /*body=*/std::move(outer_block_body),
//
- /*init=*/std::move(new_init)));
+ return stmt;
}
-StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {
+BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref,
+ Map<Block, Block>* block_sref_reuse,
arith::Analyzer* analyzer) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
- arith::Analyzer analyzer;
-
- // Step 1: Check the loop has a single child BlockRealize on the sref tree.
+ // Step 1: Check and get the only block under `loop`.
BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self,
loop_sref);
Block block = block_realize->block;
StmtSRef block_sref = self->stmt2ref.at(block.get());
-
- // Step 2: Collect loops inside and outside loop_sref.
- LoopSubspaceCollector collector;
- collector.Collect(block_sref, loop_sref);
-
- // Step 3: Calculate subspace division for the inner loops.
+ // Step 2: Derive subspace division
+ std::vector<const ForNode*> loops;
Array<Array<arith::IterMark>> division =
- CheckSubspaceDivisible(self->mod, block_realize, collector, &analyzer);
-
- // Step 4: Generate bindings for the outer block and the inner block based
on the result of
- // the subspace division.
- BlockizedBindingExtractor extractor;
- extractor.ExtractBindings(block->iter_vars, division, &analyzer);
- const PrimExpr& outer_pred = division.back()[0]->extent;
- const PrimExpr& inner_pred = division.back()[1]->extent;
-
- // Step 5: Substitute the iter vars in the original block with the inner
iters after the subspace
- // division
- Map<Block, Block> block_sref_reuse;
- InnerIterReplacer replacer(std::move(extractor.inner_iter_subst_map),
&analyzer,
- &block_sref_reuse);
- Block new_block = Downcast<Block>(replacer(block));
-
- // Step 6: Generate the inner block.
- bool outer_reduction = false; // whether there are outer reduction iter
vars.
- for (const IterVar& iter_var : extractor.outer_iter_vars) {
- if (iter_var->iter_type == kCommReduce) {
- outer_reduction = true;
- }
+ SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer);
+ if (division.empty()) {
+ throw SubspaceNotDivisibleError(self->mod, GetRef<For>(loops.back()),
block);
}
- BlockRealizeNode* inner_block_realize = block_realize.CopyOnWrite();
- inner_block_realize->iter_values = extractor.inner_bindings;
- inner_block_realize->predicate = inner_pred;
- inner_block_realize->block = new_block;
- BlockNode* inner_block = inner_block_realize->block.CopyOnWrite();
- inner_block->iter_vars = extractor.inner_iter_vars;
- inner_block->init = NullOpt;
- /* Add write regions to read regions if
- * 1. there are outer reduction iter vars.
- * 2. the init block is defined for current block.
- */
- if (outer_reduction && block->init.defined()) {
- Array<BufferRegion> new_reads;
- for (const BufferRegion& write_access : inner_block->writes) {
- new_reads.push_back(write_access);
- }
- for (const BufferRegion& read_access : inner_block->reads) {
- new_reads.push_back(read_access);
+ PrimExpr outer_predicate = division.back()[0]->extent;
+ PrimExpr inner_predicate = division.back()[1]->extent;
+ // Step 3. Derive block bindings for both outer and inner block.
+ Array<IterVar> outer_iter_vars;
+ Array<IterVar> inner_iter_vars;
+ Array<PrimExpr> outer_bindings;
+ Array<PrimExpr> inner_bindings;
+ Map<Var, PrimExpr> block_var_subst = //
+ DeriveBlockBinding(block->iter_vars, division, //
+ &outer_iter_vars, &outer_bindings, //
+ &inner_iter_vars, &inner_bindings);
+ // Step 4: Do var substitution to adjust to the new block bindings
+ Map<Var, arith::IntSet> inner_iter_dom;
+ for (const IterVar& iter : inner_iter_vars) {
+ inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(iter->dom));
+ analyzer->Bind(iter->var, iter->dom);
+ }
+ Block block_subst =
+ Downcast<Block>(Substitute(block, block_var_subst, block_sref_reuse,
analyzer));
+ // Step 5: Generate the inner block. The write regions of the inner blocks
will be reduction if
+ // 1. The original block has init stmt.
+ // 2. There are outer reduction iter vars.
+ bool has_outer_reduction = false;
+ if (block_subst->init.defined()) {
+ for (const IterVar& iter_var : outer_iter_vars) {
+ if (iter_var->iter_type == kCommReduce) {
+ has_outer_reduction = true;
+ break;
+ }
}
- inner_block->reads = std::move(new_reads);
}
- block_sref_reuse.Set(block, inner_block_realize->block);
-
+ BlockRealize inner_realize =
GenerateInner(/*is_write_reduction=*/has_outer_reduction,
+ /*iter_vars=*/inner_iter_vars,
+ /*iter_values*/ inner_bindings,
+ /*predicate=*/inner_predicate,
+ /*block=*/block_subst);
+ block_sref_reuse->Set(block, inner_realize->block);
// Step 6: Generate the outer block.
- BlockRealize outer_realize =
- GenerateBlockizedOuterBlock(extractor, new_block,
GetRef<BlockRealize>(inner_block_realize),
- collector.inner_loops, outer_pred);
- // Step 7: Do the actual replacement
- self->Replace(loop_sref, outer_realize, block_sref_reuse);
-
- // Step 8: Update the cached flags
- StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get());
- StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref,
/*require_stage_pipeline=*/false);
+ return BlockRealize(
+ /*iter_values=*/std::move(outer_bindings),
+ /*predicate=*/std::move(outer_predicate),
+ /*block=*/
+ Block(/*iter_vars=*/std::move(outer_iter_vars),
+ /*reads=*/EvalSetRegions(block_subst->reads, inner_iter_dom),
+ /*writes=*/EvalSetRegions(block_subst->writes, inner_iter_dom),
+ /*name_hint=*/block_subst->name_hint + "_o",
+ /*body=*/MakeLoopNest(inner_realize, loops),
+ /*init=*/
+ block_subst->init.defined() //
+ ? GenerateOuterInit(block_subst->init.value(), inner_realize,
loops,
+ block_subst->name_hint + "_init")
+ : Optional<Stmt>(NullOpt)));
+}
+
+StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) {
+ arith::Analyzer analyzer;
+ Map<Block, Block> block_sref_reuse;
+ BlockRealize blockized = BlockizeImpl(self, loop_sref, &block_sref_reuse,
&analyzer);
+ self->Replace(loop_sref, blockized, block_sref_reuse);
+ StmtSRef result = self->stmt2ref.at(blockized->block.get());
+ StmtSRef scope_root = tir::GetScopeRoot(self, result,
/*require_stage_pipeline=*/false);
bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root);
self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root));
self->block_info[scope_root].affine_binding = scope_block_affine_binding;
- return outer_block_sref;
-}
-
-/*!
- * \brief Update the map from the buffers in the desc to the impl of the tensor
- * intrinsic.
- * \param intrinsic The tensor intrinsic.
- * \param buffer_map The map to be updated.
- */
-void RemapTensorIntrinBuffers(
- const TensorIntrin& intrinsic,
- std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>*
buffer_map) {
- ICHECK_EQ(intrinsic->desc->params.size(), intrinsic->impl->params.size());
- for (size_t i = 0; i < intrinsic->desc->params.size(); ++i) {
- const Var& lhs_var = intrinsic->desc->params[i];
- const Buffer& lhs_buffer = intrinsic->desc->buffer_map[lhs_var];
- const Var& rhs_var = intrinsic->impl->params[i];
- const Buffer& rhs_buffer = intrinsic->impl->buffer_map[rhs_var];
- (*buffer_map)[rhs_buffer] = lhs_buffer;
- }
+ return result;
}
-void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref,
- const TensorIntrin& intrinsic) {
- /*!
- * Check:
- * - Check buffer binding, including type, alignment, shape and etc.
- * - Check the sub AST is equal to the desc function.
- *
- * Mutate:
- * - Blockize the sub AST (please refer blockize for details)
- * - Bind buffers
- * - Mutate the impl of the tensor intrinsic by replacing its buffers with
new
- * buffers created via match buffer region.
- * - Replace the sub tree with the mutated function.
- */
- const BlockRealize& desc_block_realize =
Downcast<BlockRealize>(intrinsic->desc->body);
- const BlockRealize& impl_block_realize =
Downcast<BlockRealize>(intrinsic->impl->body);
- Block impl_block = impl_block_realize->block;
-
+void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin&
intrin) {
// Step 1: Blockize the subtree rooted at the given loop if needed
- StmtSRef block_sref{nullptr};
- if (block_or_loop_sref->StmtAs<ForNode>()) {
- block_sref = Blockize(self, block_or_loop_sref);
+ BlockRealize block_realize{nullptr};
+ Optional<Block> old_block = NullOpt;
+ if (sref->stmt->IsInstance<BlockNode>()) {
+ block_realize = GetBlockRealize(self, sref);
+ old_block = block_realize->block;
+ } else if (sref->stmt->IsInstance<ForNode>()) {
+ arith::Analyzer analyzer;
+ Map<Block, Block> block_sref_reuse;
+ block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer);
} else {
- ICHECK(block_or_loop_sref->StmtAs<BlockNode>());
- block_sref = block_or_loop_sref;
+ LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: "
+ << GetRef<Stmt>(sref->stmt);
+ throw;
}
- const BlockRealize& block_realize = GetBlockRealize(self, block_sref);
-
- // Step 2: Compare the block with the desc of the tensor intrinsic, find the
correspondence
- // between buffers in the block and the desc.
+ PrimFunc intrin_desc = intrin->desc;
+ PrimFunc intrin_impl = DeepCopy(intrin->impl);
+ // Step 2: Structural pattern matching
TensorizeComparator comparator(self->mod, /*assert_mode=*/true);
- comparator.VisitStmt(block_realize, desc_block_realize);
-
- // Step 3: Find the correspondence between buffers in the current AST and
the impl of
- // the tensor intrinsic
- // Step 3.1: Map from intrinsic func buffer to desc func buffer
- std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>
intrin_buffer_map;
- RemapTensorIntrinBuffers(intrinsic, &intrin_buffer_map);
- // Step 3.2: Map form intrinsic func buffer to current AST buffer
- std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map;
- for (const auto& pair : intrin_buffer_map) {
- auto it = comparator.rhs_buffer_map_.find(pair.second);
- ICHECK(it != comparator.rhs_buffer_map_.end()) << pair.second;
- buffer_map[pair.first] = it->second;
+ comparator.VisitStmt(block_realize, intrin_desc->body);
+ // Step 3: Prepare necessary mapping
+ // 1) Buffer mapping from intrin impl buffers to intrin desc buffers.
+ // 2) Buffer mapping from intrin impl buffers to AST buffers.
+ // 3) Mapping impl buffers to their accessed regions.
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> impl2desc;
+ ICHECK_EQ(intrin_desc->params.size(), intrin_impl->params.size());
+ for (int i = 0, n = intrin_desc->params.size(); i < n; ++i) {
+ const Buffer& desc = intrin_desc->buffer_map[intrin_desc->params[i]];
+ const Buffer& impl = intrin_impl->buffer_map[intrin_impl->params[i]];
+ impl2desc[impl] = desc;
}
-
- // Step 4: Create MatchBufferRegion for the params of the impl function of
the tensor
- // intrin to make them subregions of the buffer in the original IR.
- std::unordered_map<Buffer, Array<Range>, ObjectPtrHash, ObjectPtrEqual>
buffer_region_map;
+ std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> impl2ast;
+ for (const auto& pair : impl2desc) {
+ const Buffer& impl = pair.first;
+ const Buffer& desc = pair.second;
+ ICHECK(comparator.rhs_buffer_map_.count(desc));
+ impl2ast[impl] = comparator.rhs_buffer_map_[desc];
+ }
+ std::unordered_map<Buffer, Array<Range>, ObjectPtrHash, ObjectPtrEqual>
impl2region;
+ Block impl_block = Downcast<BlockRealize>(intrin_impl->body)->block;
for (const BufferRegion& read : impl_block->reads) {
- buffer_region_map.emplace(read->buffer, read->region);
+ impl2region.emplace(read->buffer, read->region);
}
for (const BufferRegion& write : impl_block->writes) {
- buffer_region_map.emplace(write->buffer, write->region);
+ impl2region.emplace(write->buffer, write->region);
}
+ // Step 4: Create MatchBufferRegion for the params of the impl function of
the tensor
+ // intrin to make them subregions of the buffer in the original IR.
Array<MatchBufferRegion> match_buffer_regions;
- match_buffer_regions.reserve(intrinsic->impl->params.size());
- for (size_t i = 0; i < intrinsic->impl->params.size(); ++i) {
- const auto& param = intrinsic->impl->params[i];
- const auto& buffer = intrinsic->impl->buffer_map.at(param);
- const auto& source = buffer_map.at(buffer);
- // add the detected base indices to each buffer access region of the
tensor intrinsic
- Region old_region = buffer_region_map.at(buffer);
- const auto& indices_base = comparator.buffer_indices_.at(source);
+ match_buffer_regions.reserve(intrin_impl->params.size());
+ for (int i = 0, n = intrin_impl->params.size(); i < n; ++i) {
+ const Buffer& impl = intrin_impl->buffer_map.at(intrin_impl->params[i]);
+ const Buffer& ast = impl2ast.at(impl);
Review Comment:
sounds good! will go with `cur`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]