liangW-intellif commented on code in PR #13033:
URL: https://github.com/apache/tvm/pull/13033#discussion_r997974823


##########
src/tir/schedule/primitive/rolling_buffer.cc:
##########
@@ -0,0 +1,443 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <functional>
+
+#include "../ir_comparator.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+namespace {
+
+struct RollingBufferInfo {
+  Buffer old_buffer;
+  Buffer new_buffer;
+  int rolling_axis;
+  int rolling_extent;
+  std::vector<int> axis_overlaps;
+  std::vector<Optional<Var>> axis_iter_vars;
+  /*! \brief The map used for ScheduleStateNode::Replace. */
+  Map<Block, Block> block_reuse;
+};
+
+BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const 
BufferRegion& buffer_region,
+                                    const Map<Var, arith::IntSet>& dom_map) {
+  Array<arith::IntSet> relaxed_intsets =
+      arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), 
dom_map);
+  Region relaxed_region;
+  relaxed_region.reserve(relaxed_intsets.size());
+  for (size_t i = 0; i < relaxed_intsets.size(); ++i) {
+    relaxed_region.push_back(
+        relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, 
buffer_region->buffer->shape[i])));
+  }
+  return BufferRegion(buffer_region->buffer, relaxed_region);
+}
+
+class RollingBufferMatchError : public ScheduleError {
+ public:
+  RollingBufferMatchError(IRModule mod, Block block, BufferRegion 
buffer_region)
+      : mod_(mod), block_(block), buffer_region_(buffer_region) {}
+  String FastErrorString() const final {
+    return "ScheduleError: rolling_buffer expect the buffer region to have at 
least one dimention"
+           "matching the rolling pattern such as: hh.outer * stride + 
hh.inner";
+  }
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The target buffer " << buffer_region_->buffer->name << " with 
region "
+       << buffer_region_->region
+       << " should have at least one dimension range that matches a rolling 
pattern "
+          "such as hh.outer * stride + hh.inner. ";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+  IRModule mod_;
+  Block block_;
+  BufferRegion buffer_region_;
+};
+
+class RollingBufferInsertionError : public ScheduleError {
+ public:
+  RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block)
+      : mod_(mod), buffer_(std::move(buffer)), block_(block) {}
+  String FastErrorString() const final {
+    return "ScheduleError: rolling_buffer injection is invalid, the lca of the 
access "
+           "location of the target buffer is not a for loop. ";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "rolling_buffer injection is invalid. The block {0} should be tiled 
so that "
+       << "the lca of the access location of the target buffer " << 
buffer_->name
+       << " is a for loop. ";
+    return os.str();
+  }
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+  Block block_;
+};
+
+class RollingBufferInfoCollector {
+ public:
+  static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod,
+                                                        const StmtSRef& 
block_sref,
+                                                        const BufferRegion& 
buffer_region) {
+    RollingBufferInfoCollector collector;
+    if (!collector.MatchRollingBuffer(block_sref, buffer_region)) {
+      const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+      throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region);
+    }
+    return collector.info_;
+  }
+
+ private:
+  bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& 
buffer_region) {
+    const Buffer& buffer = buffer_region->buffer;
+    const Region& region = buffer_region->region;
+
+    std::vector<Optional<Var>> bound_iter_vars;
+    std::vector<int> bound_overlaps;
+    auto stride = 0;
+    auto divisor = 1;
+    Optional<Var> iter_var;
+    for (auto bound : region) {
+      divisor = 1;
+      if (auto floor_div = bound->min.as<FloorDivNode>()) {

Review Comment:
   Good idea! I updated this part with tir pattern matching, please take a look 
again.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to