MasterJH5574 commented on a change in pull request #8943:
URL: https://github.com/apache/tvm/pull/8943#discussion_r704094627



##########
File path: src/tir/schedule/primitive/cache_read_write.cc
##########
@@ -160,6 +160,21 @@ Block MakeCacheStage(const BufferRegion& cache_region, 
CacheStageInfo* info,
   return block;
 }
 
+/*!
+ * \brief Recalculate the `affine_binding` flag of a specifc block
+ * \param block_sref The sref to the specific block
+ */
+bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& 
block_sref) {

Review comment:
       This recalculator is probably used by primitive DecomposeReduction, so 
does it make sense to move it into `schedule/analysis.h`?

##########
File path: src/tir/schedule/analysis.h
##########
@@ -128,18 +138,36 @@ void CheckReductionBlock(const ScheduleState& self, const 
StmtSRef& block_sref,
                          const StmtSRef& scope_root_sref);
 
 /*!
- * \brief Check whether a subtree on SRef tree has compact data flow, and 
throw an exception if the
- * subtree does not have compact data flow
- * \details For a given StmtSRef, We say the subtree rooted from the StmtSRef 
has "compact data
- * flow" property if:
- * - the scope root of the input subtree root has stage-pipeline property, and
- * - all its child blocks on SRef tree are complete blocks or reduction blocks.
+ * \brief Check if the block is a complete block or a reduction block under 
the scope
  * \param self The schedule state
- * \param subtree_root_sref The root of the subtree to be checked in the SRef 
tree
- * \throw ScheduleError If the subtree does not have compact data flow
- * \sa IsCompleteBlock, IsReductionBlock
+ * \param block_sref The sref of the block to be checked
+ * \param scope_root_sref The scope root of the block
+ * \throw ScheduleError If the block is neither a complete block nor a 
reduction block
+ */
+void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& 
block_sref,
+                                   const StmtSRef& scope_root_sref);
+
+/*!
+ * \brief Check if the block is an output block, i.e. the block writes to at 
least a buffer that is
+ * not allocated under the current scope
+ * \param self The schedule state
+ * \param block_sref The block to be checked
+ * \param scope_root_sref The scope root of the block
+ * \return A boolean flag indicating if the block is an output block
+ */
+bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
+                   const StmtSRef& scope_root_sref);
+
+/*!
+ * \brief Check if the block is not an output block, i.e. all the buffers the 
block writes to
+ * are allocated under the current scope

Review comment:
       We should also update this brief doc to make it the same as 
`IsOutputBlock`'s.

##########
File path: src/tir/schedule/concrete_schedule.cc
##########
@@ -439,6 +439,44 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& 
block_rv, int write_buff
 
 /******** Schedule: Compute location ********/
 
+void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& 
loop_rv,
+                                     bool preserve_unit_loops) {
+  static StmtSRef inline_mark = StmtSRef::InlineMark();
+  static StmtSRef root_mark = StmtSRef::RootMark();
+  StmtSRef loop_sref = this->GetSRef(loop_rv);

Review comment:
       Sorry that I couldn't find the place where `inline_mark` and `root_mark` 
are inserted into `symbol_table_`. Could you help point it out?

##########
File path: src/tir/schedule/primitive/compute_at.cc
##########
@@ -0,0 +1,584 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+using support::NDIntSet;
+
+/******** Error Classes ********/
+
+/*!
+ * \brief An error raised when not all required blocks are under the given 
loop.
+ * \tparam is_consumer Indicates if all the required blocks are consumers or 
producers
+ */
+template <bool is_consumer>
+class NotAllRequiredBlocksAreVisitedError : public ScheduleError {
+ public:
+  explicit NotAllRequiredBlocksAreVisitedError(IRModule mod, int 
num_not_visited,
+                                               Array<StmtSRef> required)

Review comment:
       It seems better to make the parameter a constant reference 🤔?
   ```suggestion
                                                  const Array<StmtSRef>& 
required)
   ```

##########
File path: src/tir/schedule/primitive/compute_at.cc
##########
@@ -0,0 +1,584 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+using support::NDIntSet;
+
+/******** Error Classes ********/
+
+/*!
+ * \brief An error raised when not all required blocks are under the given 
loop.
+ * \tparam is_consumer Indicates if all the required blocks are consumers or 
producers
+ */
+template <bool is_consumer>
+class NotAllRequiredBlocksAreVisitedError : public ScheduleError {
+ public:
+  explicit NotAllRequiredBlocksAreVisitedError(IRModule mod, int 
num_not_visited,
+                                               Array<StmtSRef> required)
+      : mod_(mod), num_not_visited_(num_not_visited) {
+    required_.reserve(required.size());
+    for (const StmtSRef& block_sref : required) {
+      const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
+      required_.push_back(GetRef<Block>(block));
+    }
+  }
+
+  String FastErrorString() const final {
+    return "ScheduleError: Not all required blocks are under the loop scope";
+  }
+
+  String DetailRenderTemplate() const final {
+    String relation = is_consumer ? "consumer(s)" : "producer(s)";
+    std::ostringstream os;
+    os << "The primitive requires all the " << relation
+       << " of the given block to be present under the target loop. However, 
there are "
+       << num_not_visited_ << " " << relation
+       << " not satisfying the constraint. List of the producers:";

Review comment:
       ```suggestion
          << " not satisfying the constraint. List of the " << relation << ":";
   ```

##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -305,6 +305,38 @@ class ScheduleNode : public runtime::Object {
   virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
                              const String& storage_scope) = 0;
   /******** Schedule: Compute location ********/
+  /*!
+   * \brief Move a producer block under the specific loop, and regenerate the 
loops induced by the
+   * block so that the buffer region generated by the producer block could 
cover those regions read
+   * by the consumers. It requires:
+   * 1) The scope block has stage-pipeline property
+   * 2) The given block's subtree of the scope block satisfies compact 
dataflow condition.
+   * i.e. all the blocks in the scope's subtree must be either complete block 
or reduction block
+   * 3) `block` and `loop` are under the same scope, `loop` is not the 
ancestor of `block`
+   * 4) The block is not an output block,
+   * i.e. the buffer regions written by the block are allocated under the 
current scope
+   * 5) All the consumers of the block are under the given loop
+   * \param block_rv The block to be moved
+   * \param loop_rv The loop where the block to be moved under
+   * \param preserve_unit_loops Whether to keep the trivial loops whose 
extents are 1
+   */
+  virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
+                         bool preserve_unit_loops) = 0;
+  /*!
+   * \brief Move a consumer block under the specific loop, and regenerate the 
loops induced by the
+   * block so that the buffer region generated by the consumer block could 
cover those regions read
+   * by the consumers. It requires:

Review comment:
       :eyes:

##########
File path: src/tir/schedule/concrete_schedule.cc
##########
@@ -439,6 +439,44 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& 
block_rv, int write_buff
 
 /******** Schedule: Compute location ********/
 
+void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& 
loop_rv,
+                                     bool preserve_unit_loops) {
+  static StmtSRef inline_mark = StmtSRef::InlineMark();
+  static StmtSRef root_mark = StmtSRef::RootMark();
+  StmtSRef loop_sref = this->GetSRef(loop_rv);
+  if (loop_sref.same_as(root_mark)) {
+    return;
+  } else if (loop_sref.same_as(inline_mark)) {
+    TVM_TIR_SCHEDULE_BEGIN();
+    tir::ComputeInline(state_, this->GetSRef(block_rv));
+    TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
+  } else {
+    TVM_TIR_SCHEDULE_BEGIN();
+    tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, 
preserve_unit_loops);
+    TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
+  }
+  this->state_->DebugVerify();
+}
+
+void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const 
LoopRV& loop_rv,
+                                            bool preserve_unit_loops) {
+  static StmtSRef inline_mark = StmtSRef::InlineMark();
+  static StmtSRef root_mark = StmtSRef::RootMark();
+  StmtSRef loop_sref = this->GetSRef(loop_rv);
+  if (loop_sref.same_as(root_mark)) {
+    // do nothing

Review comment:
       Why we do nothing here, and directly return at line 448?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to