MasterJH5574 commented on a change in pull request #8943:
URL: https://github.com/apache/tvm/pull/8943#discussion_r703530318



##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -305,6 +305,38 @@ class ScheduleNode : public runtime::Object {
   virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
                              const String& storage_scope) = 0;
   /******** Schedule: Compute location ********/
+  /*!
+   * \brief Move a producer block under the specific loop, and regenerate the 
loops induced by the
+   * block so that the buffer region generated by the producer block could 
cover those regions read
+   * by the consumers. It requires:
+   * 1) The scope block has stage-pipeline property
+   * 2) The given block's subtree of the scope block satisfies compact 
dataflow condition.
+   * i.e. all the blocks in the scope's subtree must be either complete block 
or reduction block
+   * 3) `block` and `loop` are under the same scope, `loop` is not the 
ancestor of `block`

Review comment:
       I think moving condition 3 above condition 1 looks more natural, because 
in condition 3 we tell that "`block` and `loop` are under a same scope", and 
then in condition 1 we tell that "the scope block should have the 
stage-pipeline property". What do you think?

##########
File path: include/tvm/arith/int_set.h
##########
@@ -121,6 +121,13 @@ class IntSet : public ObjectRef {
    * \return The result set containing the indices in the vector.
    */
   static IntSet Vector(PrimExpr vec);
+  /*!
+   * \brief Construct a set representing a range [min, min + extent).
+   * \param min The minimum of the range range
+   * \param extent The extent of the range.
+   * \return constructed set.

Review comment:
       ```suggestion
      * \return The constructed set.
   ```

##########
File path: src/arith/int_set.cc
##########
@@ -607,6 +607,13 @@ inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, 
PrimExpr rhs) {
   return is_zero(analyzer->Simplify(lhs - rhs));
 }
 
+IntSet IntSet::FromMinExtent(PrimExpr min, PrimExpr extent) {
+  if (is_one(extent)) {
+    return IntSet::SinglePoint(min);
+  }
+  return IntervalSet(min, extent + min - 1);

Review comment:
       Is there any specific reason of writing `extent + min - 1` instead of 
`min + extent - 1`?

##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -305,6 +305,38 @@ class ScheduleNode : public runtime::Object {
   virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
                              const String& storage_scope) = 0;
   /******** Schedule: Compute location ********/
+  /*!
+   * \brief Move a producer block under the specific loop, and regenerate the 
loops induced by the
+   * block so that the buffer region generated by the producer block could 
cover those regions read
+   * by the consumers. It requires:
+   * 1) The scope block has stage-pipeline property
+   * 2) The given block's subtree of the scope block satisfies compact 
dataflow condition.
+   * i.e. all the blocks in the scope's subtree must be either complete block 
or reduction block
+   * 3) `block` and `loop` are under the same scope, `loop` is not the 
ancestor of `block`
+   * 4) The block is not an output block,
+   * i.e. the buffer regions written by the block are allocated under the 
current scope
+   * 5) All the consumers of the block are under the given loop
+   * \param block_rv The block to be moved
+   * \param loop_rv The loop where the block to be moved under
+   * \param preserve_unit_loops Whether to keep the trivial loops whose 
extents are 1
+   */
+  virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
+                         bool preserve_unit_loops) = 0;
+  /*!
+   * \brief Move a consumer block under the specific loop, and regenerate the 
loops induced by the
+   * block so that the buffer region generated by the consumer block could 
cover those regions read
+   * by the consumers. It requires:

Review comment:
       ```suggestion
      * block so that the buffer region generated by the consumer block could 
cover those regions written
      * by the producers. It requires:
   ```
   BTW, after applying this suggestion, line 327 will have 102 characters in 
total. So we should make an earlier line break.

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -927,6 +927,181 @@ def after_cache_write(a: ty.handle, b: ty.handle) -> None:
 
     ########## Schedule: Compute location ##########
 
+    def compute_at(
+        self,
+        block: BlockRV,
+        loop: LoopRV,
+        preserve_unit_loops: bool = False,
+    ) -> None:
+        """Compute-At. Move a producer block under the specific loop, and 
regenerate the loops
+        induced by the block so that the buffer region generated by the 
producer block could cover
+        those regions read by the consumers. It requires:
+
+        1) The scope block has stage-pipeline property
+
+        2) The given block's subtree of the scope block satisfies compact 
dataflow condition.
+        i.e. all the blocks in the scope's subtree must be either complete 
block or reduction block
+
+        3) `block` and `loop` are under the same scope, `loop` is not the 
ancestor of `block`
+
+        4) The block is not an output block,
+        i.e. the buffer regions written by the block are allocated under the 
current scope
+
+        5) All the consumers of the block are under the given loop
+
+        Parameters
+        ----------
+        block : BlockRV
+            The block to be moved
+
+        loop: LoopRV
+            The loop where the block to be moved under
+
+        preserve_unit_loops: bool
+            Whether to keep the trivial loops whose extents are 1
+
+        Examples
+        --------
+
+        Before compute-at, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_compute_at(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128), "float32")
+                B = tir.alloc_buffer((128, 128), "float32")
+                C = tir.match_buffer(c, (128, 128), "float32")
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+                with tir.block([128, 128], "C") as [vi, vj]:
+                    C[vi, vj] = B[vi, vj] + 1.0
+
+        Create the schedule and do compute-inline:

Review comment:
       🤐
   ```suggestion
           Create the schedule and do compute-at:
   ```

##########
File path: src/relay/transforms/fold_scale_axis.cc
##########
@@ -243,7 +243,9 @@ class ForwardPrep : private MixedModeVisitor {
     }
   }
   // Visitor pattern override.
-  void VisitExpr_(const LetNode* op) {
+  void VisitExpr_(const TupleGetItemNode* op) final { 
MixedModeVisitor::VisitExpr_(op); }

Review comment:
       What's this :eyes:?

##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -305,6 +305,38 @@ class ScheduleNode : public runtime::Object {
   virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
                              const String& storage_scope) = 0;
   /******** Schedule: Compute location ********/
+  /*!
+   * \brief Move a producer block under the specific loop, and regenerate the 
loops induced by the
+   * block so that the buffer region generated by the producer block could 
cover those regions read
+   * by the consumers. It requires:
+   * 1) The scope block has stage-pipeline property
+   * 2) The given block's subtree of the scope block satisfies compact 
dataflow condition.
+   * i.e. all the blocks in the scope's subtree must be either complete block 
or reduction block
+   * 3) `block` and `loop` are under the same scope, `loop` is not the 
ancestor of `block`
+   * 4) The block is not an output block,
+   * i.e. the buffer regions written by the block are allocated under the 
current scope
+   * 5) All the consumers of the block are under the given loop
+   * \param block_rv The block to be moved
+   * \param loop_rv The loop where the block to be moved under
+   * \param preserve_unit_loops Whether to keep the trivial loops whose 
extents are 1
+   */
+  virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
+                         bool preserve_unit_loops) = 0;
+  /*!
+   * \brief Move a consumer block under the specific loop, and regenerate the 
loops induced by the
+   * block so that the buffer region generated by the consumer block could 
cover those regions read
+   * by the consumers. It requires:
+   * 1) The scope block has stage-pipeline property
+   * 2) The given block's subtree of the scope block satisfies compact 
dataflow condition.
+   * i.e. all the blocks in the scope's subtree must be either complete block 
or reduction block
+   * 3) `block` and `loop` are under the same scope, `loop` is not the 
ancestor of `block`

Review comment:
       Ditto as the comments for ComputeAt.

##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -305,6 +305,38 @@ class ScheduleNode : public runtime::Object {
   virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
                              const String& storage_scope) = 0;
   /******** Schedule: Compute location ********/
+  /*!
+   * \brief Move a producer block under the specific loop, and regenerate the 
loops induced by the
+   * block so that the buffer region generated by the producer block could 
cover those regions read
+   * by the consumers. It requires:
+   * 1) The scope block has stage-pipeline property
+   * 2) The given block's subtree of the scope block satisfies compact 
dataflow condition.
+   * i.e. all the blocks in the scope's subtree must be either complete block 
or reduction block

Review comment:
       I think we should differentiate the "scope" in line 312 with the "scope" 
in line 314. To be more specific, in line 312 the "scope" is the scope block of 
`block` and `loop`, while in line 314 the "scope" is exactly the input block.

##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -927,6 +927,181 @@ def after_cache_write(a: ty.handle, b: ty.handle) -> None:
 
     ########## Schedule: Compute location ##########
 
+    def compute_at(
+        self,
+        block: BlockRV,
+        loop: LoopRV,
+        preserve_unit_loops: bool = False,
+    ) -> None:
+        """Compute-At. Move a producer block under the specific loop, and 
regenerate the loops
+        induced by the block so that the buffer region generated by the 
producer block could cover
+        those regions read by the consumers. It requires:
+
+        1) The scope block has stage-pipeline property
+
+        2) The given block's subtree of the scope block satisfies compact 
dataflow condition.
+        i.e. all the blocks in the scope's subtree must be either complete 
block or reduction block
+
+        3) `block` and `loop` are under the same scope, `loop` is not the 
ancestor of `block`
+
+        4) The block is not an output block,
+        i.e. the buffer regions written by the block are allocated under the 
current scope
+
+        5) All the consumers of the block are under the given loop
+
+        Parameters
+        ----------
+        block : BlockRV
+            The block to be moved
+
+        loop: LoopRV
+            The loop where the block to be moved under
+
+        preserve_unit_loops: bool
+            Whether to keep the trivial loops whose extents are 1
+
+        Examples
+        --------
+
+        Before compute-at, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_compute_at(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128), "float32")
+                B = tir.alloc_buffer((128, 128), "float32")
+                C = tir.match_buffer(c, (128, 128), "float32")
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+                with tir.block([128, 128], "C") as [vi, vj]:
+                    C[vi, vj] = B[vi, vj] + 1.0
+
+        Create the schedule and do compute-inline:
+
+        .. code-block:: python
+
+            sch = tir.Schedule(before_compute_at)
+            block = sch.get_block("B")
+            loop, _ = sch.get_loops(sch.get_block("C"))
+            sch.compute_at(block, loop, preserve_unit_loops=False)
+            print(tvm.script.asscript(sch.mod["main"]))
+
+        After applying compute-at, the IR becomes:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def after_compute_at(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128), "float32")
+                B = tir.alloc_buffer((128, 128), "float32")
+                C = tir.match_buffer(c, (128, 128), "float32")
+                for i in tir.serial(0, 128):
+                    for j in tir.serial(0, 128):
+                        with tir.block([128, 128], "B") as [vi, vj]:
+                            tir.bind(vi, i)
+                            tir.bind(vj, j)
+                            B[vi, vj] = A[vi, vj] * 2.0
+                    for j in tir.serial(0, 128):
+                        with tir.block([128, 128], "C") as [vi, vj]:
+                            tir.bind(vi, i)
+                            tir.bind(vj, j)
+                            C[vi, vj] = B[vi, vj] + 1.0
+
+        """
+        _ffi_api.ScheduleComputeAt(  # type: ignore # pylint: disable=no-member
+            self,
+            block,
+            loop,
+            preserve_unit_loops,
+        )
+
+    def reverse_compute_at(
+        self,
+        block: BlockRV,
+        loop: LoopRV,
+        preserve_unit_loops: bool = False,
+    ) -> None:
+        """Reverse-Compute-At. Move a consumer block under the specific loop, 
and regenerate the
+        loops induced by the block so that the buffer region generated by the 
consumer block could
+        cover those regions read by the consumers. It requires:
+
+        1) The scope block has stage-pipeline property
+
+        2) The given block's subtree of the scope block satisfies compact 
dataflow condition.
+        i.e. all the blocks in the scope's subtree must be either complete 
block or reduction block
+
+        3) `block` and `loop` are under the same scope, `loop` is not the 
ancestor of `block`
+
+        4) All the producers of the block are under the given loop
+
+        Parameters
+        ----------
+        block : BlockRV
+            The block to be moved
+
+        loop: LoopRV
+            The loop where the block to be moved under
+
+        preserve_unit_loops: bool
+            Whether to keep the trivial loops whose extents are 1
+
+        Examples
+        --------
+
+        Before reverse-compute-at, in TensorIR, the IR is:
+
+        .. code-block:: python
+
+            @tvm.script.tir
+            def before_reverse_compute_at(a: ty.handle, c: ty.handle) -> None:
+                A = tir.match_buffer(a, (128, 128), "float32")
+                B = tir.alloc_buffer((128, 128), "float32")
+                C = tir.match_buffer(c, (128, 128), "float32")
+                with tir.block([128, 128], "B") as [vi, vj]:
+                    B[vi, vj] = A[vi, vj] * 2.0
+                with tir.block([128, 128], "C") as [vi, vj]:
+                    C[vi, vj] = B[vi, vj] + 1.0
+
+        Create the schedule and do compute-inline:

Review comment:
       ```suggestion
           Create the schedule and do reverse-compute-at:
   ```

##########
File path: src/support/nd_int_set.h
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SUPPORT_ND_INT_SET_H_
+#define TVM_SUPPORT_ND_INT_SET_H_
+
+#include <tvm/arith/int_set.h>
+#include <tvm/ir/expr.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+/*! \brief An N-dimensional integer set representing a rectangle region */
+using NDIntSet = std::vector<tvm::arith::IntSet>;
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a region.
+ * \param region The region.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromRegion(const tir::Region& region) {
+  NDIntSet result;
+  result.reserve(region.size());
+  for (const Range& range : region) {
+    result.push_back(arith::IntSet::FromRange(range));
+  }
+  return result;
+}
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a shape.
+ * \param shape The shape which is an array of the length of each dimension.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromShape(const Array<PrimExpr>& shape) {
+  PrimExpr zero = Integer(0);
+  NDIntSet result;
+  result.reserve(shape.size());
+  for (const PrimExpr& extent : shape) {
+    result.push_back(arith::IntSet::FromMinExtent(zero, extent));
+  }
+  return result;
+}
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a point.
+ * \param indices The N-dimensional indices representing the point.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromPoint(const Array<PrimExpr>& indices) {
+  NDIntSet result;
+  result.reserve(indices.size());
+  for (const PrimExpr& index : indices) {
+    result.push_back(arith::IntSet::SinglePoint(index));
+  }
+  return result;
+}
+
+/*!
+ * \brief Create a union set of two sets, possibly relaxed. The RHS set will 
be combined into the
+ *        LHS set.
+ * \param lhs The first N-dimensional integer set
+ * \param rhs The second N-dimensional integer set
+ */
+inline void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) {
+  ICHECK_EQ(lhs->size(), rhs.size());
+  int ndim = rhs.size();
+  for (int i = 0; i < ndim; ++i) {
+    arith::IntSet& int_set = lhs->at(i);
+    int_set = arith::Union({int_set, rhs.at(i)});
+  }
+}
+
+/*!
+ * \brief Union a list of N-dimensional integer sets
+ * \param nd_int_sets The N-dimensional integer sets to be merged.
+ * \return The result of the union
+ */
+inline NDIntSet NDIntSetUnion(const std::vector<NDIntSet>& nd_int_sets) {

Review comment:
       Again the question that why choosing `std::vector` :eyes:.

##########
File path: src/support/nd_int_set.h
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SUPPORT_ND_INT_SET_H_
+#define TVM_SUPPORT_ND_INT_SET_H_
+
+#include <tvm/arith/int_set.h>
+#include <tvm/ir/expr.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+/*! \brief An N-dimensional integer set representing a rectangle region */
+using NDIntSet = std::vector<tvm::arith::IntSet>;

Review comment:
       * Q1. What's the reason that using `std::vector` instead of `Array`?
   * Q2. Is it possible to simplify the template name as below?
   ```suggestion
   using NDIntSet = std::vector<arith::IntSet>;
   ```

##########
File path: src/tir/schedule/analysis.h
##########
@@ -128,18 +138,36 @@ void CheckReductionBlock(const ScheduleState& self, const 
StmtSRef& block_sref,
                          const StmtSRef& scope_root_sref);
 
 /*!
- * \brief Check whether a subtree on SRef tree has compact data flow, and 
throw an exception if the
- * subtree does not have compact data flow
- * \details For a given StmtSRef, We say the subtree rooted from the StmtSRef 
has "compact data
- * flow" property if:
- * - the scope root of the input subtree root has stage-pipeline property, and
- * - all its child blocks on SRef tree are complete blocks or reduction blocks.
+ * \brief Check if the block is a complete block or a reduction block under 
the scope
  * \param self The schedule state
- * \param subtree_root_sref The root of the subtree to be checked in the SRef 
tree
- * \throw ScheduleError If the subtree does not have compact data flow
- * \sa IsCompleteBlock, IsReductionBlock
+ * \param block_sref The sref of the block to be checked
+ * \param scope_root_sref The scope root of the block
+ * \throw ScheduleError If the block is not a reduction block
+ */
+void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& 
block_sref,
+                                   const StmtSRef& scope_root_sref);
+
+/*!
+ * \brief Check if the block is an output block, i.e. the buffer regions 
written by the block are
+ * allocated under the current scope

Review comment:
       I cannot understand the description here. After reading the 
implementation, I know that `IsOutputBlock` returns true if the input block 
writes to some buffer that isn't allocated by the scope block. But my 
understanding seems not matching the description here.

##########
File path: src/tir/schedule/analysis.h
##########
@@ -128,18 +138,36 @@ void CheckReductionBlock(const ScheduleState& self, const 
StmtSRef& block_sref,
                          const StmtSRef& scope_root_sref);
 
 /*!
- * \brief Check whether a subtree on SRef tree has compact data flow, and 
throw an exception if the
- * subtree does not have compact data flow
- * \details For a given StmtSRef, We say the subtree rooted from the StmtSRef 
has "compact data
- * flow" property if:
- * - the scope root of the input subtree root has stage-pipeline property, and
- * - all its child blocks on SRef tree are complete blocks or reduction blocks.
+ * \brief Check if the block is a complete block or a reduction block under 
the scope
  * \param self The schedule state
- * \param subtree_root_sref The root of the subtree to be checked in the SRef 
tree
- * \throw ScheduleError If the subtree does not have compact data flow
- * \sa IsCompleteBlock, IsReductionBlock
+ * \param block_sref The sref of the block to be checked
+ * \param scope_root_sref The scope root of the block
+ * \throw ScheduleError If the block is not a reduction block

Review comment:
       ```suggestion
    * \throw ScheduleError If the block is neither a complete block nor a 
reduction block
   ```

##########
File path: src/support/nd_int_set.h
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SUPPORT_ND_INT_SET_H_
+#define TVM_SUPPORT_ND_INT_SET_H_
+
+#include <tvm/arith/int_set.h>
+#include <tvm/ir/expr.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+/*! \brief An N-dimensional integer set representing a rectangle region */
+using NDIntSet = std::vector<tvm::arith::IntSet>;
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a region.
+ * \param region The region.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromRegion(const tir::Region& region) {
+  NDIntSet result;
+  result.reserve(region.size());
+  for (const Range& range : region) {
+    result.push_back(arith::IntSet::FromRange(range));
+  }
+  return result;
+}
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a shape.
+ * \param shape The shape which is an array of the length of each dimension.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromShape(const Array<PrimExpr>& shape) {
+  PrimExpr zero = Integer(0);
+  NDIntSet result;
+  result.reserve(shape.size());
+  for (const PrimExpr& extent : shape) {
+    result.push_back(arith::IntSet::FromMinExtent(zero, extent));
+  }
+  return result;
+}
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a point.
+ * \param indices The N-dimensional indices representing the point.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromPoint(const Array<PrimExpr>& indices) {
+  NDIntSet result;
+  result.reserve(indices.size());
+  for (const PrimExpr& index : indices) {
+    result.push_back(arith::IntSet::SinglePoint(index));
+  }
+  return result;
+}
+
+/*!
+ * \brief Create a union set of two sets, possibly relaxed. The RHS set will 
be combined into the
+ *        LHS set.
+ * \param lhs The first N-dimensional integer set
+ * \param rhs The second N-dimensional integer set
+ */
+inline void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) {
+  ICHECK_EQ(lhs->size(), rhs.size());
+  int ndim = rhs.size();
+  for (int i = 0; i < ndim; ++i) {
+    arith::IntSet& int_set = lhs->at(i);
+    int_set = arith::Union({int_set, rhs.at(i)});
+  }
+}
+
+/*!
+ * \brief Union a list of N-dimensional integer sets
+ * \param nd_int_sets The N-dimensional integer sets to be merged.
+ * \return The result of the union
+ */
+inline NDIntSet NDIntSetUnion(const std::vector<NDIntSet>& nd_int_sets) {
+  ICHECK(!nd_int_sets.empty());
+  int n = nd_int_sets.size();
+  if (n == 1) {
+    return nd_int_sets[0];
+  }
+  int ndim = nd_int_sets[0].size();
+  for (int i = 1; i < n; ++i) {
+    ICHECK_EQ(nd_int_sets[i].size(), ndim);
+  }
+  NDIntSet result;
+  result.reserve(ndim);
+  Array<arith::IntSet> int_sets(n, arith::IntSet{nullptr});
+  for (int dim = 0; dim < ndim; ++dim) {
+    for (int i = 0; i < n; ++i) {
+      int_sets.Set(i, nd_int_sets[i][dim]);
+    }
+    result.push_back(arith::Union(int_sets));
+  }
+  return result;
+}
+
+/*!
+ * \brief Create an empty N-dimensional integer set.

Review comment:
       ```suggestion
    * \brief Create an empty N-dimensional integer set with specific number of 
dimensions.
   ```

##########
File path: src/tir/schedule/primitive/cache_read_write.cc
##########
@@ -160,6 +160,21 @@ Block MakeCacheStage(const BufferRegion& cache_region, 
CacheStageInfo* info,
   return block;
 }
 
+/*!
+ * \brief Recalculate the `affine_binding` flag of the scope block info.
+ * \param block_sref The sref to the interested scope block.
+ */
+bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& 
block_sref) {
+  if (block_sref->parent == nullptr) {
+    return true;
+  }

Review comment:
       I have a question for a long period of time: if a block has no `IterVar` 
(i.e., opaque block), does it have affine bindings?

##########
File path: src/tir/schedule/primitive.h
##########
@@ -160,6 +160,42 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const 
StmtSRef& block_sref, int r
 TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, 
int write_buffer_index,
                             const String& storage_scope);
 /******** Schedule: Compute location ********/
+/*!
+ * \brief Move a producer block under the specific loop, and regenerate the 
loops induced by the
+ * block so that the buffer region generated by the producer block could cover 
those regions read
+ * by the consumers. It requires:
+ * 1) The scope block has stage-pipeline property
+ * 2) The given block's subtree of the scope block satisfies compact dataflow 
condition.
+ * i.e. all the blocks in the scope's subtree must be either complete block or 
reduction block
+ * 3) `block` and `loop` are under the same scope, `loop` is not the ancestor 
of `block`
+ * 4) The block is not an output block,
+ * i.e. the buffer regions written by the block are allocated under the 
current scope

Review comment:
       Ditto as the comments in schedule.h.

##########
File path: src/tir/schedule/primitive/cache_read_write.cc
##########
@@ -160,6 +160,21 @@ Block MakeCacheStage(const BufferRegion& cache_region, 
CacheStageInfo* info,
   return block;
 }
 
+/*!
+ * \brief Recalculate the `affine_binding` flag of the scope block info.

Review comment:
       ```suggestion
    * \brief Recalculate the `affine_binding` flag of the input block's 
BlockInfo.
   ```

##########
File path: src/support/nd_int_set.h
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SUPPORT_ND_INT_SET_H_
+#define TVM_SUPPORT_ND_INT_SET_H_
+
+#include <tvm/arith/int_set.h>
+#include <tvm/ir/expr.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+/*! \brief An N-dimensional integer set representing a rectangle region */
+using NDIntSet = std::vector<tvm::arith::IntSet>;
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a region.
+ * \param region The region.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromRegion(const tir::Region& region) {
+  NDIntSet result;
+  result.reserve(region.size());
+  for (const Range& range : region) {
+    result.push_back(arith::IntSet::FromRange(range));
+  }
+  return result;
+}
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a shape.
+ * \param shape The shape which is an array of the length of each dimension.
+ * \return constructed set.

Review comment:
       ```suggestion
    * \return The constructed set.
   ```

##########
File path: src/tir/schedule/analysis/analysis.cc
##########
@@ -174,6 +212,18 @@ int CheckCompleteBlockErrorCode(const ScheduleState& self, 
const StmtSRef& block
   return 0;
 }
 
+static const char* kCompleteBlockDefinition = R"(Definition of a complete 
block:

Review comment:
       What does "k" mean here 👀?

##########
File path: src/support/nd_int_set.h
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SUPPORT_ND_INT_SET_H_
+#define TVM_SUPPORT_ND_INT_SET_H_
+
+#include <tvm/arith/int_set.h>
+#include <tvm/ir/expr.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+/*! \brief An N-dimensional integer set representing a rectangle region */
+using NDIntSet = std::vector<tvm::arith::IntSet>;
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a region.
+ * \param region The region.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromRegion(const tir::Region& region) {
+  NDIntSet result;
+  result.reserve(region.size());
+  for (const Range& range : region) {
+    result.push_back(arith::IntSet::FromRange(range));
+  }
+  return result;
+}
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a shape.
+ * \param shape The shape which is an array of the length of each dimension.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromShape(const Array<PrimExpr>& shape) {
+  PrimExpr zero = Integer(0);
+  NDIntSet result;
+  result.reserve(shape.size());
+  for (const PrimExpr& extent : shape) {
+    result.push_back(arith::IntSet::FromMinExtent(zero, extent));
+  }
+  return result;
+}
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a point.
+ * \param indices The N-dimensional indices representing the point.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromPoint(const Array<PrimExpr>& indices) {
+  NDIntSet result;
+  result.reserve(indices.size());
+  for (const PrimExpr& index : indices) {
+    result.push_back(arith::IntSet::SinglePoint(index));
+  }
+  return result;
+}
+
+/*!
+ * \brief Create a union set of two sets, possibly relaxed. The RHS set will 
be combined into the
+ *        LHS set.
+ * \param lhs The first N-dimensional integer set
+ * \param rhs The second N-dimensional integer set
+ */
+inline void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) {
+  ICHECK_EQ(lhs->size(), rhs.size());
+  int ndim = rhs.size();
+  for (int i = 0; i < ndim; ++i) {
+    arith::IntSet& int_set = lhs->at(i);
+    int_set = arith::Union({int_set, rhs.at(i)});
+  }
+}
+
+/*!
+ * \brief Union a list of N-dimensional integer sets
+ * \param nd_int_sets The N-dimensional integer sets to be merged.
+ * \return The result of the union
+ */
+inline NDIntSet NDIntSetUnion(const std::vector<NDIntSet>& nd_int_sets) {
+  ICHECK(!nd_int_sets.empty());
+  int n = nd_int_sets.size();
+  if (n == 1) {
+    return nd_int_sets[0];
+  }
+  int ndim = nd_int_sets[0].size();
+  for (int i = 1; i < n; ++i) {
+    ICHECK_EQ(nd_int_sets[i].size(), ndim);
+  }
+  NDIntSet result;
+  result.reserve(ndim);
+  Array<arith::IntSet> int_sets(n, arith::IntSet{nullptr});
+  for (int dim = 0; dim < ndim; ++dim) {
+    for (int i = 0; i < n; ++i) {
+      int_sets.Set(i, nd_int_sets[i][dim]);
+    }
+    result.push_back(arith::Union(int_sets));
+  }
+  return result;
+}
+
+/*!
+ * \brief Create an empty N-dimensional integer set.
+ * \param ndim The number of dimensions.
+ * \return constructed set.

Review comment:
       ```suggestion
    * \return The constructed set.
   ```

##########
File path: src/support/nd_int_set.h
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SUPPORT_ND_INT_SET_H_
+#define TVM_SUPPORT_ND_INT_SET_H_
+
+#include <tvm/arith/int_set.h>
+#include <tvm/ir/expr.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+/*! \brief An N-dimensional integer set representing a rectangle region */
+using NDIntSet = std::vector<tvm::arith::IntSet>;
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a region.
+ * \param region The region.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromRegion(const tir::Region& region) {
+  NDIntSet result;
+  result.reserve(region.size());
+  for (const Range& range : region) {
+    result.push_back(arith::IntSet::FromRange(range));
+  }
+  return result;
+}
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a shape.
+ * \param shape The shape which is an array of the length of each dimension.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromShape(const Array<PrimExpr>& shape) {
+  PrimExpr zero = Integer(0);

Review comment:
       I wonder whether extracting the definition of `Integer(0)` can reduce 
the number of constructions of zero?

##########
File path: src/support/nd_int_set.h
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SUPPORT_ND_INT_SET_H_
+#define TVM_SUPPORT_ND_INT_SET_H_
+
+#include <tvm/arith/int_set.h>
+#include <tvm/ir/expr.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+/*! \brief An N-dimensional integer set representing a rectangle region */
+using NDIntSet = std::vector<tvm::arith::IntSet>;
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a region.
+ * \param region The region.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromRegion(const tir::Region& region) {
+  NDIntSet result;
+  result.reserve(region.size());
+  for (const Range& range : region) {
+    result.push_back(arith::IntSet::FromRange(range));
+  }
+  return result;
+}
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a shape.
+ * \param shape The shape which is an array of the length of each dimension.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromShape(const Array<PrimExpr>& shape) {
+  PrimExpr zero = Integer(0);
+  NDIntSet result;
+  result.reserve(shape.size());
+  for (const PrimExpr& extent : shape) {
+    result.push_back(arith::IntSet::FromMinExtent(zero, extent));
+  }
+  return result;
+}
+
+/*!
+ * \brief Construct an N-dimensional integer set representing a point.
+ * \param indices The N-dimensional indices representing the point.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetFromPoint(const Array<PrimExpr>& indices) {
+  NDIntSet result;
+  result.reserve(indices.size());
+  for (const PrimExpr& index : indices) {
+    result.push_back(arith::IntSet::SinglePoint(index));
+  }
+  return result;
+}
+
+/*!
+ * \brief Create a union set of two sets, possibly relaxed. The RHS set will 
be combined into the
+ *        LHS set.
+ * \param lhs The first N-dimensional integer set
+ * \param rhs The second N-dimensional integer set
+ */
+inline void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) {
+  ICHECK_EQ(lhs->size(), rhs.size());
+  int ndim = rhs.size();
+  for (int i = 0; i < ndim; ++i) {
+    arith::IntSet& int_set = lhs->at(i);
+    int_set = arith::Union({int_set, rhs.at(i)});
+  }
+}
+
+/*!
+ * \brief Union a list of N-dimensional integer sets
+ * \param nd_int_sets The N-dimensional integer sets to be merged.
+ * \return The result of the union
+ */
+inline NDIntSet NDIntSetUnion(const std::vector<NDIntSet>& nd_int_sets) {
+  ICHECK(!nd_int_sets.empty());
+  int n = nd_int_sets.size();
+  if (n == 1) {
+    return nd_int_sets[0];
+  }
+  int ndim = nd_int_sets[0].size();
+  for (int i = 1; i < n; ++i) {
+    ICHECK_EQ(nd_int_sets[i].size(), ndim);
+  }
+  NDIntSet result;
+  result.reserve(ndim);
+  Array<arith::IntSet> int_sets(n, arith::IntSet{nullptr});
+  for (int dim = 0; dim < ndim; ++dim) {
+    for (int i = 0; i < n; ++i) {
+      int_sets.Set(i, nd_int_sets[i][dim]);
+    }
+    result.push_back(arith::Union(int_sets));
+  }
+  return result;
+}
+
+/*!
+ * \brief Create an empty N-dimensional integer set.
+ * \param ndim The number of dimensions.
+ * \return constructed set.
+ */
+inline NDIntSet NDIntSetEmpty(int ndim) {

Review comment:
       `EmptyNDInsSet` or `NDIntSetEmpty`. Which name sounds better 🤔.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to