MasterJH5574 commented on a change in pull request #9041:
URL: https://github.com/apache/tvm/pull/9041#discussion_r711994987



##########
File path: src/tir/schedule/primitive/reduction.cc
##########
@@ -21,6 +21,289 @@
 namespace tvm {
 namespace tir {
 
+bool ListContainsElement(const Array<StmtSRef>& list, const StmtSRef& element) 
{
+  for (const StmtSRef& ele : list) {
+    if (ele.same_as(element)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+/*!
+ * \brief A helper class to create a new scope that contains decomposed init 
body
+ * and replaced old reduction block.
+ */
+class DecomposeReductionBlockReplacer : public StmtMutator {
+ public:
+  /*!
+   * \brief The open interface to users to call the helper class
+   * \param old_scope_root The original block scope before decomposition
+   * \param target_loop The loop we insert the decomposed init body before
+   * \param decompose_body The decomposed init body
+   * \param old_reduction_block The reduction block we want to decompose
+   * \return The new block scope and the updated reduction block
+   */
+  static std::pair<Block, Block> Replace(Block old_scope_root, For target_loop,
+                                         Stmt decomposed_body, Block 
old_reduction_block) {
+    DecomposeReductionBlockReplacer replacer(std::move(target_loop), 
std::move(decomposed_body),
+                                             std::move(old_reduction_block));
+    return std::make_pair(Downcast<Block>(replacer(std::move(old_scope_root))),
+                          replacer.new_reduction_block_);
+  }
+
+ private:
+  explicit DecomposeReductionBlockReplacer(For target_loop, Stmt 
decomposed_body,
+                                           Block old_reduction_block)
+      : target_loop_(std::move(target_loop)),
+        decomposed_body_(std::move(decomposed_body)),
+        old_reduction_block_(std::move(old_reduction_block)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_stmt = StmtMutator::VisitStmt_(loop);
+    if (loop == target_loop_.get()) {
+      return SeqStmt({decomposed_body_, mutated_stmt});
+    } else {
+      return mutated_stmt;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    if (block == old_reduction_block_.get()) {
+      ObjectPtr<BlockNode> p_new_block = CopyOnWrite(block);
+      p_new_block->name_hint = p_new_block->name_hint + "_update";
+      p_new_block->init = NullOpt;
+      new_reduction_block_ = Block(p_new_block);
+      return new_reduction_block_;
+    } else {
+      return StmtMutator::VisitStmt_(block);
+    }
+  }
+
+  Stmt VisitStmt_(const SeqStmtNode* seq) final {
+    Array<Stmt> new_stmts;
+    new_stmts.reserve(seq->seq.size());
+    for (const Stmt& old_stmt : seq->seq) {
+      new_stmts.push_back(VisitStmt(old_stmt));
+    }
+    return SeqStmt::Flatten(new_stmts);
+  }
+
+ private:
+  For target_loop_;
+  Stmt decomposed_body_;
+  Block old_reduction_block_;
+  Block new_reduction_block_;
+};
+
+class LoopPositionError : public ScheduleError {
+ public:
+  explicit LoopPositionError(IRModule mod, For loop, Block block)
+      : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) 
{}
+
+  String FastErrorString() const final {
+    return "ScheduleError: decompose_reduction expect the loop to be an 
ancestor of block";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "ScheduleError: The input loop {0} of decompose_reduction is 
required to be be an "
+          "ancestor of block {1}.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; 
}
+
+  IRModule mod_;
+  For loop_;
+  Block block_;
+};
+
+class LoopHeightError : public ScheduleError {
+ public:
+  static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const 
BlockNode* block,
+                                             const BlockRealizeNode* realize,
+                                             const Array<StmtSRef>& loops,
+                                             const StmtSRef& loop_sref) {
+    for (int i = 0, n = block->iter_vars.size(); i < n; ++i) {
+      // For each block var of type kCommReduce, check its binding
+      const IterVar& iter_var = block->iter_vars[i];
+      const PrimExpr& binding = realize->iter_values[i];
+      if (iter_var->iter_type != IterVarType::kCommReduce) {
+        continue;
+      }

Review comment:
       There's a helper function in analysis.h which collects the variables 
contained in the bindings of data-parallel block iters and reduction block 
iters respectively. 
https://github.com/apache/tvm/blob/44b644c6a37266c6c49eaa7e8c87c7809b882da5/src/tir/schedule/analysis.h#L215-L226
   We can use it here to remove some repetitive code.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to