wrongtest-intellif commented on code in PR #15274:
URL: https://github.com/apache/tvm/pull/15274#discussion_r1271660324


##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -441,25 +441,29 @@ Array<BufferRegion> EvalSetRegions(const 
Array<BufferRegion>& regions,
  * \return The union regions
  */
 Array<BufferRegion> UnionRegions(const Array<BufferRegion>& regions) {
+  arith::Analyzer analyzer;
   typedef std::vector<Array<arith::IntSet>> ranges_t;
   std::unordered_map<Buffer, ranges_t, ObjectPtrHash, ObjectPtrEqual> 
intset_map;
+  Array<Buffer> buffer_order;
   for (const BufferRegion& buffer_region : regions) {
     const Buffer& buffer = buffer_region->buffer;
     if (intset_map.find(buffer) == intset_map.end()) {
       intset_map[buffer] = {buffer->shape.size(), Array<arith::IntSet>()};
+      buffer_order.push_back(buffer);
     }
     std::vector<Array<arith::IntSet>> dim_range(buffer->shape.size(), 
Array<arith::IntSet>());
     for (size_t dim = 0; dim < buffer->shape.size(); ++dim) {
       
intset_map[buffer][dim].push_back(arith::IntSet::FromRange(buffer_region->region[dim]));
     }
   }
   Array<BufferRegion> results;
-  for (const auto& it : intset_map) {
-    const Buffer& buffer = it.first;
+  for (size_t i = 0; i < buffer_order.size(); ++i) {
+    auto it = intset_map.find(buffer_order[i]);
+    const Buffer& buffer = it->first;
     Array<Range> regions;
     for (size_t dim = 0; dim < buffer->shape.size(); ++dim) {
-      const arith::IntSet intset = arith::Union(it.second[dim]);
-      regions.push_back({intset.min(), intset.max() + 1});
+      const arith::IntSet intset = arith::Union(it->second[dim]);
+      regions.push_back({analyzer.Simplify(intset.min()), 
analyzer.Simplify(intset.max() + 1)});

Review Comment:
   Why we need simplify here, could we provide an example?



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -441,25 +441,29 @@ Array<BufferRegion> EvalSetRegions(const 
Array<BufferRegion>& regions,
  * \return The union regions
  */
 Array<BufferRegion> UnionRegions(const Array<BufferRegion>& regions) {
+  arith::Analyzer analyzer;
   typedef std::vector<Array<arith::IntSet>> ranges_t;
   std::unordered_map<Buffer, ranges_t, ObjectPtrHash, ObjectPtrEqual> 
intset_map;
+  Array<Buffer> buffer_order;
   for (const BufferRegion& buffer_region : regions) {
     const Buffer& buffer = buffer_region->buffer;
     if (intset_map.find(buffer) == intset_map.end()) {
       intset_map[buffer] = {buffer->shape.size(), Array<arith::IntSet>()};
+      buffer_order.push_back(buffer);
     }
     std::vector<Array<arith::IntSet>> dim_range(buffer->shape.size(), 
Array<arith::IntSet>());
     for (size_t dim = 0; dim < buffer->shape.size(); ++dim) {
       
intset_map[buffer][dim].push_back(arith::IntSet::FromRange(buffer_region->region[dim]));
     }
   }
   Array<BufferRegion> results;
-  for (const auto& it : intset_map) {
-    const Buffer& buffer = it.first;
+  for (size_t i = 0; i < buffer_order.size(); ++i) {
+    auto it = intset_map.find(buffer_order[i]);

Review Comment:
   It seems that we do not need `buffer_order` array, it is just 
`regions[i]->buffer`?



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+  static void Collect(const ScheduleState& self, const StmtSRef& lca, const 
StmtSRef& block_sref,
+                      Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                      Map<Var, PrimExpr>* block_var_subst) {
+    CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings, 
block_var_subst);
+    StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
+    const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+    Block block = GetRef<Block>(root_block);
+    return collector(block);
+  }
+
+ private:
+  explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+                            Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                            Map<Var, PrimExpr>* block_var_subst)
+      : self_(self),
+        lca_(lca),
+        outer_iter_vars_(outer_iter_vars),
+        outer_bindings_(outer_bindings),
+        block_var_subst_(block_var_subst) {}
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (!in_lca) {
+      if (loop == lca_->stmt) {
+        in_lca = true;
+      }
+      outer_bindings_->push_back(loop->loop_var);
+      out_extent.push_back(loop->extent);
+      // traverse lca
+      ++num_travered;
+      StmtVisitor::VisitStmt(loop->body);
+      --num_travered;
+      if (!in_lca) {

Review Comment:
   is this dupliacte condition with line 595?



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+  static void Collect(const ScheduleState& self, const StmtSRef& lca, const 
StmtSRef& block_sref,
+                      Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                      Map<Var, PrimExpr>* block_var_subst) {
+    CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings, 
block_var_subst);
+    StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
+    const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+    Block block = GetRef<Block>(root_block);
+    return collector(block);
+  }
+
+ private:
+  explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+                            Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                            Map<Var, PrimExpr>* block_var_subst)
+      : self_(self),
+        lca_(lca),
+        outer_iter_vars_(outer_iter_vars),
+        outer_bindings_(outer_bindings),
+        block_var_subst_(block_var_subst) {}
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (!in_lca) {
+      if (loop == lca_->stmt) {
+        in_lca = true;
+      }
+      outer_bindings_->push_back(loop->loop_var);
+      out_extent.push_back(loop->extent);
+      // traverse lca
+      ++num_travered;
+      StmtVisitor::VisitStmt(loop->body);
+      --num_travered;
+      if (!in_lca) {
+        outer_bindings_->pop_back();
+        out_extent.pop_back();
+      }
+      if (num_travered == 0) {
+        in_lca = false;

Review Comment:
   could we use same condition with when we set `in_lca = true`: `if (loop == 
lca_->stmt)`?



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+  static void Collect(const ScheduleState& self, const StmtSRef& lca, const 
StmtSRef& block_sref,
+                      Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                      Map<Var, PrimExpr>* block_var_subst) {
+    CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings, 
block_var_subst);
+    StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
+    const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+    Block block = GetRef<Block>(root_block);
+    return collector(block);
+  }
+
+ private:
+  explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+                            Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                            Map<Var, PrimExpr>* block_var_subst)
+      : self_(self),
+        lca_(lca),
+        outer_iter_vars_(outer_iter_vars),
+        outer_bindings_(outer_bindings),
+        block_var_subst_(block_var_subst) {}
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (!in_lca) {
+      if (loop == lca_->stmt) {
+        in_lca = true;
+      }
+      outer_bindings_->push_back(loop->loop_var);
+      out_extent.push_back(loop->extent);
+      // traverse lca
+      ++num_travered;
+      StmtVisitor::VisitStmt(loop->body);
+      --num_travered;
+      if (!in_lca) {
+        outer_bindings_->pop_back();
+        out_extent.pop_back();
+      }
+      if (num_travered == 0) {
+        in_lca = false;
+      }
+    } else {
+      StmtVisitor::VisitStmt_(loop);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    if (block == lca_->stmt && block->name_hint == String("root")) {
+      // nothings need to substitute, so all output is nullptr.
+      return;
+    }
+    if (in_lca) {
+      if (block->iter_vars.size() > 0) {
+        // collect info
+        for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+          const IterVar& iter_var = block->iter_vars[i];
+          if (static_cast<unsigned int>(i) < out_extent.size()) {
+            arith::Analyzer ana;
+            // According to outer_bindings info, check outer iter_vars
+            ICHECK(ana.CanProveEqual(out_extent[i], iter_var->dom->extent));
+            auto outer_bind = Downcast<Var>((*outer_bindings_)[i]);
+            ObjectPtr<VarNode> new_ptr = 
make_object<VarNode>(*iter_var->var.get());
+            new_ptr->name_hint = "v" + outer_bind->name_hint;
+            auto outer_iter = 
IterVar(/*dom=*/RangeFromExtent(iter_var->dom->extent),
+                                      /*var=*/Var(new_ptr),
+                                      /*iter_type=*/iter_var->iter_type);
+            if (num_outer_iter_vars == 0) {
+              outer_iter_vars_->push_back(outer_iter);
+              block_var_subst_->Set(iter_var->var, outer_iter->var);
+            }
+          }
+        }
+        ++num_outer_iter_vars;
+        return;
+      }
+    }
+    StmtVisitor::VisitStmt_(block);
+  }
+
+  ScheduleState self_;
+  StmtSRef lca_;
+  Array<IterVar>* outer_iter_vars_;
+  Array<PrimExpr>* outer_bindings_;
+  Array<IterVar>* inner_iter_vars_;
+  Array<PrimExpr>* inner_bindings_;
+  Map<Var, PrimExpr>* block_var_subst_;
+  Array<PrimExpr> out_extent{nullptr};

Review Comment:
   `out_extent` -> `outer_extent`



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+  static void Collect(const ScheduleState& self, const StmtSRef& lca, const 
StmtSRef& block_sref,
+                      Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                      Map<Var, PrimExpr>* block_var_subst) {
+    CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings, 
block_var_subst);
+    StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
+    const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+    Block block = GetRef<Block>(root_block);
+    return collector(block);
+  }
+
+ private:
+  explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+                            Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                            Map<Var, PrimExpr>* block_var_subst)
+      : self_(self),
+        lca_(lca),
+        outer_iter_vars_(outer_iter_vars),
+        outer_bindings_(outer_bindings),
+        block_var_subst_(block_var_subst) {}
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (!in_lca) {
+      if (loop == lca_->stmt) {
+        in_lca = true;
+      }
+      outer_bindings_->push_back(loop->loop_var);
+      out_extent.push_back(loop->extent);
+      // traverse lca
+      ++num_travered;
+      StmtVisitor::VisitStmt(loop->body);
+      --num_travered;
+      if (!in_lca) {
+        outer_bindings_->pop_back();
+        out_extent.pop_back();
+      }
+      if (num_travered == 0) {
+        in_lca = false;
+      }
+    } else {
+      StmtVisitor::VisitStmt_(loop);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    if (block == lca_->stmt && block->name_hint == String("root")) {
+      // nothings need to substitute, so all output is nullptr.
+      return;
+    }
+    if (in_lca) {
+      if (block->iter_vars.size() > 0) {
+        // collect info
+        for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {

Review Comment:
   size_t i



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {

Review Comment:
   Please add these comments to the class's code.



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+  static void Collect(const ScheduleState& self, const StmtSRef& lca, const 
StmtSRef& block_sref,
+                      Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                      Map<Var, PrimExpr>* block_var_subst) {
+    CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings, 
block_var_subst);
+    StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
+    const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+    Block block = GetRef<Block>(root_block);
+    return collector(block);
+  }
+
+ private:
+  explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+                            Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                            Map<Var, PrimExpr>* block_var_subst)
+      : self_(self),
+        lca_(lca),
+        outer_iter_vars_(outer_iter_vars),
+        outer_bindings_(outer_bindings),
+        block_var_subst_(block_var_subst) {}
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (!in_lca) {
+      if (loop == lca_->stmt) {
+        in_lca = true;
+      }
+      outer_bindings_->push_back(loop->loop_var);
+      out_extent.push_back(loop->extent);
+      // traverse lca
+      ++num_travered;
+      StmtVisitor::VisitStmt(loop->body);
+      --num_travered;
+      if (!in_lca) {
+        outer_bindings_->pop_back();
+        out_extent.pop_back();
+      }
+      if (num_travered == 0) {
+        in_lca = false;
+      }
+    } else {
+      StmtVisitor::VisitStmt_(loop);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    if (block == lca_->stmt && block->name_hint == String("root")) {
+      // nothings need to substitute, so all output is nullptr.
+      return;
+    }
+    if (in_lca) {
+      if (block->iter_vars.size() > 0) {
+        // collect info
+        for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+          const IterVar& iter_var = block->iter_vars[i];
+          if (static_cast<unsigned int>(i) < out_extent.size()) {

Review Comment:
   here we assumes `iter_var[i]` always refer to `loops[i]`, we may prove it is 
the case or explicitly check it.



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+  static void Collect(const ScheduleState& self, const StmtSRef& lca, const 
StmtSRef& block_sref,
+                      Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                      Map<Var, PrimExpr>* block_var_subst) {
+    CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings, 
block_var_subst);
+    StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
+    const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+    Block block = GetRef<Block>(root_block);
+    return collector(block);
+  }
+
+ private:
+  explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+                            Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                            Map<Var, PrimExpr>* block_var_subst)
+      : self_(self),
+        lca_(lca),
+        outer_iter_vars_(outer_iter_vars),
+        outer_bindings_(outer_bindings),
+        block_var_subst_(block_var_subst) {}
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (!in_lca) {
+      if (loop == lca_->stmt) {
+        in_lca = true;
+      }
+      outer_bindings_->push_back(loop->loop_var);
+      out_extent.push_back(loop->extent);
+      // traverse lca
+      ++num_travered;
+      StmtVisitor::VisitStmt(loop->body);
+      --num_travered;
+      if (!in_lca) {
+        outer_bindings_->pop_back();
+        out_extent.pop_back();
+      }
+      if (num_travered == 0) {
+        in_lca = false;
+      }
+    } else {
+      StmtVisitor::VisitStmt_(loop);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    if (block == lca_->stmt && block->name_hint == String("root")) {
+      // nothings need to substitute, so all output is nullptr.
+      return;
+    }
+    if (in_lca) {
+      if (block->iter_vars.size() > 0) {
+        // collect info
+        for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+          const IterVar& iter_var = block->iter_vars[i];
+          if (static_cast<unsigned int>(i) < out_extent.size()) {
+            arith::Analyzer ana;
+            // According to outer_bindings info, check outer iter_vars
+            ICHECK(ana.CanProveEqual(out_extent[i], iter_var->dom->extent));
+            auto outer_bind = Downcast<Var>((*outer_bindings_)[i]);
+            ObjectPtr<VarNode> new_ptr = 
make_object<VarNode>(*iter_var->var.get());
+            new_ptr->name_hint = "v" + outer_bind->name_hint;
+            auto outer_iter = 
IterVar(/*dom=*/RangeFromExtent(iter_var->dom->extent),
+                                      /*var=*/Var(new_ptr),
+                                      /*iter_type=*/iter_var->iter_type);
+            if (num_outer_iter_vars == 0) {
+              outer_iter_vars_->push_back(outer_iter);
+              block_var_subst_->Set(iter_var->var, outer_iter->var);
+            }
+          }
+        }
+        ++num_outer_iter_vars;

Review Comment:
   the `num_outer_iter_vars` seems to be either 0 or 1 by implementation, it is 
as expected?



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+  static void Collect(const ScheduleState& self, const StmtSRef& lca, const 
StmtSRef& block_sref,
+                      Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                      Map<Var, PrimExpr>* block_var_subst) {
+    CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings, 
block_var_subst);
+    StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
+    const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+    Block block = GetRef<Block>(root_block);
+    return collector(block);
+  }
+
+ private:
+  explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+                            Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                            Map<Var, PrimExpr>* block_var_subst)
+      : self_(self),
+        lca_(lca),
+        outer_iter_vars_(outer_iter_vars),
+        outer_bindings_(outer_bindings),
+        block_var_subst_(block_var_subst) {}
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (!in_lca) {
+      if (loop == lca_->stmt) {
+        in_lca = true;
+      }
+      outer_bindings_->push_back(loop->loop_var);
+      out_extent.push_back(loop->extent);
+      // traverse lca
+      ++num_travered;
+      StmtVisitor::VisitStmt(loop->body);
+      --num_travered;
+      if (!in_lca) {
+        outer_bindings_->pop_back();
+        out_extent.pop_back();
+      }
+      if (num_travered == 0) {
+        in_lca = false;
+      }
+    } else {
+      StmtVisitor::VisitStmt_(loop);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    if (block == lca_->stmt && block->name_hint == String("root")) {
+      // nothings need to substitute, so all output is nullptr.
+      return;
+    }
+    if (in_lca) {
+      if (block->iter_vars.size() > 0) {
+        // collect info
+        for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+          const IterVar& iter_var = block->iter_vars[i];
+          if (static_cast<unsigned int>(i) < out_extent.size()) {
+            arith::Analyzer ana;
+            // According to outer_bindings info, check outer iter_vars
+            ICHECK(ana.CanProveEqual(out_extent[i], iter_var->dom->extent));
+            auto outer_bind = Downcast<Var>((*outer_bindings_)[i]);
+            ObjectPtr<VarNode> new_ptr = 
make_object<VarNode>(*iter_var->var.get());
+            new_ptr->name_hint = "v" + outer_bind->name_hint;
+            auto outer_iter = 
IterVar(/*dom=*/RangeFromExtent(iter_var->dom->extent),

Review Comment:
   `RangeFromExtent(iter_var->dom->extent)` is just `iter_var->dom`?



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+  static void Collect(const ScheduleState& self, const StmtSRef& lca, const 
StmtSRef& block_sref,
+                      Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                      Map<Var, PrimExpr>* block_var_subst) {
+    CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings, 
block_var_subst);
+    StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
+    const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+    Block block = GetRef<Block>(root_block);
+    return collector(block);
+  }
+
+ private:
+  explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+                            Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                            Map<Var, PrimExpr>* block_var_subst)
+      : self_(self),
+        lca_(lca),
+        outer_iter_vars_(outer_iter_vars),
+        outer_bindings_(outer_bindings),
+        block_var_subst_(block_var_subst) {}
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (!in_lca) {
+      if (loop == lca_->stmt) {
+        in_lca = true;
+      }
+      outer_bindings_->push_back(loop->loop_var);
+      out_extent.push_back(loop->extent);
+      // traverse lca
+      ++num_travered;
+      StmtVisitor::VisitStmt(loop->body);
+      --num_travered;
+      if (!in_lca) {
+        outer_bindings_->pop_back();
+        out_extent.pop_back();
+      }
+      if (num_travered == 0) {
+        in_lca = false;
+      }
+    } else {
+      StmtVisitor::VisitStmt_(loop);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    if (block == lca_->stmt && block->name_hint == String("root")) {
+      // nothings need to substitute, so all output is nullptr.
+      return;
+    }
+    if (in_lca) {
+      if (block->iter_vars.size() > 0) {
+        // collect info
+        for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+          const IterVar& iter_var = block->iter_vars[i];
+          if (static_cast<unsigned int>(i) < out_extent.size()) {
+            arith::Analyzer ana;
+            // According to outer_bindings info, check outer iter_vars
+            ICHECK(ana.CanProveEqual(out_extent[i], iter_var->dom->extent));
+            auto outer_bind = Downcast<Var>((*outer_bindings_)[i]);
+            ObjectPtr<VarNode> new_ptr = 
make_object<VarNode>(*iter_var->var.get());
+            new_ptr->name_hint = "v" + outer_bind->name_hint;

Review Comment:
   we could use Var::copy_with_suffix



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+  static void Collect(const ScheduleState& self, const StmtSRef& lca, const 
StmtSRef& block_sref,
+                      Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                      Map<Var, PrimExpr>* block_var_subst) {
+    CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings, 
block_var_subst);
+    StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
+    const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+    Block block = GetRef<Block>(root_block);
+    return collector(block);
+  }
+
+ private:
+  explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+                            Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                            Map<Var, PrimExpr>* block_var_subst)
+      : self_(self),
+        lca_(lca),
+        outer_iter_vars_(outer_iter_vars),
+        outer_bindings_(outer_bindings),
+        block_var_subst_(block_var_subst) {}
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (!in_lca) {
+      if (loop == lca_->stmt) {
+        in_lca = true;
+      }
+      outer_bindings_->push_back(loop->loop_var);
+      out_extent.push_back(loop->extent);
+      // traverse lca
+      ++num_travered;
+      StmtVisitor::VisitStmt(loop->body);
+      --num_travered;
+      if (!in_lca) {
+        outer_bindings_->pop_back();
+        out_extent.pop_back();
+      }
+      if (num_travered == 0) {
+        in_lca = false;
+      }
+    } else {
+      StmtVisitor::VisitStmt_(loop);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    if (block == lca_->stmt && block->name_hint == String("root")) {
+      // nothings need to substitute, so all output is nullptr.
+      return;
+    }
+    if (in_lca) {
+      if (block->iter_vars.size() > 0) {
+        // collect info
+        for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+          const IterVar& iter_var = block->iter_vars[i];
+          if (static_cast<unsigned int>(i) < out_extent.size()) {
+            arith::Analyzer ana;
+            // According to outer_bindings info, check outer iter_vars
+            ICHECK(ana.CanProveEqual(out_extent[i], iter_var->dom->extent));
+            auto outer_bind = Downcast<Var>((*outer_bindings_)[i]);
+            ObjectPtr<VarNode> new_ptr = 
make_object<VarNode>(*iter_var->var.get());
+            new_ptr->name_hint = "v" + outer_bind->name_hint;
+            auto outer_iter = 
IterVar(/*dom=*/RangeFromExtent(iter_var->dom->extent),
+                                      /*var=*/Var(new_ptr),
+                                      /*iter_type=*/iter_var->iter_type);
+            if (num_outer_iter_vars == 0) {
+              outer_iter_vars_->push_back(outer_iter);
+              block_var_subst_->Set(iter_var->var, outer_iter->var);
+            }
+          }
+        }
+        ++num_outer_iter_vars;
+        return;
+      }
+    }
+    StmtVisitor::VisitStmt_(block);
+  }
+
+  ScheduleState self_;
+  StmtSRef lca_;
+  Array<IterVar>* outer_iter_vars_;
+  Array<PrimExpr>* outer_bindings_;
+  Array<IterVar>* inner_iter_vars_;
+  Array<PrimExpr>* inner_bindings_;
+  Map<Var, PrimExpr>* block_var_subst_;
+  Array<PrimExpr> out_extent{nullptr};

Review Comment:
   strange initializer `nullptr`



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+  static void Collect(const ScheduleState& self, const StmtSRef& lca, const 
StmtSRef& block_sref,
+                      Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                      Map<Var, PrimExpr>* block_var_subst) {
+    CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings, 
block_var_subst);
+    StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
+    const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+    Block block = GetRef<Block>(root_block);
+    return collector(block);
+  }
+
+ private:
+  explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+                            Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                            Map<Var, PrimExpr>* block_var_subst)
+      : self_(self),
+        lca_(lca),
+        outer_iter_vars_(outer_iter_vars),
+        outer_bindings_(outer_bindings),
+        block_var_subst_(block_var_subst) {}
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (!in_lca) {
+      if (loop == lca_->stmt) {
+        in_lca = true;
+      }
+      outer_bindings_->push_back(loop->loop_var);
+      out_extent.push_back(loop->extent);
+      // traverse lca
+      ++num_travered;
+      StmtVisitor::VisitStmt(loop->body);
+      --num_travered;
+      if (!in_lca) {
+        outer_bindings_->pop_back();
+        out_extent.pop_back();
+      }
+      if (num_travered == 0) {
+        in_lca = false;
+      }
+    } else {
+      StmtVisitor::VisitStmt_(loop);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    if (block == lca_->stmt && block->name_hint == String("root")) {
+      // nothings need to substitute, so all output is nullptr.
+      return;
+    }
+    if (in_lca) {
+      if (block->iter_vars.size() > 0) {

Review Comment:
   !block->iter_vars.empty()



##########
src/tir/schedule/primitive/blockize_tensorize.cc:
##########
@@ -565,58 +569,230 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
-BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
-                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
-                            bool preserve_unit_iters) {
+class CollectSubstInfo : public StmtVisitor {
+ public:
+  static void Collect(const ScheduleState& self, const StmtSRef& lca, const 
StmtSRef& block_sref,
+                      Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                      Map<Var, PrimExpr>* block_var_subst) {
+    CollectSubstInfo collector(self, lca, outer_iter_vars, outer_bindings, 
block_var_subst);
+    StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);
+    const BlockNode* root_block = scope_root->StmtAs<BlockNode>();
+    Block block = GetRef<Block>(root_block);
+    return collector(block);
+  }
+
+ private:
+  explicit CollectSubstInfo(const ScheduleState& self, const StmtSRef& lca,
+                            Array<IterVar>* outer_iter_vars, Array<PrimExpr>* 
outer_bindings,
+                            Map<Var, PrimExpr>* block_var_subst)
+      : self_(self),
+        lca_(lca),
+        outer_iter_vars_(outer_iter_vars),
+        outer_bindings_(outer_bindings),
+        block_var_subst_(block_var_subst) {}
+
+  void VisitStmt_(const ForNode* loop) final {
+    if (!in_lca) {
+      if (loop == lca_->stmt) {
+        in_lca = true;
+      }
+      outer_bindings_->push_back(loop->loop_var);
+      out_extent.push_back(loop->extent);
+      // traverse lca
+      ++num_travered;
+      StmtVisitor::VisitStmt(loop->body);
+      --num_travered;
+      if (!in_lca) {
+        outer_bindings_->pop_back();
+        out_extent.pop_back();
+      }
+      if (num_travered == 0) {
+        in_lca = false;
+      }
+    } else {
+      StmtVisitor::VisitStmt_(loop);
+    }
+  }
+
+  void VisitStmt_(const BlockNode* block) final {
+    if (block == lca_->stmt && block->name_hint == String("root")) {
+      // nothings need to substitute, so all output is nullptr.
+      return;
+    }
+    if (in_lca) {
+      if (block->iter_vars.size() > 0) {
+        // collect info
+        for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+          const IterVar& iter_var = block->iter_vars[i];
+          if (static_cast<unsigned int>(i) < out_extent.size()) {
+            arith::Analyzer ana;
+            // According to outer_bindings info, check outer iter_vars
+            ICHECK(ana.CanProveEqual(out_extent[i], iter_var->dom->extent));
+            auto outer_bind = Downcast<Var>((*outer_bindings_)[i]);
+            ObjectPtr<VarNode> new_ptr = 
make_object<VarNode>(*iter_var->var.get());
+            new_ptr->name_hint = "v" + outer_bind->name_hint;
+            auto outer_iter = 
IterVar(/*dom=*/RangeFromExtent(iter_var->dom->extent),
+                                      /*var=*/Var(new_ptr),
+                                      /*iter_type=*/iter_var->iter_type);
+            if (num_outer_iter_vars == 0) {
+              outer_iter_vars_->push_back(outer_iter);
+              block_var_subst_->Set(iter_var->var, outer_iter->var);
+            }
+          }
+        }
+        ++num_outer_iter_vars;
+        return;
+      }
+    }
+    StmtVisitor::VisitStmt_(block);
+  }
+
+  ScheduleState self_;
+  StmtSRef lca_;
+  Array<IterVar>* outer_iter_vars_;
+  Array<PrimExpr>* outer_bindings_;
+  Array<IterVar>* inner_iter_vars_;
+  Array<PrimExpr>* inner_bindings_;
+  Map<Var, PrimExpr>* block_var_subst_;
+  Array<PrimExpr> out_extent{nullptr};
+  bool in_lca = false;
+  int num_outer_iter_vars = 0;
+  int num_travered = 0;
+};
+
+class BlockizeBlocks : public StmtMutator {
+  Array<StmtSRef> blocks_;
+  StmtSRef lca_;
+  Map<Block, Block>* block_sref_reuse_;
+  BlockRealize* blockized_;
   Array<Stmt> seq_body;
-  PrimExpr outer_predicate{nullptr};
   Array<IterVar> outer_iter_vars{nullptr};
   Array<PrimExpr> outer_bindings{nullptr};
+  Array<IterVar> inner_iter_vars{nullptr};
+  Map<Var, PrimExpr> block_var_subst;
   Array<BufferRegion> read_regions;
   Array<BufferRegion> write_regions;
   std::string outer_block_name = "outer_";
   Map<Var, Var> loop_var_subst;
   arith::Analyzer analyzer;
-  for (const auto& block_sref : block_srefs) {
-    auto block_realize = GetBlockRealize(self, block_sref);
-    auto block = block_realize->block;
-    // Step 1: Derive subspace division
-    std::vector<const ForNode*> loops;
-    Array<Array<arith::IterMark>> division = SubspaceDivide(block_realize, 
block_sref, lca, &loops,
-                                                            &analyzer, 
preserve_unit_iters, true);
-    if (division.empty()) {
-      throw SubspaceNotDivisibleError(self->mod, GetRef<For>(loops.back()), 
block);
-    }
-    outer_predicate = division.back()[0]->extent;
-    PrimExpr inner_predicate = division.back()[1]->extent;
-    // Step 2. Derive block bindings for both outer and inner block.
-    Array<IterVar> inner_iter_vars;
-    Array<PrimExpr> inner_bindings;
-    Map<Var, PrimExpr> block_var_subst =                       //
-        DeriveBlockBinding(block->iter_vars, division,         //
-                           &outer_iter_vars, &outer_bindings,  //
-                           &inner_iter_vars, &inner_bindings,  //
-                           preserve_unit_iters, outer_iter_vars.defined());
-    // Step 3: Do var substitution to adjust to the new block bindings
+  Block tmp_in_block;
+  Map<Var, arith::IntSet> inner_iter_dom;
+  bool _first_in = false;

Review Comment:
   It would be great to make variable name consistent for `_first_in`, 
`target_in`, `IsBlockNode`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to