junrushao1994 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r716920036



##########
File path: src/tir/schedule/state.cc
##########
@@ -206,25 +188,11 @@ class StateCreator : private StmtVisitor {
    * \param stmt A for-loop statement or a block statement
    * \return A sref to the stmt
    */
-  StmtSRef PushSRef(const StmtNode* stmt) {
-    if (srefs_.empty()) {
-      srefs_.push_back(
-          StmtSRef(stmt,
-                   /*parent=*/nullptr,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in 
SetSeqIndex
-    } else {
-      StmtSRefNode* parent = srefs_.back().get();
-      srefs_.push_back(
-          StmtSRef(stmt, parent,
-                   /*seq_index=*/-1));  // `seq_index` will be set properly in 
SetSeqIndex
-    }
-    return srefs_.back();
-  }
+  void PushSRef(const StmtNode* stmt) { 
srefs_.push_back(self_->stmt2ref.at(stmt)); }
 
-  /*! \brief Pop the top of the scope and record it in stmt2ref map */
-  StmtSRef PopAndRecordSRef() {
-    StmtSRef sref = std::move(srefs_.back());
-    self_->stmt2ref[sref->stmt] = sref;
+  /*! \brief Pop the top of the scope */

Review comment:
       Correct me if I was wrong, but 
`ScheduleStateNode::UpdateSubtreeblockInfo(stmt)` calls 
`BlockInfoCollector::Collect`, where `stmt` can be the a root block of a 
subtree, but it's binding might not be affine; However, in `MakeBlockInfo`, we 
would assume `is_root_block == True`, and make this binding affine - is that 
correct?

##########
File path: include/tvm/tir/schedule/state.h
##########
@@ -142,6 +142,8 @@ class ScheduleStateNode : public Object {
   /******** Property of blocks ********/
   /*! \brief Returns the BlockInfo correpsonding to the block sref */
   TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
+  /*! \brief Recalculate the BlockInfo recursively under stmt */
+  TVM_DLL void UpdateSubtreeBlockInfo(const Stmt& stmt);

Review comment:
       Do we have any additional constraint of this API? For example:
   - `stmt` must be a `Block`?
   - All loop /blocks in the subtree must not have any sref

##########
File path: src/tir/schedule/state.cc
##########
@@ -421,6 +389,86 @@ class StateCreator : private StmtVisitor {
   arith::Analyzer analyzer_;
 };
 
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {

Review comment:
       I just checked. Actually StateCreator, including its contents like 
PushSRef, VisitStmt, is not used anymore, so let's remove them all and inline 
the Create method into `ScheduleState::ScheduleState`

##########
File path: src/tir/schedule/state.cc
##########
@@ -421,6 +389,86 @@ class StateCreator : private StmtVisitor {
   arith::Analyzer analyzer_;
 };
 
+/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */
+class StateCreator : private StmtVisitor {

Review comment:
       Update: I'm not sure I'm 100% correct after a second check...Let me know 
:-)

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
+  }

Review comment:
       nitpick
   
   ```suggestion
     if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
       throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
     }
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of 
type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
block_var_map;

Review comment:
       nit
   
   ```suggestion
     std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
block_var_map;
     block_var_map.reserve(block->iter_vars.size());
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of 
type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),

Review comment:
       do we want to rename these block vars, or simply using the same name is 
good enough?
    
   ```suggestion
                            /*var=*/iter_var->var.copy_with_suffix(""),
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of 
type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction 
block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; 
})) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block 
realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, 
discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {

Review comment:
       nit: we don't need const ref to integers because direct copy may be 
equally fast
   
   ```suggestion
     for (int i : chosen_loops) {
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of 
type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction 
block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; 
})) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block 
realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, 
discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {

Review comment:
       nit: we don't need const ref to integers because direct copy may be 
faster or at least equally fast
   
   ```suggestion
     for (int i : chosen_loops) {
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of 
type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction 
block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; 
})) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block 
realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, 
discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,

Review comment:
       do we need to handle the case where `min` and `extent` contains some 
other loop variables?

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of 
type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction 
block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; 
})) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block 
realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, 
discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,

Review comment:
       do we need to handle the case where `min` and `extent` contains some 
other loop variables? are they handled in line 287?

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of 
type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction 
block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; 
})) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block 
realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, 
discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,
+               /*kind=*/ForKind::kSerial,
+               /*body=*/body);
+  }
+  body = Substitute(body, loop_var_map);
+  // Step 6. Mutate IR
+  const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, 
scope_root_sref);
+  Block new_scope_root;
+  Block new_reduction_block;

Review comment:
       ```suggestion
     Block new_scope_root{nullptr};
     Block new_reduction_block{nullptr};
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of 
type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }
+  // Step 3. Scan loops not higher than the specified loop above the reduction 
block.
+  //         If the loop is used in the init block binding, then it is chosen.
+  //         Otherwise, it is discarded.
+  std::unordered_set<const VarNode*> discarded_loops;
+  std::vector<int> chosen_loops;
+  for (int i = static_cast<int>(loops.size()) - 1; i >= 0; --i) {
+    const VarNode* loop_var = loops[i]->StmtAs<ForNode>()->loop_var.get();
+    bool discarded = true;
+    for (const PrimExpr& expr : init_realize->iter_values) {
+      if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; 
})) {
+        continue;
+      }
+      // The loop is related to init block bindings;
+      chosen_loops.push_back(i);
+      discarded = false;
+      break;
+    }
+    if (discarded) discarded_loops.insert(loop_var);
+    // Only scan loops not higher than the given loop
+    if (loops[i].same_as(loop_sref)) {
+      break;
+    }
+  }
+  // Step 4. After scanning loops, make a new predicate in the init block 
realize
+  //         We discard predicate that is related to discarded loops
+  init_realize->predicate = RemakePredicate(init_realize->predicate, 
discarded_loops);
+  // Step 5. Create new loops above init block
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
loop_var_map;
+  Stmt body = BlockRealize(init_realize);
+  for (const int& i : chosen_loops) {
+    const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]);
+    // Create a new equivalent to the chosen loop
+    Var old_loop_var = old_loop->loop_var;
+    Var new_loop_var = old_loop_var.copy_with_suffix("_init");
+    loop_var_map[old_loop_var] = new_loop_var;
+    body = For(/*loop_var=*/new_loop_var,
+               /*min=*/old_loop->min,
+               /*extent=*/old_loop->extent,
+               /*kind=*/ForKind::kSerial,
+               /*body=*/body);
+  }
+  body = Substitute(body, loop_var_map);
+  // Step 6. Mutate IR
+  const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, 
scope_root_sref);
+  Block new_scope_root;
+  Block new_reduction_block;

Review comment:
       nit
   
   ```suggestion
     Block new_scope_root{nullptr};
     Block new_reduction_block{nullptr};
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of 
type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;
+  init_realize->block = Block(init_block);
+  // Step 1. Create new block vars and their bindings
+  // Maps an old block var to the new corresponding block var
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> 
block_var_map;
+  for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+    const IterVar& iter_var = block->iter_vars[i];
+    const PrimExpr& binding = realize->iter_values[i];
+    // Only process data parallel block vars
+    if (iter_var->iter_type != IterVarType::kDataPar) {
+      continue;
+    }
+    // Create a new block var
+    IterVar new_iter_var(/*dom=*/iter_var->dom,
+                         /*var=*/iter_var->var.copy_with_suffix("_init"),
+                         /*iter_type=*/iter_var->iter_type,
+                         /*thread_tag=*/iter_var->thread_tag);
+    // Add a block var and its binding
+    init_block->iter_vars.push_back(new_iter_var);
+    init_realize->iter_values.push_back(binding);
+    // Add a mapping from old block vars to new block vars
+    block_var_map[iter_var->var] = new_iter_var->var;
+  }
+  // Step 2. After copying block vars, substitute them in init block
+  init_block->body = Substitute(block->init.value(), block_var_map);
+  for (const BufferRegion& write : block->writes) {
+    init_block->writes.push_back(
+        BufferRegion(write->buffer, Substitute(write->region, block_var_map)));
+  }

Review comment:
       Do we make the assumption here about the write region?

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      pred = rest.Eval();
+    } else if ((lhs < rhs).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());
+      break;
+    } else {
+      ICHECK(false) << "Unexpected predicate for reduction block";
+    }
+  }
+  return new_pred;
+}
+
+StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
+                            const StmtSRef& loop_sref) {
+  /*!
+   *  Check
+   *    - block is a reduction block
+   *    - loop is not lower than all the loops related to reduce block var
+   *  Mutate
+   *    - generate loops related to data par block vars
+   *    - generate corresponding init block and update block
+   */
+  // Condition Checks and Information Collection
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+  const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+  // Get the outer loops from high to low
+  Array<StmtSRef> loops = GetLoops(block_sref);
+  const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
+  // Cond 0. Check loop_sref is an ancestor of block_sref
+  const auto& it = std::find(loops.begin(), loops.end(), loop_sref);
+  if (it == loops.end()) {
+    throw LoopPositionError(self->mod, GetRef<For>(loop), 
GetRef<Block>(block));
+  }
+  // Cond 1. Check block is reduction
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
+                                          /*require_stage_pipeline=*/false,
+                                          
/*require_subtree_compact_dataflow=*/false);
+  CheckReductionBlock(self, block_sref, scope_root_sref);
+  // Cond 2. Check 'loop' is higher than all the loops related to block var of 
type reduction
+  LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, 
loops, loop_sref);
+  // IR Manipulation
+  ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
+  ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
+  init_block->name_hint = block->name_hint + "_init";
+  init_realize->iter_values = {};
+  init_realize->predicate = realize->predicate;

Review comment:
       nit: don't set predicate here because we will mutate later on line 271

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };

Review comment:
       nit:
   
   ```suggestion
     auto f = [&](const VarNode* var) { return discarded_loops.count(var); };
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;

Review comment:
       nit: swap the two statements below
   
   ```suggestion
     if (is_one(pred)) return new_pred;
     PrimExpr new_pred = Bool(true);
   ```

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }
+      for (const StmtSRef& higher_loop : loops) {
+        // Only check loops not lower than the target loop
+        if (higher_loop.same_as(loop_sref)) {
+          break;
+        }
+        // loop_var of a higher loop shouldn't contain loop var
+        const Var& loop_var = higher_loop->StmtAs<ForNode>()->loop_var;
+        if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return 
var == v; })) {
+          const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
+          throw LoopHeightError(mod, GetRef<For>(loop), GetRef<Block>(block));
+        }
+      }
+    }
+  }
+
+  explicit LoopHeightError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be higher 
than all the loops "
+           "related to reduce block var";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: decompose_reduction expect the loop {0} to be higher 
than all the loops "
+          "related to reduce block var of block {1}";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set<const 
VarNode*>& discarded_loops) {
+  PrimExpr new_pred = Bool(true);
+  if (is_one(pred)) return new_pred;
+  auto f = [&](const VarNode* var) { return discarded_loops.find(var) != 
discarded_loops.end(); };
+  arith::PVar<PrimExpr> lhs, rhs, rest;
+  for (;;) {
+    if ((rest && (lhs < rhs)).Match(pred)) {
+      if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < 
rhs.Eval());

Review comment:
       Do we have any checks on the rhs?

##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,283 @@
 namespace tvm {
 namespace tir {
 
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }

Review comment:
       Why do we need special handling of SeqStmt btw?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to