wrongtest-intellif commented on code in PR #13033:
URL: https://github.com/apache/tvm/pull/13033#discussion_r1000130417


##########
src/tir/schedule/primitive/rolling_buffer.cc:
##########
@@ -0,0 +1,443 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <functional>
+
+#include "../ir_comparator.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+namespace {
+
+struct RollingBufferInfo {
+  Buffer old_buffer;
+  Buffer new_buffer;
+  int rolling_axis;
+  int rolling_extent;
+  std::vector<int> axis_overlaps;
+  std::vector<Optional<Var>> axis_iter_vars;
+  /*! \brief The map used for ScheduleStateNode::Replace. */
+  Map<Block, Block> block_reuse;
+};
+
+BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const 
BufferRegion& buffer_region,
+                                    const Map<Var, arith::IntSet>& dom_map) {
+  Array<arith::IntSet> relaxed_intsets =
+      arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), 
dom_map);
+  Region relaxed_region;
+  relaxed_region.reserve(relaxed_intsets.size());
+  for (size_t i = 0; i < relaxed_intsets.size(); ++i) {
+    relaxed_region.push_back(
+        relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, 
buffer_region->buffer->shape[i])));
+  }
+  return BufferRegion(buffer_region->buffer, relaxed_region);
+}
+
+class RollingBufferMatchError : public ScheduleError {
+ public:
+  RollingBufferMatchError(IRModule mod, Block block, BufferRegion 
buffer_region)
+      : mod_(mod), block_(block), buffer_region_(buffer_region) {}
+  String FastErrorString() const final {
+    return "ScheduleError: rolling_buffer expect the buffer region to have at 
least one dimention"
+           "matching the rolling pattern such as: hh.outer * stride + 
hh.inner";
+  }
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The target buffer " << buffer_region_->buffer->name << " with 
region "
+       << buffer_region_->region
+       << " should have at least one dimension range that matches a rolling 
pattern "
+          "such as hh.outer * stride + hh.inner. ";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+  IRModule mod_;
+  Block block_;
+  BufferRegion buffer_region_;
+};
+
+class RollingBufferInsertionError : public ScheduleError {
+ public:
+  RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block)
+      : mod_(mod), buffer_(std::move(buffer)), block_(block) {}
+  String FastErrorString() const final {
+    return "ScheduleError: rolling_buffer injection is invalid, the lca of the 
access "
+           "location of the target buffer is not a for loop. ";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "rolling_buffer injection is invalid. The block {0} should be tiled 
so that "
+       << "the lca of the access location of the target buffer " << 
buffer_->name
+       << " is a for loop. ";
+    return os.str();
+  }
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+  Block block_;
+};
+
+class RollingBufferInfoCollector {
+ public:
+  static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod,
+                                                        const StmtSRef& 
block_sref,
+                                                        const BufferRegion& 
buffer_region) {
+    RollingBufferInfoCollector collector;
+    if (!collector.MatchRollingBuffer(block_sref, buffer_region)) {
+      const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+      throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region);
+    }
+    return collector.info_;
+  }
+
+ private:
+  bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& 
buffer_region) {
+    const Buffer& buffer = buffer_region->buffer;
+    const Region& region = buffer_region->region;
+
+    std::vector<Optional<Var>> bound_iter_vars;
+    std::vector<int> bound_overlaps;
+    auto stride = 0;
+    auto divisor = 1;
+    Optional<Var> iter_var;
+    for (auto bound : region) {
+      divisor = 1;
+      if (auto floor_div = bound->min.as<FloorDivNode>()) {
+        // Handle the case of fractional strides
+        // They take this form: floordiv(hh.outer, 2)
+        // Strip the floordiv and keep track of the divisor
+        divisor = Downcast<IntImm>(floor_div->b)->value;
+        bound = Range::FromMinExtent(floor_div->a, bound->extent, bound->span);
+      }
+      if (bound->min.as<IntImmNode>()) {
+        // If the bound is an int, we can't roll over it
+        iter_var = NullOpt;
+      } else if (auto var = bound->min.as<VarNode>()) {
+        // If the bound is just a Var, that implies the stride is 1
+        iter_var = GetRef<Var>(var);
+        stride = 1;
+      } else {
+        // Otherwise, it's the iter var multiplied by the stride
+        // If not we're in unknown behaviour
+        if (auto mul = bound->min.as<MulNode>()) {
+          if (mul->a->IsInstance<VarNode>() && 
mul->b->IsInstance<IntImmNode>()) {
+            iter_var = Downcast<Var>(mul->a);
+            stride = Downcast<IntImm>(mul->b)->value;
+          } else {
+            return false;
+          }
+        } else {
+          return false;
+        }
+      }
+      stride = std::ceil(static_cast<float>(stride) / divisor);
+      auto bound_overlap = 0;
+      if (iter_var.defined()) {
+        auto extent = bound->extent.as<IntImmNode>();
+        ICHECK(extent);
+        bound_overlap = extent->value - stride;
+        // Since Pass CompactBufferAllocation will be responsible for 
compacting the buffer
+        // allocation region, there is no need to roll over the axis where the 
overlap is not
+        // positive, so reset iter_var to NullOpt.
+        if (bound_overlap <= 0) {
+          iter_var = NullOpt;
+        }
+      }
+      bound_iter_vars.push_back(iter_var);
+      bound_overlaps.push_back(bound_overlap);
+    }
+    Array<StmtSRef> loop_srefs = GetLoops(block_sref);
+    // Pick the outermost iter_var that's mentioned in the bounds
+    // to be the rolling axis
+    Optional<Var> roll_iter_var;
+    int roll_axis;
+    for (const tir::StmtSRef& loop_sref : loop_srefs) {
+      auto loop_var = loop_sref->StmtAs<ForNode>()->loop_var;
+
+      auto it{std::find_if(bound_iter_vars.begin(), bound_iter_vars.end(), 
[&](Optional<Var> var) {
+        return var && (var.get() == loop_var.get());
+      })};
+      if (it != bound_iter_vars.end()) {
+        auto i = std::distance(bound_iter_vars.begin(), it);
+        roll_iter_var = loop_var;
+        roll_axis = i;
+        break;
+      }
+    }
+
+    if (!roll_iter_var.defined()) {
+      return false;
+    }
+    Array<PrimExpr> new_shape = buffer->shape;
+    new_shape.Set(roll_axis, region[roll_axis]->extent);
+    Buffer new_buffer = buffer;
+    new_buffer.CopyOnWrite()->shape = new_shape;
+
+    info_.old_buffer = buffer;
+    info_.new_buffer = new_buffer;
+    info_.rolling_axis = roll_axis;
+    info_.rolling_extent = 
static_cast<int>(Downcast<IntImm>(region[roll_axis]->extent)->value);
+    info_.axis_overlaps = bound_overlaps;
+    info_.axis_iter_vars = bound_iter_vars;
+
+    return true;
+  }
+
+  RollingBufferInfo info_;
+};
+
+class RollingBufferRewriter : public StmtExprMutator {
+ public:
+  static Stmt Rewrite(const StmtSRef& scope_sref, RollingBufferInfo* info) {
+    RollingBufferRewriter rewriter(scope_sref, info);
+    return rewriter(GetRef<Stmt>(scope_sref->stmt));
+  }
+
+ private:
+  explicit RollingBufferRewriter(const StmtSRef& scope_sref, 
RollingBufferInfo* info)
+      : scope_sref_(scope_sref), info_(info) {}
+
+  void RewriteAccessRegion(Array<BufferRegion>* old_access_regions,
+                           const Array<BufferRegion>& infered_access_regions) {
+    auto fmutate = [this, &infered_access_regions](const BufferRegion& 
buffer_region) {
+      if (buffer_region->buffer.same_as(info_->old_buffer)) {
+        ICHECK(infered_access_regions.size() == 1);
+        return infered_access_regions[0];
+      }
+      return buffer_region;
+    };
+    (*old_access_regions).MutateByApply(fmutate);
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    Block old_stmt = GetRef<Block>(block);
+    Block stmt = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
+    if (block == scope_sref_->stmt) {
+      ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
+
+      Array<Buffer> new_alloc_buffers;
+      for (const Buffer& buffer : stmt->alloc_buffers) {
+        if (buffer != info_->old_buffer) {
+          new_alloc_buffers.push_back(buffer);
+        } else {
+          new_alloc_buffers.push_back(info_->new_buffer);
+        }
+      }
+      n->alloc_buffers = std::move(new_alloc_buffers);
+      stmt = Block(n);
+    } else {
+      Array<IterVar> new_iter_bindings;
+      for (size_t i = 0; i < stmt->iter_vars.size(); ++i) {
+        auto old_iter_var = stmt->iter_vars[i];
+        if (static_cast<int>(i) == info_->rolling_axis) {
+          // All inner loops of the rolling axis has a loop carried dependency
+          // (i.e. each iteration calculation of the rolling axis depends on
+          // the calculation results of all the historical iterations of inner 
loops),
+          // so annotate the iteration type of the rolling axis as 'opaque',
+          // avoid the iterative range of its inner loop from being compressed
+          // during lowering phase.
+          IterVar new_iter_var =
+              IterVar(old_iter_var->dom, old_iter_var->var, 
IterVarType::kOpaque);
+          new_iter_bindings.push_back(new_iter_var);
+        } else {
+          new_iter_bindings.push_back(old_iter_var);
+        }
+      }
+      Map<Var, Buffer> buffer_data_to_buffer = {{info_->new_buffer->data, 
info_->new_buffer}};
+      auto infered_access_regions = GetBlockReadWriteRegion(stmt, 
buffer_data_to_buffer);
+
+      BlockNode* n = stmt.CopyOnWrite();
+      n->iter_vars = std::move(new_iter_bindings);
+      RewriteAccessRegion(&n->reads, infered_access_regions[0]);
+      RewriteAccessRegion(&n->writes, infered_access_regions[1]);
+    }
+    info_->block_reuse.Set(old_stmt, stmt);
+    return std::move(stmt);
+  }
+
+  Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+    BlockRealize stmt = 
Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(realize));
+    // Append block predicate to avoid recomputing elements.
+    if (rewrite_block_predicate_) {
+      rewrite_block_predicate_ = false;
+      PrimExpr condition = stmt->predicate;
+      for (size_t i = 0; i < info_->axis_iter_vars.size(); ++i) {
+        auto iter_var = info_->axis_iter_vars[i];
+        if (iter_var && info_->axis_overlaps[i] > 0) {
+          Var var = iter_var.value();
+          const Map<Var, arith::IntSet> dmap = {std::make_pair(var, 
arith::IntSet::Interval(0, 0))};
+          auto iter_value = realize->iter_values[i];
+          arith::Analyzer analyzer;
+          auto term_2 = analyzer.int_set(iter_value, dmap).min();
+          condition = analyzer.Simplify(
+              And(condition, Or(LT(var, 1), GE(term_2, 
info_->axis_overlaps[i]))));
+        }
+      }
+      BlockRealizeNode* n = stmt.CopyOnWrite();
+      n->predicate = condition;
+    }
+    return std::move(stmt);
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    BufferStore stmt = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    if (stmt->buffer.same_as(info_->old_buffer)) {
+      Array<PrimExpr> new_indices;
+      new_indices.reserve(stmt->indices.size());
+      // First modify the access indices to use modulo arithmetic
+      // for the rolling axis
+      for (size_t i = 0; i < stmt->indices.size(); ++i) {
+        auto index = stmt->indices[i];
+        if (static_cast<int>(i) == info_->rolling_axis) {
+          new_indices.push_back(FloorMod(index, info_->rolling_extent));
+        } else {
+          new_indices.push_back(index);
+        }
+      }
+      BufferStoreNode* n = stmt.CopyOnWrite();
+      // Replace the stored buffer with the new buffer.
+      n->buffer = info_->new_buffer;
+      n->indices = std::move(new_indices);
+      // Need to add predicate to the current block to avoid recomputing 
elements.
+      rewrite_block_predicate_ = true;
+    }
+    return std::move(stmt);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    BufferLoad stmt = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    if (stmt->buffer.same_as(info_->old_buffer)) {
+      Array<PrimExpr> new_indices;
+      new_indices.reserve(stmt->indices.size());
+      for (size_t i{0}; i < stmt->indices.size(); ++i) {
+        auto index = stmt->indices[i];
+        if (static_cast<int>(i) == info_->rolling_axis) {
+          new_indices.push_back(FloorMod(index, info_->rolling_extent));
+        } else {
+          new_indices.push_back(index);
+        }
+      }
+      BufferLoadNode* n = stmt.CopyOnWrite();
+      // Replace the loaded buffer with the new buffer.
+      n->buffer = info_->new_buffer;
+      n->indices = std::move(new_indices);
+    }
+    return std::move(stmt);
+  }
+
+ private:
+  const StmtSRef& scope_sref_;
+  RollingBufferInfo* info_;
+  bool rewrite_block_predicate_ = false;
+};
+
+}  // namespace
+
+void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int 
write_buffer_index) {
+  /*!
+   *  Check
+   *    - The block is not an output block.
+   *    - The block is tiled and there is access overlap between adjacent 
tiles.
+   *  Mutate
+   *    - Select the outermost rollable axis appeared in the block's loop nest
+   *      as the 'rolling axis', trim the target buffer from the rolling axis.
+   *    - Use modulo arithmetic to modify the target buffer's read and load
+   *      indices to circularize the buffer along the rolling dimension.
+   *    - Append block predicate to avoid recomputing overlapping elements.
+   */
+  Map<Var, arith::IntSet> dom_map;
+  const BlockRealize& realize = GetBlockRealize(self, block_sref);
+  const Block& block = realize->block;
+
+  // Step 1. Checking index, getting the target buffer region and the parent 
scope.
+  const BufferRegion& buffer_region =
+      GetNthAccessBufferRegion(self, block, write_buffer_index, 
BufferIndexType::kWrite);
+  StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, 
/*require_stage_pipeline=*/false);

Review Comment:
   Ok, that makes sense to me. I agree use `require_stage_pipeline=False`. Then 
could we try best to check there is no complex rw dependencies. Eg, the 
consumer-producer relation is unique? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to