junrushao1994 commented on a change in pull request #8467:
URL: https://github.com/apache/tvm/pull/8467#discussion_r669941050



##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -257,6 +257,133 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
         return _ffi_api_schedule.ScheduleGetLoops(self, block)  # type: ignore 
# pylint: disable=no-member
 
     ########## Schedule: loops manipulation ##########
+    def fuse(self, *loops: List[LoopRV]) -> LoopRV:
+        """Fuse a list of consecutive loops into one. It requires:
+        1) The loops can't have annotations or thread bindings.
+        2) The (i+1)-th loop must be the only child of the i-th loop.
+        3) All loops must start with 0.
+
+        Parameters
+        ----------
+        *loops : List[LoopRV]
+            The loops to be fused
+
+        Returns
+        ----------
+        fused_loop : LoopRV
+            The new loop after fusion
+
+        Examples
+        --------
+
+        Before fuse, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do fuse:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_fuse, debug_mode=True)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.fuse(i, j)
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying fuse, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, [128, 128])
+                for i0_i1_fused in tir.serial(0, 16384):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        tir.bind(vi, tir.floordiv(i0_i1_fused, 128))
+                        tir.bind(vj, tir.floormod(i0_i1_fused, 128))
+                        tir.reads([A[vi, vj]])
+                        tir.writes([B[vi, vj]])
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        """
+        return _ffi_api_schedule.ScheduleFuse(self, loops)  # type: ignore # 
pylint: disable=no-member
+
+    def split(
+        self,
+        loop: LoopRV,
+        factors: List[Optional[ExprRV]],
+    ) -> List[LoopRV]:
+        """Split a loop into a list of consecutive loops. It requires:
+        1) The loop can't have annotation or thread binding.
+        2) The loop must start with 0.
+        Predicates may be added to ensure the total loop numbers keeps 
unchanged.
+        In `factors`, at most one of the factors can be None or -1,
+        which will be automatically inferred.

Review comment:
       add a blank line below 

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -257,6 +257,133 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
         return _ffi_api_schedule.ScheduleGetLoops(self, block)  # type: ignore 
# pylint: disable=no-member
 
     ########## Schedule: loops manipulation ##########
+    def fuse(self, *loops: List[LoopRV]) -> LoopRV:
+        """Fuse a list of consecutive loops into one. It requires:
+        1) The loops can't have annotations or thread bindings.
+        2) The (i+1)-th loop must be the only child of the i-th loop.
+        3) All loops must start with 0.
+
+        Parameters
+        ----------
+        *loops : List[LoopRV]
+            The loops to be fused
+
+        Returns
+        ----------
+        fused_loop : LoopRV
+            The new loop after fusion
+
+        Examples
+        --------
+
+        Before fuse, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do fuse:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_fuse, debug_mode=True)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.fuse(i, j)
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying fuse, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, [128, 128])
+                for i0_i1_fused in tir.serial(0, 16384):

Review comment:
       add a comment here saying the two loops are fused

##########
File path: src/tir/schedule/analysis/analysis.cc
##########
@@ -298,5 +298,35 @@ Array<StmtSRef> GetChildBlocks(const ScheduleState& self, 
const StmtSRef& parent
   throw;
 }
 
+Array<Stmt> GetChildren(const Stmt& stmt) {
+  /*! \note Nested SeqStmt is not allowed in schedule. */
+  Stmt body;
+  if (const auto* block = stmt.as<BlockNode>()) {
+    body = block->body;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    body = loop->body;
+  } else {
+    LOG(FATAL) << "The Stmt can only be a Block or a For";

Review comment:
       ```suggestion
       ICHECK(false) << "... only takes Block or For as the input argument";
   ```

##########
File path: src/tir/schedule/analysis/analysis.cc
##########
@@ -298,5 +298,35 @@ Array<StmtSRef> GetChildBlocks(const ScheduleState& self, 
const StmtSRef& parent
   throw;
 }
 
+Array<Stmt> GetChildren(const Stmt& stmt) {
+  /*! \note Nested SeqStmt is not allowed in schedule. */
+  Stmt body;
+  if (const auto* block = stmt.as<BlockNode>()) {
+    body = block->body;
+  } else if (const auto* loop = stmt.as<ForNode>()) {
+    body = loop->body;
+  } else {
+    LOG(FATAL) << "The Stmt can only be a Block or a For";
+  }
+  if (const auto* seq = body.as<SeqStmtNode>()) {
+    Array<Stmt> ret;
+    for (const Stmt& child : seq->seq) {
+      ICHECK(!child->IsInstance<SeqStmtNode>()) << "Nested SeqStmt is not 
allowed in schedule.";
+      if (child->IsInstance<BlockRealizeNode>()) {
+        ret.push_back(child.as<BlockRealizeNode>()->block);

Review comment:
       don't cast twice
   
   ```suggestion
         if (const auto* realize = child.as<BlockRealizeNode>()) {
           ret.push_back(realize->block);
   ```

##########
File path: src/tir/schedule/primitive/fuse_split.cc
##########
@@ -0,0 +1,483 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+namespace tvm {
+namespace tir {
+
+/*! \brief Append a new predicate to the each children of type BlockRealize 
(not recursively) */
+class PredicateUpdater : public StmtMutator {
+ public:
+  /*!
+   * \brief Constructor
+   * \param predicate The predicate to be apppend to BlockRealizeNode
+   */
+  explicit PredicateUpdater(const PrimExpr& predicate, arith::Analyzer* ana)
+      : predicate_(predicate) {
+    if (!ana->CanProve(predicate)) {
+      add_predicate_ = true;
+    }
+  }
+
+ private:
+  // For each direct child of type BlockRealizeNode, append the predicate
+  Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+    // We do not recursively do this
+    if (add_predicate_) {
+      ObjectPtr<BlockRealizeNode> n = CopyOnWrite(realize);
+      n->predicate = n->predicate && predicate_;
+      return BlockRealize(n);
+    } else {
+      return GetRef<BlockRealize>(realize);
+    }
+  }
+
+  /*! \brief The predicate to be added */
+  const PrimExpr& predicate_;
+  /*! \brief whether to add predicate */
+  bool add_predicate_;
+};
+/*! \brief Substitute vars and collect the reuse mapping of opaque blocks */
+class IRSubstituteAndCollectOpaqueBlock : public StmtExprMutator {
+ public:
+  explicit 
IRSubstituteAndCollectOpaqueBlock(std::function<Optional<PrimExpr>(const Var&)> 
vmap,
+                                             Map<Block, Block>* opaque_blocks)
+      : vmap_(vmap), opaque_blocks_(opaque_blocks) {}
+
+ private:
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    Var var = GetRef<Var>(op);
+    Optional<PrimExpr> ret = vmap_(var);
+    if (ret.defined()) {
+      return ret.value();
+    } else {
+      return std::move(var);
+    }
+  }
+
+  Stmt VisitStmt_(const BlockRealizeNode* op) final {
+    Stmt res = StmtMutator::VisitStmt_(op);
+    if (op->block->iter_vars.empty()) {
+      const BlockRealizeNode* realize = res.as<BlockRealizeNode>();

Review comment:
       use TVM_TYPE_AS to enforce an ICHECK here

##########
File path: src/tir/schedule/primitive/fuse_split.cc
##########
@@ -0,0 +1,483 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+namespace tvm {
+namespace tir {
+
+/*! \brief Append a new predicate to the each children of type BlockRealize 
(not recursively) */
+class PredicateUpdater : public StmtMutator {
+ public:
+  /*!
+   * \brief Constructor
+   * \param predicate The predicate to be apppend to BlockRealizeNode
+   */
+  explicit PredicateUpdater(const PrimExpr& predicate, arith::Analyzer* ana)
+      : predicate_(predicate) {
+    if (!ana->CanProve(predicate)) {
+      add_predicate_ = true;
+    }

Review comment:
       IIUC the logic here is that we do nothing if the predicate is always 
true. If so, why we still call the visitor?

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -257,6 +257,133 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
         return _ffi_api_schedule.ScheduleGetLoops(self, block)  # type: ignore 
# pylint: disable=no-member
 
     ########## Schedule: loops manipulation ##########
+    def fuse(self, *loops: List[LoopRV]) -> LoopRV:
+        """Fuse a list of consecutive loops into one. It requires:
+        1) The loops can't have annotations or thread bindings.
+        2) The (i+1)-th loop must be the only child of the i-th loop.
+        3) All loops must start with 0.
+
+        Parameters
+        ----------
+        *loops : List[LoopRV]
+            The loops to be fused
+
+        Returns
+        ----------
+        fused_loop : LoopRV
+            The new loop after fusion
+
+        Examples
+        --------
+
+        Before fuse, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:

Review comment:
       To make it clearer that we are fusing loops, let's write the two loops 
out explicitly

##########
File path: src/arith/iter_affine_map.cc
##########
@@ -515,7 +515,6 @@ class IterMapRewriter : public ExprMutator {
    */
   Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
     if (!is_zero(expr->base)) return NullOpt;
-    if (expr->args.size() == 1) return expr->args[0];

Review comment:
       Why remove this line?

##########
File path: src/tir/schedule/concrete_schedule.h
##########
@@ -143,17 +147,22 @@ inline For ConcreteScheduleNode::Get(const LoopRV& 
loop_rv) const {
 }
 
 inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const {
-  auto it = this->symbol_table_.find(expr_rv);
-  if (it == this->symbol_table_.end()) {
-    LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << expr_rv;
-  }
-  const ObjectRef& obj = (*it).second;
-  const auto* expr_node = obj.as<PrimExprNode>();
-  if (expr_node == nullptr) {
-    LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: "
-               << (obj.defined() ? obj->GetTypeKey() : "None");
-  }
-  return GetRef<PrimExpr>(expr_node);
+  PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> 
Optional<PrimExpr> {
+    auto it = this->symbol_table_.find(var);
+    if (it == this->symbol_table_.end()) {
+      LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var;
+    }
+    const ObjectRef& obj = (*it).second;
+    const auto* int_imm = obj.as<IntImmNode>();

Review comment:
       Use `TVM_TYPE_AS` defined in utils.h

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=unused-import
 """The TensorIR schedule class"""
-from typing import List, Optional, Union
+from typing import List, Optional, Union, Tuple

Review comment:
       Why the file changes from 644 to 755? Consider chmod-ing it back to 644?

##########
File path: src/tir/schedule/concrete_schedule.cc
##########
@@ -258,6 +258,34 @@ Array<LoopRV> ConcreteScheduleNode::GetLoops(const 
BlockRV& block_rv) {
 }
 
 /******** Schedule: loops manipulation ********/
+
+LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
+  Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
+  StmtSRef fused_sref = tir::Fuse(state_, loop_srefs);
+  this->state_->DebugVerify();
+  return CreateRV<LoopRV>(fused_sref);
+  TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
+  throw;
+}
+
+Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv, const 
Array<ExprRV>& factor_rvs) {
+  TVM_TIR_SCHEDULE_BEGIN();

Review comment:
       ditto

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -257,6 +257,133 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
         return _ffi_api_schedule.ScheduleGetLoops(self, block)  # type: ignore 
# pylint: disable=no-member
 
     ########## Schedule: loops manipulation ##########
+    def fuse(self, *loops: List[LoopRV]) -> LoopRV:
+        """Fuse a list of consecutive loops into one. It requires:
+        1) The loops can't have annotations or thread bindings.
+        2) The (i+1)-th loop must be the only child of the i-th loop.
+        3) All loops must start with 0.
+
+        Parameters
+        ----------
+        *loops : List[LoopRV]
+            The loops to be fused
+
+        Returns
+        ----------
+        fused_loop : LoopRV
+            The new loop after fusion
+
+        Examples
+        --------
+
+        Before fuse, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do fuse:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_fuse, debug_mode=True)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.fuse(i, j)
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying fuse, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, [128, 128])
+                for i0_i1_fused in tir.serial(0, 16384):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        tir.bind(vi, tir.floordiv(i0_i1_fused, 128))
+                        tir.bind(vj, tir.floormod(i0_i1_fused, 128))
+                        tir.reads([A[vi, vj]])
+                        tir.writes([B[vi, vj]])
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        """
+        return _ffi_api_schedule.ScheduleFuse(self, loops)  # type: ignore # 
pylint: disable=no-member
+
+    def split(
+        self,
+        loop: LoopRV,
+        factors: List[Optional[ExprRV]],
+    ) -> List[LoopRV]:
+        """Split a loop into a list of consecutive loops. It requires:
+        1) The loop can't have annotation or thread binding.
+        2) The loop must start with 0.
+        Predicates may be added to ensure the total loop numbers keeps 
unchanged.
+        In `factors`, at most one of the factors can be None or -1,
+        which will be automatically inferred.
+        Parameters
+        ----------
+        loop : LoopRV
+            The loop to be split
+
+        factors: List[Optional[ExprRV]]
+            The splitting factors
+
+        Returns
+        ----------
+        split_loops : List[LoopRV]
+            The new loops after split
+
+        Examples
+        --------
+
+        Before split, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_split(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:

Review comment:
       ditto

##########
File path: src/tir/schedule/concrete_schedule.cc
##########
@@ -258,6 +258,34 @@ Array<LoopRV> ConcreteScheduleNode::GetLoops(const 
BlockRV& block_rv) {
 }
 
 /******** Schedule: loops manipulation ********/
+
+LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
+  TVM_TIR_SCHEDULE_BEGIN();
+  CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
+  Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
+  StmtSRef fused_sref = tir::Fuse(state_, loop_srefs);
+  this->state_->DebugVerify();
+  return CreateRV<LoopRV>(fused_sref);
+  TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
+  throw;

Review comment:
       Make sure the area that `TVM_TIR_SCHEDULE_BEGIN` encloses is minimal.
   
   ```suggestion
     CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 
loop(s)";
     Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
     StmtSRef result{nullptr};
     TVM_TIR_SCHEDULE_BEGIN();
     result = tir::Fuse(state_, loop_srefs);
     TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
     this->state_->DebugVerify();
     return CreateRV<LoopRV>(result);
   ```

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -257,6 +257,133 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
         return _ffi_api_schedule.ScheduleGetLoops(self, block)  # type: ignore 
# pylint: disable=no-member
 
     ########## Schedule: loops manipulation ##########
+    def fuse(self, *loops: List[LoopRV]) -> LoopRV:
+        """Fuse a list of consecutive loops into one. It requires:
+        1) The loops can't have annotations or thread bindings.
+        2) The (i+1)-th loop must be the only child of the i-th loop.
+        3) All loops must start with 0.
+
+        Parameters
+        ----------
+        *loops : List[LoopRV]
+            The loops to be fused
+
+        Returns
+        ----------
+        fused_loop : LoopRV
+            The new loop after fusion
+
+        Examples
+        --------
+
+        Before fuse, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do fuse:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_fuse, debug_mode=True)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.fuse(i, j)
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying fuse, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, [128, 128])
+                for i0_i1_fused in tir.serial(0, 16384):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        tir.bind(vi, tir.floordiv(i0_i1_fused, 128))
+                        tir.bind(vj, tir.floormod(i0_i1_fused, 128))
+                        tir.reads([A[vi, vj]])
+                        tir.writes([B[vi, vj]])
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        """
+        return _ffi_api_schedule.ScheduleFuse(self, loops)  # type: ignore # 
pylint: disable=no-member
+
+    def split(
+        self,
+        loop: LoopRV,
+        factors: List[Optional[ExprRV]],
+    ) -> List[LoopRV]:
+        """Split a loop into a list of consecutive loops. It requires:
+        1) The loop can't have annotation or thread binding.
+        2) The loop must start with 0.
+        Predicates may be added to ensure the total loop numbers keeps 
unchanged.
+        In `factors`, at most one of the factors can be None or -1,
+        which will be automatically inferred.
+        Parameters
+        ----------
+        loop : LoopRV
+            The loop to be split
+
+        factors: List[Optional[ExprRV]]
+            The splitting factors
+
+        Returns
+        ----------
+        split_loops : List[LoopRV]
+            The new loops after split
+
+        Examples
+        --------
+
+        Before split, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_split(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do fuse:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_split, debug_mode=True)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.split(i, factors=[2, 64])
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying split, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_split(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, [128, 128])
+                for i0_outer, i0_inner, i1 in tir.grid(2, 64, 128):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        tir.bind(vi, ((i0_outer*64) + i0_inner))
+                        tir.bind(vj, i1)
+                        tir.reads([A[vi, vj]])
+                        tir.writes([B[vi, vj]])
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        """
+        for i, factor in enumerate(factors):
+            if factor is None:
+                factors[i] = -1

Review comment:
       1) Why do we need to support `-1` given `None` is good enough?
   2) I suppose it is fine to pass into the FFI without this explicit 
conversion because tvm will do the conversion for you

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -257,6 +257,133 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
         return _ffi_api_schedule.ScheduleGetLoops(self, block)  # type: ignore 
# pylint: disable=no-member
 
     ########## Schedule: loops manipulation ##########
+    def fuse(self, *loops: List[LoopRV]) -> LoopRV:
+        """Fuse a list of consecutive loops into one. It requires:
+        1) The loops can't have annotations or thread bindings.
+        2) The (i+1)-th loop must be the only child of the i-th loop.
+        3) All loops must start with 0.
+
+        Parameters
+        ----------
+        *loops : List[LoopRV]
+            The loops to be fused
+
+        Returns
+        ----------
+        fused_loop : LoopRV
+            The new loop after fusion
+
+        Examples
+        --------
+
+        Before fuse, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do fuse:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_fuse, debug_mode=True)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.fuse(i, j)
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying fuse, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, [128, 128])
+                for i0_i1_fused in tir.serial(0, 16384):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        tir.bind(vi, tir.floordiv(i0_i1_fused, 128))
+                        tir.bind(vj, tir.floormod(i0_i1_fused, 128))
+                        tir.reads([A[vi, vj]])
+                        tir.writes([B[vi, vj]])
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        """
+        return _ffi_api_schedule.ScheduleFuse(self, loops)  # type: ignore # 
pylint: disable=no-member
+
+    def split(
+        self,
+        loop: LoopRV,
+        factors: List[Optional[ExprRV]],

Review comment:
       document what are potential inputs for factors, right now i can think of:
   
   - None
   - Nonnegative constant integers
   - ExprRV
   
   So the type should be `List[Union[None, int, ExprRV]]` instead

##########
File path: src/arith/iter_affine_map.cc
##########
@@ -1086,6 +1085,22 @@ 
TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter
   return NormalizeIterMapToExpr(expr);
 });
 
+Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, 
Range>& input_iters,
+                                const PrimExpr& input_pred, bool 
require_bijective) {
+  Analyzer analyzer;
+  Array<IterSumExpr> rewrite =
+      DetectIterMap(indices, input_iters, input_pred, require_bijective, 
&analyzer);
+  if (rewrite.empty()) {
+    return indices;
+  } else {

Review comment:
       You dont need the "else" here 

##########
File path: src/tir/schedule/analysis.h
##########
@@ -142,6 +142,12 @@ Array<StmtSRef> GetLoops(const StmtSRef& block_sref);
  * \return A list of leaf blocks
  */
 Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& 
parent_sref);
+/*!
+ * \brief Get the direct child Schedulable Stmt (Block and For)
+ * \param stmt the parent stmt.
+ * \return the list of child stmts
+ */
+Array<Stmt> GetChildren(const Stmt& stmt);

Review comment:
       GetDirectChildrenInSRefTree

##########
File path: src/tir/schedule/analysis/analysis.cc
##########
@@ -298,5 +298,35 @@ Array<StmtSRef> GetChildBlocks(const ScheduleState& self, 
const StmtSRef& parent
   throw;
 }
 
+Array<Stmt> GetChildren(const Stmt& stmt) {
+  /*! \note Nested SeqStmt is not allowed in schedule. */

Review comment:
       This implementation will lead to potential issues when it comes to a 
non-stage pipeline case, where there might be statements other than 
`BlockRealize`/`Loop`. For example:
   
   ```python
   for ...: <= GetChildren on this loop
     if ...:
   ```
   
   Consider reimplementing it with a visitor that stops at BlockNode/ForNode

##########
File path: src/tir/schedule/primitive/fuse_split.cc
##########
@@ -0,0 +1,483 @@
+/*

Review comment:
       Rename `fuse_split.cc` to `loop_transformation.cc`

##########
File path: src/tir/schedule/primitive/fuse_split.cc
##########
@@ -0,0 +1,483 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+namespace tvm {
+namespace tir {
+
+/*! \brief Append a new predicate to the each children of type BlockRealize 
(not recursively) */
+class PredicateUpdater : public StmtMutator {

Review comment:
       Pick a more informative name, like, `BlockPredicateAppender`

##########
File path: src/tir/schedule/primitive/fuse_split.cc
##########
@@ -0,0 +1,483 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+namespace tvm {
+namespace tir {
+
+/*! \brief Append a new predicate to the each children of type BlockRealize 
(not recursively) */
+class PredicateUpdater : public StmtMutator {
+ public:
+  /*!
+   * \brief Constructor
+   * \param predicate The predicate to be apppend to BlockRealizeNode
+   */
+  explicit PredicateUpdater(const PrimExpr& predicate, arith::Analyzer* ana)

Review comment:
       ```suggestion
     explicit PredicateUpdater(const PrimExpr& to_append, arith::Analyzer* 
analyzer)
   ```

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -257,6 +257,133 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
         return _ffi_api_schedule.ScheduleGetLoops(self, block)  # type: ignore 
# pylint: disable=no-member
 
     ########## Schedule: loops manipulation ##########
+    def fuse(self, *loops: List[LoopRV]) -> LoopRV:
+        """Fuse a list of consecutive loops into one. It requires:
+        1) The loops can't have annotations or thread bindings.
+        2) The (i+1)-th loop must be the only child of the i-th loop.
+        3) All loops must start with 0.
+
+        Parameters
+        ----------
+        *loops : List[LoopRV]
+            The loops to be fused
+
+        Returns
+        ----------
+        fused_loop : LoopRV
+            The new loop after fusion
+
+        Examples
+        --------
+
+        Before fuse, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do fuse:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_fuse, debug_mode=True)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.fuse(i, j)
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying fuse, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_fuse(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, [128, 128])
+                for i0_i1_fused in tir.serial(0, 16384):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        tir.bind(vi, tir.floordiv(i0_i1_fused, 128))
+                        tir.bind(vj, tir.floormod(i0_i1_fused, 128))
+                        tir.reads([A[vi, vj]])
+                        tir.writes([B[vi, vj]])
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        """
+        return _ffi_api_schedule.ScheduleFuse(self, loops)  # type: ignore # 
pylint: disable=no-member
+
+    def split(
+        self,
+        loop: LoopRV,
+        factors: List[Optional[ExprRV]],
+    ) -> List[LoopRV]:
+        """Split a loop into a list of consecutive loops. It requires:
+        1) The loop can't have annotation or thread binding.
+        2) The loop must start with 0.
+        Predicates may be added to ensure the total loop numbers keeps 
unchanged.
+        In `factors`, at most one of the factors can be None or -1,
+        which will be automatically inferred.
+        Parameters
+        ----------
+        loop : LoopRV
+            The loop to be split
+
+        factors: List[Optional[ExprRV]]
+            The splitting factors
+
+        Returns
+        ----------
+        split_loops : List[LoopRV]
+            The new loops after split
+
+        Examples
+        --------
+
+        Before split, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_split(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, (128, 128))
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+
+        Create the schedule and do fuse:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_split, debug_mode=True)
+            i, j = sch.get_loops(sch.get_block("B"))
+            sch.split(i, factors=[2, 64])
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying split, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_split(a: ty.handle, b: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128))
+                B = tir.match_buffer(b, [128, 128])
+                for i0_outer, i0_inner, i1 in tir.grid(2, 64, 128):
+                    with tir.block([128, 128], "B") as [vi, vj]:
+                        tir.bind(vi, ((i0_outer*64) + i0_inner))
+                        tir.bind(vj, i1)
+                        tir.reads([A[vi, vj]])
+                        tir.writes([B[vi, vj]])
+                        B[vi, vj] = A[vi, vj] * 2.0
+
+        """
+        for i, factor in enumerate(factors):
+            if factor is None:

Review comment:
       agree with cody




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to