This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f863dda62 [TIR][Schedule] Improve blockize to support blockizing 
multiple blocks (#14766)
2f863dda62 is described below

commit 2f863dda62c80ce16bbbb608a0bc7afe3589a864
Author: multiverstack <[email protected]>
AuthorDate: Tue May 16 17:42:14 2023 +0800

    [TIR][Schedule] Improve blockize to support blockizing multiple blocks 
(#14766)
    
    * Improve blockize to support blockize multiple blocks
    
    * Adjust unit test to match simplified blockize result.
    
    * Update doc
    
    * Preserve unit iters in expr and revert test case change
    
    * Apply review suggestion
    
    ---------
    
    Co-authored-by: Min Chen <[email protected]>
---
 include/tvm/tir/schedule/schedule.h                |   7 +
 python/tvm/tir/schedule/schedule.py                |  12 +-
 src/tir/schedule/concrete_schedule.cc              |   9 +
 src/tir/schedule/concrete_schedule.h               |   1 +
 src/tir/schedule/primitive.h                       |  10 +
 src/tir/schedule/primitive/blockize_tensorize.cc   | 287 +++++++++++++++++++--
 src/tir/schedule/schedule.cc                       |   9 +-
 src/tir/schedule/traced_schedule.cc                |  11 +
 src/tir/schedule/traced_schedule.h                 |   1 +
 .../metaschedule_e2e/test_resnet50_int8.py         |   2 +-
 .../unittest/test_meta_schedule_trace_apply.py     |  18 +-
 .../python/unittest/test_tir_schedule_blockize.py  |  49 ++++
 12 files changed, 373 insertions(+), 43 deletions(-)

diff --git a/include/tvm/tir/schedule/schedule.h 
b/include/tvm/tir/schedule/schedule.h
index 69f0520117..187d0a31d0 100644
--- a/include/tvm/tir/schedule/schedule.h
+++ b/include/tvm/tir/schedule/schedule.h
@@ -641,6 +641,13 @@ class ScheduleNode : public runtime::Object {
    * \return the new block
    */
   virtual BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters = 
true) = 0;
+  /*!
+   * \brief Convert specified blocks into a nested block.
+   * \param blocks the specified block to construct the new block
+   * \param preserve_unit_iters Whether or not to preserve unit iterators in 
block bindings
+   * \return the new block
+   */
+  virtual BlockRV Blockize(const Array<BlockRV>& blocks, bool 
preserve_unit_iters = true) = 0;
   /*!
    * \brief Tensorize the computation enclosed by loop with the tensor intrin.
    * \param loop_rv The loop to be tensorized
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 7c7af998be..8ebc02ccbb 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2691,13 +2691,15 @@ class Schedule(Object):
     ########## Schedule: Blockize & Tensorize ##########
 
     @type_checked
-    def blockize(self, loop: LoopRV, preserve_unit_iters: bool = True) -> 
BlockRV:
-        """Convert the subtree rooted at a specific loop into a block.
+    def blockize(
+        self, target: Union[LoopRV, List[BlockRV]], preserve_unit_iters: bool 
= True
+    ) -> BlockRV:
+        """Convert multiple blocks or the subtree rooted at a specific loop 
into a block.
 
         Parameters
         ----------
-        loop : LoopRV
-            The root of the subtree.
+        target : LoopRV or List[BlockRV]
+            The root of the subtree or the specified blocks.
         preserve_unit_iters : bool
             Whether or not to preserve unit iterators in block bindings
 
@@ -2764,7 +2766,7 @@ class Schedule(Object):
         block are divisible by the subspace represented by the loops starting 
at the given loop.
         """
 
-        return _ffi_api.ScheduleBlockize(self, loop, preserve_unit_iters)  # 
type: ignore # pylint: disable=no-member
+        return _ffi_api.ScheduleBlockize(self, target, preserve_unit_iters)  # 
type: ignore # pylint: disable=no-member
 
     @type_checked
     def tensorize(
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index 7192a48099..d485127242 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -791,6 +791,15 @@ BlockRV ConcreteScheduleNode::Blockize(const LoopRV& 
loop_rv, bool preserve_unit
   return CreateRV<BlockRV>(result);
 }
 
+BlockRV ConcreteScheduleNode::Blockize(const Array<BlockRV>& blocks, bool 
preserve_unit_iters) {
+  StmtSRef result{nullptr};
+  TVM_TIR_SCHEDULE_BEGIN();
+  result = tir::Blockize(state_, this->GetSRefs(blocks), preserve_unit_iters);
+  this->state_->DebugVerify();
+  TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_);
+  return CreateRV<BlockRV>(result);
+}
+
 void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& 
intrin,
                                      bool preserve_unit_iters) {
   TVM_TIR_SCHEDULE_BEGIN();
diff --git a/src/tir/schedule/concrete_schedule.h 
b/src/tir/schedule/concrete_schedule.h
index 16065df3cd..73a0b314dd 100644
--- a/src/tir/schedule/concrete_schedule.h
+++ b/src/tir/schedule/concrete_schedule.h
@@ -153,6 +153,7 @@ class ConcreteScheduleNode : public ScheduleNode {
   void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& 
dtype) override;
   /******** Schedule: Blockize & Tensorize ********/
   BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override;
+  BlockRV Blockize(const Array<BlockRV>& blocks, bool preserve_unit_iters) 
override;
   void Tensorize(const BlockRV& block_rv, const String& intrin, bool 
preserve_unit_iters) override;
   void Tensorize(const LoopRV& loop_rv, const String& intrin, bool 
preserve_unit_iters) override;
   /******** Schedule: Annotation ********/
diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h
index 78d1cab05c..7355d38db1 100644
--- a/src/tir/schedule/primitive.h
+++ b/src/tir/schedule/primitive.h
@@ -542,6 +542,16 @@ TVM_DLL void SetAxisSeparator(ScheduleState self, const 
StmtSRef& block_sref, in
  */
 TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool 
preserve_unit_iters);
 
+/*!
+ * \brief Convert specific blocks into a nested block.
+ * \param self The state of the schedule
+ * \param blocks The target blocks to construct the new block
+ * \param preserve_unit_iters Whether or not to preserve unit iterators in 
block bindings
+ * \return The new block
+ */
+TVM_DLL StmtSRef Blockize(ScheduleState self, const Array<StmtSRef>& blocks,
+                          bool preserve_unit_iters);
+
 /*!
  * \brief Tensorize the computation enclosed by loop with the tensor intrinsic.
  * \param self The state of the schedule
diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc 
b/src/tir/schedule/primitive/blockize_tensorize.cc
index 25694ed6fc..994a3a95fb 100644
--- a/src/tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -139,20 +139,26 @@ Array<Array<arith::IterMark>> 
TrivialSubspaceDivision(const Array<IterVar>& iter
 
 /*!
  * \brief Subspace division. The space is divided into two subspaces:
+ * If loop_sref_as_outer is false:
  *  1. The subspace represented by the outer loops above `loop_sref` 
(exclusive).
  *  2. The subspace represented by the inner loops below `loop_sref` 
(inclusive).
+ * else:
+ *  1. The subspace represented by the outer loops above `loop_sref` 
(inclusive).
+ *  2. The subspace represented by the inner loops below `loop_sref` 
(exclusive).
  * \param realize The inner block
  * \param block_sref The sref to the inner block
  * \param loop_sref The loop that is the root of the second subspace.
  * \param loops The loops that represents the second part of the subspace.
  * \param analyzer The arithmetic analyzer to use.
  * \param preserve_unit_iters Whether or not to preserve unit iterators in 
block bindings
+ * \param loop_sref_as_outer Whether loop_sref is divided into outer or inner
  */
 Array<Array<arith::IterMark>> SubspaceDivide(const BlockRealize& realize,
                                              const StmtSRef& block_sref,  //
                                              const StmtSRef& loop_sref,   //
                                              std::vector<const ForNode*>* 
loops,
-                                             arith::Analyzer* analyzer, bool 
preserve_unit_iters) {
+                                             arith::Analyzer* analyzer, bool 
preserve_unit_iters,
+                                             bool loop_sref_as_outer = false) {
   Array<Var> inner_vars;
   Array<Var> outer_vars;
   Map<Var, Range> loop_var_domain;
@@ -168,7 +174,7 @@ Array<Array<arith::IterMark>> SubspaceDivide(const 
BlockRealize& realize,
       outer_vars.push_back(loop->loop_var);
     }
     loop_var_domain.Set(loop->loop_var, Range::FromMinExtent(loop->min, 
loop->extent));
-    if (sref == loop_sref.get()) {
+    if ((loop_sref_as_outer && sref->parent == loop_sref.get()) || sref == 
loop_sref.get()) {
       inner = false;
     }
   }
@@ -201,12 +207,14 @@ Map<Var, PrimExpr> DeriveBlockBinding(const 
Array<IterVar>& iter_vars,
                                       Array<IterVar>* outer_iter_vars,         
       //
                                       Array<PrimExpr>* outer_bindings,         
       //
                                       Array<IterVar>* inner_iter_vars,         
       //
-                                      Array<PrimExpr>* inner_bindings, bool 
preserve_unit_iters) {
+                                      Array<PrimExpr>* inner_bindings,         
       //
+                                      bool preserve_unit_iters, bool 
reuse_outer = false) {
   using arith::IterMapExpr;
   using arith::IterMapExprNode;
   using arith::NormalizeIterMapToExpr;
   Map<Var, PrimExpr> block_var_subst;
   ICHECK_EQ(iter_vars.size() + 1, division.size());
+  arith::Analyzer ana;
   for (int i = 0, n = iter_vars.size(); i < n; ++i) {
     const IterVar& iter_var = iter_vars[i];
     arith::IterMark outer_mark = division[i][0];
@@ -219,30 +227,43 @@ Map<Var, PrimExpr> DeriveBlockBinding(const 
Array<IterVar>& iter_vars,
     // The inner block will have binding: iter_inner -> inner_binding
     // The iter in the original block will be substituted with base + 
iter_inner where
     // base == iter_outer * iter_inner_extent
-    if (is_one(inner_mark->extent)) {  // IsOuter
-      // extract this iter var to outer block directly
+    // create iter var for the outer block
+    IterVar outer_iter;
+    if (reuse_outer) {
+      outer_iter = outer_iter_vars->operator[](i);
+      ICHECK(ana.CanProveEqual(outer_iter->dom->extent, outer_mark->extent));
+      ICHECK(
+          ana.CanProveEqual(outer_bindings->operator[](i), 
NormalizeIterMapToExpr(outer_binding)));
+    } else {
+      outer_iter = IterVar(/*dom=*/RangeFromExtent(outer_mark->extent),
+                           /*var=*/iter_var->var.copy_with_suffix("_o"),
+                           /*iter_type=*/iter_var->iter_type);
       outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
-      outer_iter_vars->push_back(iter_var);
-      continue;
+      outer_iter_vars->push_back(outer_iter);
     }
-    // create iter var for the outer block
-    IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent),
-                       /*var=*/iter_var->var.copy_with_suffix("_o"),
-                       /*iter_type=*/iter_var->iter_type);
-    outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
-    outer_iter_vars->push_back(outer_iter);
-    // create iter var for the inner block
-    IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent),
-                       /*var=*/iter_var->var.copy_with_suffix("_i"),
-                       /*iter_type=*/iter_var->iter_type);
-    inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding));
-    inner_iter_vars->push_back(inner_iter);
-    // substitution
     PrimExpr sub{nullptr};
-    if (is_one(outer_mark->extent)) {
-      sub = inner_iter->var;
+    if (is_one(inner_mark->extent)) {
+      // Skip inner var when extent is 1
+      // substitution
+      if (is_one(outer_mark->extent) && !preserve_unit_iters) {
+        // Simplify outer if not preserve_unit_iters
+        sub = make_zero(outer_mark->extent.dtype());
+      } else {
+        sub = outer_iter;
+      }
     } else {
-      sub = outer_iter * inner_mark->extent + inner_iter->var;
+      // create iter var for the inner block
+      IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent),
+                         /*var=*/iter_var->var.copy_with_suffix("_i"),
+                         /*iter_type=*/iter_var->iter_type);
+      inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding));
+      inner_iter_vars->push_back(inner_iter);
+      // substitution
+      if (is_one(outer_mark->extent)) {
+        sub = inner_iter->var;
+      } else {
+        sub = outer_iter * inner_mark->extent + inner_iter->var;
+      }
     }
     block_var_subst.Set(iter_var->var, sub);
   }
@@ -414,6 +435,37 @@ Array<BufferRegion> EvalSetRegions(const 
Array<BufferRegion>& regions,
   return results;
 }
 
+/*!
+ * \brief Get the union of the given regions
+ * \param regions The input regions for the union.
+ * \return The union regions
+ */
+Array<BufferRegion> UnionRegions(const Array<BufferRegion>& regions) {
+  typedef std::vector<Array<arith::IntSet>> ranges_t;
+  std::unordered_map<Buffer, ranges_t, ObjectPtrHash, ObjectPtrEqual> 
intset_map;
+  for (const BufferRegion& buffer_region : regions) {
+    const Buffer& buffer = buffer_region->buffer;
+    if (intset_map.find(buffer) == intset_map.end()) {
+      intset_map[buffer] = {buffer->shape.size(), Array<arith::IntSet>()};
+    }
+    std::vector<Array<arith::IntSet>> dim_range(buffer->shape.size(), 
Array<arith::IntSet>());
+    for (size_t dim = 0; dim < buffer->shape.size(); ++dim) {
+      
intset_map[buffer][dim].push_back(arith::IntSet::FromRange(buffer_region->region[dim]));
+    }
+  }
+  Array<BufferRegion> results;
+  for (const auto& it : intset_map) {
+    const Buffer& buffer = it.first;
+    Array<Range> regions;
+    for (size_t dim = 0; dim < buffer->shape.size(); ++dim) {
+      const arith::IntSet intset = arith::Union(it.second[dim]);
+      regions.push_back({intset.min(), intset.max() + 1});
+    }
+    results.push_back(BufferRegion(buffer, regions));
+  }
+  return results;
+}
+
 /*!
  * \brief Create the loop nest on top of the given stmt.
  * \param stmt The stmt to be wrapped.
@@ -513,6 +565,181 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& 
loop_sref, bool preserve_u
   return result;
 }
 
+BlockRealize BlockizeBlocks(const ScheduleState& self, const Array<StmtSRef>& 
block_srefs,
+                            const StmtSRef& lca, Map<Block, Block>* 
block_sref_reuse,
+                            bool preserve_unit_iters) {
+  Array<Stmt> seq_body;
+  PrimExpr outer_predicate{nullptr};
+  Array<IterVar> outer_iter_vars{nullptr};
+  Array<PrimExpr> outer_bindings{nullptr};
+  Array<BufferRegion> read_regions;
+  Array<BufferRegion> write_regions;
+  std::string outer_block_name = "outer_";
+  Map<Var, Var> loop_var_subst;
+  arith::Analyzer analyzer;
+  for (const auto& block_sref : block_srefs) {
+    auto block_realize = GetBlockRealize(self, block_sref);
+    auto block = block_realize->block;
+    // Step 1: Derive subspace division
+    std::vector<const ForNode*> loops;
+    Array<Array<arith::IterMark>> division = SubspaceDivide(block_realize, 
block_sref, lca, &loops,
+                                                            &analyzer, 
preserve_unit_iters, true);
+    if (division.empty()) {
+      throw SubspaceNotDivisibleError(self->mod, GetRef<For>(loops.back()), 
block);
+    }
+    outer_predicate = division.back()[0]->extent;
+    PrimExpr inner_predicate = division.back()[1]->extent;
+    // Step 2. Derive block bindings for both outer and inner block.
+    Array<IterVar> inner_iter_vars;
+    Array<PrimExpr> inner_bindings;
+    Map<Var, PrimExpr> block_var_subst =                       //
+        DeriveBlockBinding(block->iter_vars, division,         //
+                           &outer_iter_vars, &outer_bindings,  //
+                           &inner_iter_vars, &inner_bindings,  //
+                           preserve_unit_iters, outer_iter_vars.defined());
+    // Step 3: Do var substitution to adjust to the new block bindings
+    for (size_t i = 0; i < outer_iter_vars.size(); ++i) {
+      if (outer_bindings[i].as<Var>()) {
+        loop_var_subst.Set(Downcast<Var>(outer_bindings[i]), 
outer_iter_vars[i]->var);
+      }
+    }
+    Map<Var, arith::IntSet> inner_iter_dom;
+    for (const IterVar& iter : inner_iter_vars) {
+      Range dom = Substitute(iter->dom, loop_var_subst);
+      inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(dom));
+      analyzer.Bind(iter->var, dom);
+    }
+    Block block_subst =
+        Downcast<Block>(Substitute(block, block_var_subst, block_sref_reuse, 
&analyzer));
+    auto reads = EvalSetRegions(block_subst->reads, inner_iter_dom);
+    auto writes = EvalSetRegions(block_subst->writes, inner_iter_dom);
+    read_regions.insert(read_regions.end(), reads.begin(), reads.end());
+    write_regions.insert(write_regions.end(), writes.begin(), writes.end());
+    outer_block_name += block_subst->name_hint + "_";
+    // Step 4: Generate the inner block. No reduction iter vars allowed for 
the outer loops.
+    bool has_outer_reduction = false;
+    if (block_subst->init.defined()) {
+      for (const IterVar& iter_var : outer_iter_vars) {
+        if (iter_var->iter_type == kCommReduce) {
+          has_outer_reduction = true;
+          break;
+        }
+      }
+    }
+    ICHECK(has_outer_reduction == false)
+        << "No reduction iter vars allowed for the outer loops when blockize 
multiple blocks";
+    BlockRealize inner_realize = 
GenerateInner(/*is_write_reduction=*/has_outer_reduction,
+                                               /*iter_vars=*/inner_iter_vars,
+                                               /*iter_values*/ inner_bindings,
+                                               /*predicate=*/inner_predicate,
+                                               /*block=*/block_subst);
+    block_sref_reuse->Set(block, inner_realize->block);
+    Stmt stmt = inner_realize;
+    for (const ForNode* loop : loops) {
+      ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
+      new_loop->body = std::move(stmt);
+      new_loop->extent = Substitute(new_loop->extent, loop_var_subst);
+      stmt = For(new_loop);
+    }
+    seq_body.push_back(stmt);
+  }
+  // Step 5: Generate the outer block.
+  return BlockRealize(
+      /*iter_values=*/std::move(outer_bindings),
+      /*predicate=*/std::move(outer_predicate),
+      /*block=*/
+      Block(/*iter_vars=*/std::move(outer_iter_vars),
+            /*reads=*/UnionRegions(read_regions),
+            /*writes=*/UnionRegions(write_regions),
+            /*name_hint=*/outer_block_name,
+            /*body=*/SeqStmt(seq_body),
+            /*init=*/Optional<Stmt>(NullOpt)));
+}
+
+class BlockizeRewriter : public StmtMutator {
+ public:
+  static Stmt Rewrite(const StmtSRef& lca, const Array<StmtSRef>& blocks,
+                      const BlockRealize& blockized) {
+    BlockizeRewriter rewriter(lca, blocks, blockized);
+    return rewriter(GetRef<Stmt>(lca->stmt));
+  }
+
+ private:
+  explicit BlockizeRewriter(const StmtSRef& lca, const Array<StmtSRef>& blocks,
+                            const BlockRealize& blockized)
+      : lca_(lca), blocks_(blocks), blockized_(blockized) {}
+
+  Stmt RewriteSeq(const Stmt& stmt) {
+    const SeqStmtNode* seq = stmt.as<SeqStmtNode>();
+    ICHECK(seq) << "Target blocks must not be nested with each other!";
+    int idx_start = -1;
+    int found_cnt = 0;
+    int last_found_idx = -1;
+    size_t cur_idx = 0;
+    Array<Stmt> new_seq;
+    for (const Stmt& it : seq->seq) {
+      target_in_ = false;
+      Stmt stmt = StmtMutator::VisitStmt(it);
+      if (target_in_) {
+        if (idx_start == -1) {
+          idx_start = cur_idx;
+          new_seq.push_back(blockized_);
+        } else {
+          ICHECK_EQ(last_found_idx, cur_idx - 1) << "Target blocks must be 
consecutive!";
+        }
+        last_found_idx = cur_idx;
+        ++found_cnt;
+      } else {
+        new_seq.push_back(it);
+      }
+      ++cur_idx;
+    }
+    if (new_seq.size() == 1) return new_seq[0];
+    return SeqStmt(new_seq, seq->span);
+  }
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    if (loop == lca_->stmt) {
+      return For(loop->loop_var, loop->min, loop->extent, loop->kind, 
RewriteSeq(loop->body),
+                 loop->thread_binding, loop->annotations, loop->span);
+    }
+    return StmtMutator::VisitStmt_(loop);
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == lca_->stmt) {
+      return Block(block->iter_vars, block->reads, block->writes, 
block->name_hint,
+                   RewriteSeq(block->body), block->init, block->alloc_buffers, 
block->match_buffers,
+                   block->annotations, block->span);
+    }
+    for (const StmtSRef& block_sref : blocks_) {
+      if (block_sref->stmt == block) {
+        target_in_ = true;
+        break;
+      }
+    }
+    return GetRef<Stmt>(block);
+  }
+
+  StmtSRef lca_;
+  Array<StmtSRef> blocks_;
+  BlockRealize blockized_;
+  bool target_in_ = false;
+};
+
+StmtSRef Blockize(ScheduleState self, const Array<StmtSRef>& blocks, bool 
preserve_unit_iters) {
+  Map<Block, Block> block_sref_reuse;
+  auto lca = GetSRefLowestCommonAncestor(blocks);
+  BlockRealize blockized =
+      BlockizeBlocks(self, blocks, lca, &block_sref_reuse, 
preserve_unit_iters);
+  auto new_root = BlockizeRewriter::Rewrite(lca, blocks, blockized);
+  self->Replace(lca, new_root, block_sref_reuse);
+  StmtSRef result = self->stmt2ref.at(blockized->block.get());
+  StmtSRef scope_root = tir::GetScopeRoot(self, result, 
/*require_stage_pipeline=*/false);
+  self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root));
+  return result;
+}
+
 void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& 
intrin,
                bool preserve_unit_iters) {
   // Step 1: Blockize the subtree rooted at the given loop if needed
@@ -636,13 +863,19 @@ struct BlockizeTraits : public 
UnpackedInstTraits<BlockizeTraits> {
   static constexpr size_t kNumAttrs = 1;
   static constexpr size_t kNumDecisions = 0;
 
-  static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Bool 
preserve_unit_iters) {
-    return sch->Blockize(loop_rv, preserve_unit_iters.operator bool());
+  static BlockRV UnpackedApplyToSchedule(Schedule sch, ObjectRef target, Bool 
preserve_unit_iters) {
+    if (auto loop = target.as<LoopRV>()) {
+      return sch->Blockize(loop.value(), preserve_unit_iters.operator bool());
+    } else if (auto blocks = target.as<Array<BlockRV>>()) {
+      return sch->Blockize(blocks.value(), preserve_unit_iters.operator 
bool());
+    }
+    LOG(FATAL) << "TypeError: expect Loop or list of Blocks, but gets:" << 
target->GetTypeKey();
   }
 
-  static String UnpackedAsPython(Array<String> outputs, String loop_rv, Bool 
preserve_unit_iters) {
+  static String UnpackedAsPython(Array<String> outputs, ObjectRef target,
+                                 Bool preserve_unit_iters) {
     PythonAPICall py("blockize");
-    py.Input("loop", loop_rv);
+    py.Input("target", target);
     py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
     py.SingleOutput(outputs);
     return py.Str();
diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc
index 8663ac2b97..56d0d1efa9 100644
--- a/src/tir/schedule/schedule.cc
+++ b/src/tir/schedule/schedule.cc
@@ -228,7 +228,14 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType")
     .set_body_method<Schedule>(&ScheduleNode::UnsafeSetDType);
 /******** (FFI) Blockize & Tensorize ********/
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
-    .set_body_method<Schedule>(&ScheduleNode::Blockize);
+    .set_body_typed([](Schedule self, ObjectRef target, bool 
preserve_unit_iters) {
+      if (auto loop_rv = target.as<LoopRV>()) {
+        return self->Blockize(loop_rv.value(), preserve_unit_iters);
+      } else if (auto blocks = target.as<Array<BlockRV>>()) {
+        return self->Blockize(blocks.value(), preserve_unit_iters);
+      }
+      LOG(FATAL) << "Unsupported target type: " << target->GetTypeKey();
+    });
 TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTensorize")
     .set_body_typed([](Schedule self, ObjectRef rv, String intrin, bool 
preserve_unit_iters) {
       if (auto block_rv = rv.as<BlockRV>()) {
diff --git a/src/tir/schedule/traced_schedule.cc 
b/src/tir/schedule/traced_schedule.cc
index 4d820078e5..ceeeacb335 100644
--- a/src/tir/schedule/traced_schedule.cc
+++ b/src/tir/schedule/traced_schedule.cc
@@ -558,6 +558,17 @@ BlockRV TracedScheduleNode::Blockize(const LoopRV& 
loop_rv, bool preserve_unit_i
   return new_block;
 }
 
+BlockRV TracedScheduleNode::Blockize(const Array<BlockRV>& blocks, bool 
preserve_unit_iters) {
+  BlockRV new_block = ConcreteScheduleNode::Blockize(blocks, 
preserve_unit_iters);
+  static const InstructionKind& kind = InstructionKind::Get("Blockize");
+  trace_->Append(/*inst=*/Instruction(
+      /*kind=*/kind,
+      /*inputs=*/{blocks},
+      /*attrs=*/{Bool(preserve_unit_iters)},
+      /*outputs=*/{new_block}));
+  return new_block;
+}
+
 void TracedScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin,
                                    bool preserve_unit_iters) {
   ConcreteScheduleNode::Tensorize(loop_rv, intrin, preserve_unit_iters);
diff --git a/src/tir/schedule/traced_schedule.h 
b/src/tir/schedule/traced_schedule.h
index 16ec86f227..2d47ee9aff 100644
--- a/src/tir/schedule/traced_schedule.h
+++ b/src/tir/schedule/traced_schedule.h
@@ -111,6 +111,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
   void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& 
dtype) final;
   /******** Schedule: Blockize & Tensorize ********/
   BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final;
+  BlockRV Blockize(const Array<BlockRV>& blocks, bool preserve_unit_iters) 
final;
   void Tensorize(const BlockRV& block_rv, const String& intrin, bool 
preserve_unit_iters) final;
   void Tensorize(const LoopRV& loop_rv, const String& intrin, bool 
preserve_unit_iters) final;
   /******** Schedule: Annotation ********/
diff --git 
a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py 
b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
index 030e47ac58..111448ea57 100644
--- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
+++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py
@@ -493,7 +493,7 @@ def _schedule_async_dma_conv2d():
         sch.parallel(new_loops[4])
         sch.unroll(new_loops[5])
         # TODO(nverke): Add compute optimizations here.
-        sch.blockize(loop=oc_i)
+        sch.blockize(target=oc_i)
 
         sch.tensorize(oc_i, VRMPY_u8i8i32_VTCM_INTRIN)
 
diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py 
b/tests/python/unittest/test_meta_schedule_trace_apply.py
index 4d22c4ff88..78b2fdbf3d 100644
--- a/tests/python/unittest/test_meta_schedule_trace_apply.py
+++ b/tests/python/unittest/test_meta_schedule_trace_apply.py
@@ -2150,7 +2150,7 @@ def test_conv2d_int8_tensorcore():
         l28, l29 = sch.split(loop=l21, factors=[None, 16], 
preserve_unit_iters=True)
         l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b1)
         sch.reorder(l34, l36, l29, l27, l25)
-        b38 = sch.blockize(loop=l29)
+        b38 = sch.blockize(target=l29)
         sch.annotate(
             block_or_loop=b38,
             ann_key="meta_schedule.auto_tensorize",
@@ -2243,7 +2243,7 @@ def test_conv2d_int8_tensorcore():
         l95, l96 = sch.split(loop=l91, factors=[None, 16], 
preserve_unit_iters=True)
         l97, l98, l99, l100, l101, l102, l103 = sch.get_loops(block=b86)
         sch.reorder(l102, l96, l94)
-        b104 = sch.blockize(loop=l96)
+        b104 = sch.blockize(target=l96)
         sch.annotate(
             block_or_loop=b104,
             ann_key="meta_schedule.auto_tensorize",
@@ -2308,7 +2308,7 @@ def test_conv2d_int8_tensorcore():
             l157,
         ) = sch.get_loops(block=b129)
         sch.reorder(l156, l144, l142)
-        b158 = sch.blockize(loop=l144)
+        b158 = sch.blockize(target=l144)
         sch.annotate(
             block_or_loop=b158,
             ann_key="meta_schedule.auto_tensorize",
@@ -2351,7 +2351,7 @@ def test_conv2d_int8_tensorcore():
             l191,
         ) = sch.get_loops(block=b159)
         sch.reorder(l190, l176, l174)
-        b192 = sch.blockize(loop=l176)
+        b192 = sch.blockize(target=l176)
         sch.annotate(
             block_or_loop=b192,
             ann_key="meta_schedule.auto_tensorize",
@@ -2554,7 +2554,7 @@ def test_conv2d_int8_vnni():
         l34, l35 = sch.split(loop=l26, factors=[None, 16], 
preserve_unit_iters=True)
         l36, l37, l38, l39, l40, l41, l42, l43, l44, l45, l46, l47 = 
sch.get_loops(block=b1)
         sch.reorder(l42, l43, l44, l45, l46, l35, l33)
-        b48 = sch.blockize(loop=l35)
+        b48 = sch.blockize(target=l35)
         sch.annotate(block_or_loop=b48, 
ann_key="meta_schedule.auto_tensorize", ann_val=VNNI_INTRIN)
         l49, l50, l51, l52, l53, l54, l55, l56, l57, l58 = 
sch.get_loops(block=b48)
         v59, v60, v61, v62 = sch.sample_perfect_tile(
@@ -3119,7 +3119,7 @@ def test_inline_order():
         l22, l23 = sch.split(loop=l15, factors=[None, 16], 
preserve_unit_iters=True)
         l24, l25, l26, l27, l28, l29, l30, l31 = sch.get_loops(block=b1)
         sch.reorder(l28, l30, l23, l21, l19)
-        b32 = sch.blockize(loop=l23)
+        b32 = sch.blockize(target=l23)
         sch.annotate(
             block_or_loop=b32,
             ann_key="meta_schedule.auto_tensorize",
@@ -3212,7 +3212,7 @@ def test_inline_order():
         l89, l90 = sch.split(loop=l85, factors=[None, 16], 
preserve_unit_iters=True)
         l91, l92, l93, l94, l95, l96, l97 = sch.get_loops(block=b80)
         sch.reorder(l96, l90, l88)
-        b98 = sch.blockize(loop=l90)
+        b98 = sch.blockize(target=l90)
         sch.annotate(
             block_or_loop=b98,
             ann_key="meta_schedule.auto_tensorize",
@@ -3277,7 +3277,7 @@ def test_inline_order():
             l151,
         ) = sch.get_loops(block=b123)
         sch.reorder(l150, l138, l136)
-        b152 = sch.blockize(loop=l138)
+        b152 = sch.blockize(target=l138)
         sch.annotate(
             block_or_loop=b152,
             ann_key="meta_schedule.auto_tensorize",
@@ -3320,7 +3320,7 @@ def test_inline_order():
             l185,
         ) = sch.get_loops(block=b153)
         sch.reorder(l184, l170, l168)
-        b186 = sch.blockize(loop=l170)
+        b186 = sch.blockize(target=l170)
         sch.annotate(
             block_or_loop=b186,
             ann_key="meta_schedule.auto_tensorize",
diff --git a/tests/python/unittest/test_tir_schedule_blockize.py 
b/tests/python/unittest/test_tir_schedule_blockize.py
index cd4ce663e5..d151e4b438 100644
--- a/tests/python/unittest/test_tir_schedule_blockize.py
+++ b/tests/python/unittest/test_tir_schedule_blockize.py
@@ -305,5 +305,54 @@ def test_blockize_outer_int64_shape(preserve_unit_iters):
     verify_trace_roundtrip(sch=s, mod=single_elementwise_int64)
 
 
+def test_blockize_blocks():
+    @T.prim_func
+    def blocks_func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 
128), "float32")) -> None:
+        for m in T.serial(6):
+            for i, j in T.grid(3, 1):
+                with T.block("B"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    T.reads(A[vi, vj])
+                    T.writes(B[vi, vj])
+                    B[vi, vj] = A[vi, vj] * 2.0
+
+            for i, j in T.grid(128, 64):
+                with T.block("C"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    T.reads(A[vi, vj + 64])
+                    T.writes(B[vi, vj + 64])
+                    B[vi, vj + 64] = A[vi, vj + 64] * 3.0
+
+    @T.prim_func
+    def after_blocks_blockize(
+        A: T.Buffer((128, 128), "float32"), B: T.Buffer((128, 128), "float32")
+    ) -> None:
+        for m in range(6):
+            with T.block("outer_B_C_"):
+                vi_o = T.axis.spatial(1, 0)
+                vj_o = T.axis.spatial(1, 0)
+                T.reads(A[0:128, 0:128])
+                T.writes(B[0:128, 0:128])
+                for i, j in T.grid(3, 1):
+                    with T.block("B"):
+                        vi_i = T.axis.spatial(3, i)
+                        T.reads(A[vi_i, 0])
+                        T.writes(B[vi_i, 0])
+                        B[vi_i, 0] = A[vi_i, 0] * T.float32(2)
+                for i, j in T.grid(128, 64):
+                    with T.block("C"):
+                        vi_i, vj_i = T.axis.remap("SS", [i, j])
+                        T.reads(A[vi_i, vj_i + 64])
+                        T.writes(B[vi_i, vj_i + 64])
+                        B[vi_i, vj_i + 64] = A[vi_i, vj_i + 64] * T.float32(3)
+
+    s = tir.Schedule(blocks_func, debug_mask="all")
+    blocks = [s.get_block("B"), s.get_block("C")]
+    s.blockize(blocks, preserve_unit_iters=False)
+    expected = after_blocks_blockize
+    tvm.ir.assert_structural_equal(s.mod["main"], expected)
+    verify_trace_roundtrip(sch=s, mod=blocks_func)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to