Hzfengsy commented on code in PR #14398:
URL: https://github.com/apache/tvm/pull/14398#discussion_r1153731668
##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -292,6 +292,16 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
/******** Schedule: Transform loops ********/
+ /*!
+ * \brief Merge a list of loops into one. The loops under their LCA requires:
+ * 1) Under the same scope.
+ * 2) Can't have annotations or thread bindings
+ * 3) Start with 0 and have same domain.
+ * 4) The inner loop must be the only child of the outer loop.
Review Comment:
It's unclear what's the meaning of `the inner loop`
##########
src/tir/schedule/primitive.h:
##########
@@ -161,6 +161,19 @@ Array<StmtSRef> GetConsumers(const ScheduleState& self,
const StmtSRef& block_sr
*/
TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
const Array<PrimExpr>& factors, bool
preserve_unit_iters);
+
+/*!
+ * \brief Merge a list of loops into one. The loops under their LCA requires:
+ * 1) Under the same scope.
+ * 2) Can't have annotations or thread bindings
+ * 3) Start with 0 and have same domain.
+ * 4) The inner loop must be the only child of the outer loop.
Review Comment:
ditto `inner loop`
##########
tests/python/unittest/test_tir_schedule_merge.py:
##########
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
+import tvm
+import tvm.testing
+from tvm import te, tir
+from tvm.script import tir as T
+from tvm.tir.expr import IntImm
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
[email protected]_func
+def elementwise(a: T.handle, c: T.handle, d: T.handle) -> None:
+ A = T.match_buffer(a, (128, 128))
+ C = T.match_buffer(c, (128, 128))
+ D = T.match_buffer(d, (64, 64))
+ B = T.alloc_buffer((128, 128))
+ for i, j in T.grid(128, 128):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+ for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
+ with T.block("C"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + j_1)
+ T.reads(B[vi, vj])
+ T.writes(C[vi, vj])
+ C[vi, vj] = B[vi, vj] + T.float32(1)
+ for i_0, j_0, i_1, j_1 in T.grid(8, 8, 8, 8):
+ with T.block("D"):
+ vi = T.axis.spatial(64, i_0 * 8 + i_1)
+ vj = T.axis.spatial(64, j_0 * 8 + j_1)
+ T.reads(B[vi, vj])
+ T.writes(D[vi, vj])
+ D[vi, vj] = B[vi, vj] + T.float32(2)
+
+
[email protected]_func
+def elementwise_different_loops_extent(a: T.handle, c: T.handle) -> None:
Review Comment:
Could you please move the PrimFunc under the test case where uses it.
##########
src/tir/schedule/primitive/loop_transformation.cc:
##########
@@ -451,6 +451,163 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef&
loop_sref, const Array
return result_srefs;
}
+class LoopReconstructor : private StmtMutator {
+ public:
+ explicit LoopReconstructor(Block scope_root,
+ const std::vector<std::vector<const ForNode*>>&
loops)
+ : scope_root_(scope_root), loops_(loops) {}
+
+ using StmtMutator::operator();
+
+ /*!
+ * \brief Create the new nest loops induced by the given loops
+ */
+ void MakeNewLoop() {
+ Array<Var> new_loop_vars;
+ Array<PrimExpr> new_loop_extents;
+ Array<Stmt> new_stmts;
+ for (size_t i = 0; i < loops_.size(); i++) {
+ Map<Var, PrimExpr> var_map;
+ for (size_t j = 0; j < loops_[i].size(); j++) {
+ if (i == 0) {
+ Var merged_var = loops_[i][j]->loop_var.copy_with_suffix("_m");
+ new_loop_vars.push_back(merged_var);
+ new_loop_extents.push_back(loops_[i][j]->extent);
+ }
+ var_map.Set(loops_[i][j]->loop_var, new_loop_vars[j]);
+ }
+ auto new_stmt = Substitute(loops_[i][0]->body, var_map);
+ new_stmts.push_back(new_stmt);
+ this->need_remove_loop_.push_back(loops_[i].back());
+ }
+ auto new_loop = For(new_loop_vars[0], Integer(0), new_loop_extents[0],
ForKind::kSerial,
+ SeqStmt(std::move(new_stmts)));
+ this->new_inner_loop_ = new_loop;
+ for (size_t i = 1; i < new_loop_vars.size(); ++i) {
+ const Var& loop_var = new_loop_vars[i];
+ const PrimExpr& loop_extent = new_loop_extents[i];
+ new_loop = For(loop_var, Integer(0), loop_extent, ForKind::kSerial,
new_loop);
+ }
+ this->new_outer_loop_ = new_loop;
+ }
+
+ private:
+ Stmt VisitStmt_(const BlockNode* block) final {
+ if (block != scope_root_.get()) {
+ return GetRef<Block>(block);
+ }
+ return StmtMutator::VisitStmt_(block);
+ }
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ if (loop == need_remove_loop_.back()) {
+ return new_outer_loop_;
+ } else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(),
loop)) {
+ return Evaluate(0);
+ }
+ return StmtMutator::VisitStmt_(loop);
+ }
+
+ Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final {
+ auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(seq_stmt, true));
+ Array<Stmt> filtered;
+ for (Stmt stmt : ret->seq) {
+ if (!is_no_op(stmt)) {
+ filtered.push_back(std::move(stmt));
+ }
+ }
+ ret = SeqStmt(filtered);
+ if (ret->size() == 0) {
+ return Evaluate(0);
+ } else if (ret->size() == 1) {
+ return ret->seq[0];
+ } else {
+ return std::move(ret);
+ }
+ }
+
+ public:
+ /*! \brief The root block of the block scope */
+ Block scope_root_;
+ /*! \brief The given loops to be merge */
+ const std::vector<std::vector<const ForNode*>>& loops_;
+ /*! \brief The outermost new loop to replace the original loop */
+ For new_outer_loop_{nullptr};
+ /*! \brief The innermost new loop to replace the original loop */
+ For new_inner_loop_{nullptr};
+ /*! \brief The loops to be removed */
+ std::vector<const ForNode*> need_remove_loop_;
Review Comment:
```suggestion
std::vector<For> need_remove_loop_;
```
##########
src/tir/schedule/primitive/loop_transformation.cc:
##########
@@ -451,6 +451,163 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef&
loop_sref, const Array
return result_srefs;
}
+class LoopReconstructor : private StmtMutator {
+ public:
+ explicit LoopReconstructor(Block scope_root,
+ const std::vector<std::vector<const ForNode*>>&
loops)
+ : scope_root_(scope_root), loops_(loops) {}
+
+ using StmtMutator::operator();
+
+ /*!
+ * \brief Create the new nest loops induced by the given loops
+ */
+ void MakeNewLoop() {
+ Array<Var> new_loop_vars;
+ Array<PrimExpr> new_loop_extents;
+ Array<Stmt> new_stmts;
+ for (size_t i = 0; i < loops_.size(); i++) {
+ Map<Var, PrimExpr> var_map;
+ for (size_t j = 0; j < loops_[i].size(); j++) {
+ if (i == 0) {
+ Var merged_var = loops_[i][j]->loop_var.copy_with_suffix("_m");
+ new_loop_vars.push_back(merged_var);
+ new_loop_extents.push_back(loops_[i][j]->extent);
+ }
+ var_map.Set(loops_[i][j]->loop_var, new_loop_vars[j]);
+ }
+ auto new_stmt = Substitute(loops_[i][0]->body, var_map);
+ new_stmts.push_back(new_stmt);
+ this->need_remove_loop_.push_back(loops_[i].back());
+ }
+ auto new_loop = For(new_loop_vars[0], Integer(0), new_loop_extents[0],
ForKind::kSerial,
+ SeqStmt(std::move(new_stmts)));
+ this->new_inner_loop_ = new_loop;
+ for (size_t i = 1; i < new_loop_vars.size(); ++i) {
+ const Var& loop_var = new_loop_vars[i];
+ const PrimExpr& loop_extent = new_loop_extents[i];
+ new_loop = For(loop_var, Integer(0), loop_extent, ForKind::kSerial,
new_loop);
+ }
+ this->new_outer_loop_ = new_loop;
+ }
+
+ private:
+ Stmt VisitStmt_(const BlockNode* block) final {
+ if (block != scope_root_.get()) {
+ return GetRef<Block>(block);
+ }
+ return StmtMutator::VisitStmt_(block);
+ }
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ if (loop == need_remove_loop_.back()) {
+ return new_outer_loop_;
+ } else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(),
loop)) {
+ return Evaluate(0);
+ }
+ return StmtMutator::VisitStmt_(loop);
+ }
+
+ Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final {
+ auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(seq_stmt, true));
+ Array<Stmt> filtered;
+ for (Stmt stmt : ret->seq) {
+ if (!is_no_op(stmt)) {
+ filtered.push_back(std::move(stmt));
+ }
+ }
+ ret = SeqStmt(filtered);
+ if (ret->size() == 0) {
+ return Evaluate(0);
+ } else if (ret->size() == 1) {
+ return ret->seq[0];
+ } else {
+ return std::move(ret);
+ }
+ }
+
+ public:
+ /*! \brief The root block of the block scope */
+ Block scope_root_;
+ /*! \brief The given loops to be merge */
+ const std::vector<std::vector<const ForNode*>>& loops_;
+ /*! \brief The outermost new loop to replace the original loop */
+ For new_outer_loop_{nullptr};
+ /*! \brief The innermost new loop to replace the original loop */
+ For new_inner_loop_{nullptr};
+ /*! \brief The loops to be removed */
+ std::vector<const ForNode*> need_remove_loop_;
+};
+
+StmtSRef Merge(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
+ // Invariance
+ // - The total repeat number has not changed for each direct child block.
+ // - The execution order has not changed. (The block executes with the same
+ // args and the same order with before.)
+ arith::Analyzer analyzer;
+ StmtSRef scope_root_sref;
+ StmtSRef lca = GetSRefLowestCommonAncestor(loop_srefs);
+ std::vector<std::vector<const ForNode*>> lca_nest_loops;
+ // Step 1. check correctness
+ std::vector<const ForNode*> nest_loop_loops;
+ std::vector<PrimExpr> nest_loop_extents;
+ for (size_t i = 0; i < loop_srefs.size(); i++) {
+ const StmtSRef& sref = loop_srefs[i];
+ auto scope_root_sref_ = GetScopeRoot(self, sref,
/*require_stage_pipeline=*/false);
Review Comment:
Need to check all loops to be merged in the same scope
##########
tests/python/unittest/test_tir_schedule_merge.py:
##########
Review Comment:
Need another failure test if loops are not at the same level
```python
for i0, i1 in T.grid(...):
A[i] = ...
for j0, j1 in T.grid(...):
B[i] =
s.merge(i0, j1) # raise error
```
##########
src/tir/schedule/concrete_schedule.cc:
##########
@@ -356,6 +356,17 @@ Array<BlockRV> ConcreteScheduleNode::GetConsumers(const
BlockRV& block_rv) {
/******** Schedule: Transform loops ********/
+LoopRV ConcreteScheduleNode::Merge(const Array<LoopRV>& loop_rvs) {
+ CHECK(!loop_rvs.empty()) << "ValueError: 'merge' requires at least 1
loop(s)";
Review Comment:
I recommend to require at least 2 loops, as "merge" 1 loop is trival
##########
tests/python/unittest/test_tir_schedule_merge.py:
##########
Review Comment:
Need a failure tests with dependencies.
```python
for i0 in range(...):
A[i] = ...
for i1 in range(...):
B[i] = 1
for i2 in range(...):
C[i] = B[i]
s.merge(i0, i2) # raise error as B is not computed
```
##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -292,6 +292,16 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
/******** Schedule: Transform loops ********/
+ /*!
+ * \brief Merge a list of loops into one. The loops under their LCA requires:
+ * 1) Under the same scope.
+ * 2) Can't have annotations or thread bindings
+ * 3) Start with 0 and have same domain.
Review Comment:
```suggestion
* 3) Start with 0 and have same extent.
```
##########
src/tir/schedule/primitive/loop_transformation.cc:
##########
@@ -451,6 +451,163 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef&
loop_sref, const Array
return result_srefs;
}
+class LoopReconstructor : private StmtMutator {
+ public:
+ explicit LoopReconstructor(Block scope_root,
+ const std::vector<std::vector<const ForNode*>>&
loops)
+ : scope_root_(scope_root), loops_(loops) {}
+
+ using StmtMutator::operator();
+
+ /*!
+ * \brief Create the new nest loops induced by the given loops
+ */
+ void MakeNewLoop() {
+ Array<Var> new_loop_vars;
+ Array<PrimExpr> new_loop_extents;
+ Array<Stmt> new_stmts;
+ for (size_t i = 0; i < loops_.size(); i++) {
+ Map<Var, PrimExpr> var_map;
+ for (size_t j = 0; j < loops_[i].size(); j++) {
+ if (i == 0) {
+ Var merged_var = loops_[i][j]->loop_var.copy_with_suffix("_m");
+ new_loop_vars.push_back(merged_var);
+ new_loop_extents.push_back(loops_[i][j]->extent);
+ }
+ var_map.Set(loops_[i][j]->loop_var, new_loop_vars[j]);
+ }
+ auto new_stmt = Substitute(loops_[i][0]->body, var_map);
+ new_stmts.push_back(new_stmt);
+ this->need_remove_loop_.push_back(loops_[i].back());
+ }
+ auto new_loop = For(new_loop_vars[0], Integer(0), new_loop_extents[0],
ForKind::kSerial,
+ SeqStmt(std::move(new_stmts)));
+ this->new_inner_loop_ = new_loop;
+ for (size_t i = 1; i < new_loop_vars.size(); ++i) {
+ const Var& loop_var = new_loop_vars[i];
+ const PrimExpr& loop_extent = new_loop_extents[i];
+ new_loop = For(loop_var, Integer(0), loop_extent, ForKind::kSerial,
new_loop);
+ }
+ this->new_outer_loop_ = new_loop;
+ }
+
+ private:
+ Stmt VisitStmt_(const BlockNode* block) final {
+ if (block != scope_root_.get()) {
+ return GetRef<Block>(block);
+ }
+ return StmtMutator::VisitStmt_(block);
+ }
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ if (loop == need_remove_loop_.back()) {
+ return new_outer_loop_;
+ } else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(),
loop)) {
+ return Evaluate(0);
+ }
+ return StmtMutator::VisitStmt_(loop);
+ }
+
+ Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final {
+ auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(seq_stmt, true));
+ Array<Stmt> filtered;
+ for (Stmt stmt : ret->seq) {
+ if (!is_no_op(stmt)) {
+ filtered.push_back(std::move(stmt));
+ }
+ }
+ ret = SeqStmt(filtered);
+ if (ret->size() == 0) {
+ return Evaluate(0);
+ } else if (ret->size() == 1) {
+ return ret->seq[0];
+ } else {
+ return std::move(ret);
+ }
+ }
+
+ public:
+ /*! \brief The root block of the block scope */
+ Block scope_root_;
+ /*! \brief The given loops to be merge */
+ const std::vector<std::vector<const ForNode*>>& loops_;
+ /*! \brief The outermost new loop to replace the original loop */
+ For new_outer_loop_{nullptr};
+ /*! \brief The innermost new loop to replace the original loop */
+ For new_inner_loop_{nullptr};
+ /*! \brief The loops to be removed */
+ std::vector<const ForNode*> need_remove_loop_;
+};
+
+StmtSRef Merge(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
+ // Invariance
+ // - The total repeat number has not changed for each direct child block.
+ // - The execution order has not changed. (The block executes with the same
+ // args and the same order with before.)
+ arith::Analyzer analyzer;
+ StmtSRef scope_root_sref;
+ StmtSRef lca = GetSRefLowestCommonAncestor(loop_srefs);
+ std::vector<std::vector<const ForNode*>> lca_nest_loops;
Review Comment:
use `For` instead of `ForNode` if possible
##########
tests/python/unittest/test_tir_schedule_merge.py:
##########
Review Comment:
Need a failure test if loops under different scope (parent block)
##########
src/tir/schedule/primitive/loop_transformation.cc:
##########
@@ -451,6 +451,163 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef&
loop_sref, const Array
return result_srefs;
}
+class LoopReconstructor : private StmtMutator {
+ public:
+ explicit LoopReconstructor(Block scope_root,
+ const std::vector<std::vector<const ForNode*>>&
loops)
+ : scope_root_(scope_root), loops_(loops) {}
+
+ using StmtMutator::operator();
+
+ /*!
+ * \brief Create the new nest loops induced by the given loops
+ */
+ void MakeNewLoop() {
+ Array<Var> new_loop_vars;
+ Array<PrimExpr> new_loop_extents;
+ Array<Stmt> new_stmts;
+ for (size_t i = 0; i < loops_.size(); i++) {
+ Map<Var, PrimExpr> var_map;
+ for (size_t j = 0; j < loops_[i].size(); j++) {
+ if (i == 0) {
+ Var merged_var = loops_[i][j]->loop_var.copy_with_suffix("_m");
+ new_loop_vars.push_back(merged_var);
+ new_loop_extents.push_back(loops_[i][j]->extent);
+ }
+ var_map.Set(loops_[i][j]->loop_var, new_loop_vars[j]);
+ }
+ auto new_stmt = Substitute(loops_[i][0]->body, var_map);
+ new_stmts.push_back(new_stmt);
+ this->need_remove_loop_.push_back(loops_[i].back());
+ }
+ auto new_loop = For(new_loop_vars[0], Integer(0), new_loop_extents[0],
ForKind::kSerial,
+ SeqStmt(std::move(new_stmts)));
+ this->new_inner_loop_ = new_loop;
+ for (size_t i = 1; i < new_loop_vars.size(); ++i) {
+ const Var& loop_var = new_loop_vars[i];
+ const PrimExpr& loop_extent = new_loop_extents[i];
+ new_loop = For(loop_var, Integer(0), loop_extent, ForKind::kSerial,
new_loop);
+ }
+ this->new_outer_loop_ = new_loop;
+ }
+
+ private:
+ Stmt VisitStmt_(const BlockNode* block) final {
+ if (block != scope_root_.get()) {
+ return GetRef<Block>(block);
+ }
+ return StmtMutator::VisitStmt_(block);
+ }
+
+ Stmt VisitStmt_(const ForNode* loop) final {
+ if (loop == need_remove_loop_.back()) {
+ return new_outer_loop_;
+ } else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(),
loop)) {
+ return Evaluate(0);
+ }
+ return StmtMutator::VisitStmt_(loop);
+ }
+
+ Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final {
+ auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(seq_stmt, true));
+ Array<Stmt> filtered;
+ for (Stmt stmt : ret->seq) {
+ if (!is_no_op(stmt)) {
+ filtered.push_back(std::move(stmt));
+ }
+ }
+ ret = SeqStmt(filtered);
+ if (ret->size() == 0) {
+ return Evaluate(0);
+ } else if (ret->size() == 1) {
+ return ret->seq[0];
+ } else {
+ return std::move(ret);
+ }
+ }
+
+ public:
+ /*! \brief The root block of the block scope */
+ Block scope_root_;
+ /*! \brief The given loops to be merge */
+ const std::vector<std::vector<const ForNode*>>& loops_;
Review Comment:
As it is a mutator rather than a visitor, let's prevent to use weak pointer
e.g.`const ForNode*`
```suggestion
const std::vector<std::vector<For>>& loops_;
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]