Biubiubiu12 commented on code in PR #15274:
URL: https://github.com/apache/tvm/pull/15274#discussion_r1271819281
##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef&
loop_sref, bool preserve_u
return result;
}
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>&
block_srefs,
- const StmtSRef& lca, Map<Block, Block>*
block_sref_reuse,
- bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+ static void Collect(const ScheduleState& self, const StmtSRef& lca, const
StmtSRef& block_sref,
+ Array<IterVar>* outer_iter_vars, Array<PrimExpr>*
outer_bindings,
+ Map<Var, PrimExpr>* block_var_subst) {
+ CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings,
block_var_subst);
+ StmtSRef scope_root = tir::GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false);
+ const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+ Block block = GetRef<Block>(root_block);
+ return collector(block);
+ }
+
+ private:
+ explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+ Array<IterVar>* outer_iter_vars, Array<PrimExpr>*
outer_bindings,
+ Map<Var, PrimExpr>* block_var_subst)
+ : self_(self),
+ lca_(lca),
+ outer_iter_vars_(outer_iter_vars),
+ outer_bindings_(outer_bindings),
+ block_var_subst_(block_var_subst) {}
+
+ void VisitStmt_(const ForNode* loop) final {
+ if (!in_lca) {
+ if (loop == lca_->stmt) {
+ in_lca = true;
+ }
+ outer_bindings_->push_back(loop->loop_var);
+ out_extent.push_back(loop->extent);
+ // traverse lca
+ ++num_travered;
+ StmtVisitor::VisitStmt(loop->body);
+ --num_travered;
+ if (!in_lca) {
+ outer_bindings_->pop_back();
+ out_extent.pop_back();
+ }
+ if (num_travered == 0) {
+ in_lca = false;
+ }
+ } else {
+ StmtVisitor::VisitStmt_(loop);
+ }
+ }
+
+ void VisitStmt_(const BlockNode* block) final {
+ if (block == lca_->stmt && block->name_hint == String("root")) {
+ // nothings need to substitute, so all output is nullptr.
+ return;
+ }
+ if (in_lca) {
+ if (block->iter_vars.size() > 0) {
+ // collect info
+ for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+ const IterVar& iter_var = block->iter_vars[i];
+ if (static_cast<unsigned int>(i) < out_extent.size()) {
+ arith::Analyzer ana;
+ // According to outer_bindings info, check outer iter_vars
+ ICHECK(ana.CanProveEqual(out_extent[i], iter_var->dom->extent));
+ auto outer_bind = Downcast<Var>((*outer_bindings_)[i]);
+ ObjectPtr<VarNode> new_ptr =
make_object<VarNode>(*iter_var->var.get());
+ new_ptr->name_hint = "v" + outer_bind->name_hint;
+ auto outer_iter =
IterVar(/*dom=*/RangeFromExtent(iter_var->dom->extent),
+ /*var=*/Var(new_ptr),
+ /*iter_type=*/iter_var->iter_type);
+ if (num_outer_iter_vars == 0) {
+ outer_iter_vars_->push_back(outer_iter);
+ block_var_subst_->Set(iter_var->var, outer_iter->var);
+ }
+ }
+ }
+ ++num_outer_iter_vars;
Review Comment:
In order to collect the `iter_vars` of externally generated blocks, please
refer to L432 in test.Because this information only needs to be collected once,
use `num_outer_iter_vars == 0` to judge. The maximum value of
`num_outer_iter_vars` is the size of the input `Array<StmtSRef>& block_srefs`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]