junrushao1994 commented on a change in pull request #8693:
URL: https://github.com/apache/tvm/pull/8693#discussion_r688010398



##########
File path: src/tir/schedule/primitive/block_annotate.cc
##########
@@ -0,0 +1,315 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+class StorageAlignAxisOutOfRangeError : public ScheduleError {
+ public:
+  explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int 
axis)
+      : mod_(std::move(mod)), buffer_(std::move(buffer)), axis_(axis) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The input `axis` is out of range. It is required to 
be in range "
+           "[-ndim, ndim) where `ndim` is the number of dimensions of the 
buffer to set "
+           "storage alignment.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    int ndim = static_cast<int>(buffer_->shape.size());
+    os << "The buffer to set storage alignment of, " << buffer_->name << ", 
has " << ndim
+       << " dimension(s), so `axis` is required to be in [" << -(ndim) << ", " 
<< ndim
+       << ") for storage_align. However, the input `axis` is " << axis_
+       << ", which is out of the expected range.";
+    return os.str();
+  }
+
+  IRModule mod() const final { return mod_; }
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+
+  static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int 
axis) {
+    int ndim = static_cast<int>(buffer->shape.size());
+    if (axis < -ndim || axis >= ndim) {
+      throw StorageAlignAxisOutOfRangeError(mod, buffer, axis);
+    }
+    // If axis is negative, convert it to a non-negative one.
+    if (axis < 0) {
+      axis += ndim;
+    }
+    return axis;
+  }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+  int axis_;
+};
+
+/*!
+ * \brief Find the defining site of the buffer in the given block and its 
ancestors
+ * \param block_sref The block sref
+ * \param buffer The buffer
+ * \return The defining site of the buffer and whether the buffer is allocated 
(otherwise the
+ *         buffer is from match_buffer).
+ */
+std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& 
block_sref,
+                                                          const Buffer& 
buffer) {
+  // Climb up along the sref tree, and find the block where `buffer` is in 
alloc_buffers or
+  // match_buffers.
+  const StmtSRefNode* defining_site_sref = block_sref.get();
+  while (defining_site_sref != nullptr) {
+    const auto* block = defining_site_sref->StmtAs<BlockNode>();
+    // If this sref is not a block sref, skip it.
+    if (block == nullptr) {
+      defining_site_sref = defining_site_sref->parent;
+      continue;
+    }
+    // Try to find the buffer in `allloc_buffers`
+    for (const Buffer& alloc_buffer : block->alloc_buffers) {
+      if (buffer.same_as(alloc_buffer)) {
+        return {GetRef<StmtSRef>(defining_site_sref), true};
+      }
+    }
+    // We do not allow the buffer being defined in `match_buffer`.
+    for (const MatchBufferRegion match_buffer : block->match_buffers) {
+      if (buffer.same_as(match_buffer)) {
+        return {GetRef<StmtSRef>(defining_site_sref), false};
+      }
+    }
+    defining_site_sref = defining_site_sref->parent;
+  }
+  // If we cannot find the defining site block, it means that the buffer must 
be in the function's
+  // buffer_map, which isn't an intermediate buffer.
+  return {NullOpt, false};
+}
+
+class NonAllocatedBufferError : public ScheduleError {
+ public:
+  explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod), 
buffer_(buffer) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The input buffer is not allocated by a block. This 
means the buffer is "
+           " either a function parameter or defined in `match_buffer` of a 
block.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The input buffer " << buffer_->name
+       << " is not allocated by a block. This means the buffer is either a 
function parameter or "
+          "defined in `match_buffer` of a block.";
+    return os.str();
+  }
+
+  static void CheckBufferAllocated(const IRModule& mod, const StmtSRef& 
block_sref,
+                                   const Buffer& buffer) {
+    Optional<StmtSRef> defining_site_sref;
+    bool is_alloc;
+    std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, 
buffer);
+    if (!defining_site_sref || !is_alloc) {
+      throw NonAllocatedBufferError(mod, buffer);
+    }
+  }
+
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+  IRModule mod() const final { return mod_; }
+
+ private:
+  IRModule mod_;
+  Buffer buffer_;
+};
+
+class StorageAlignInvalidFactorError : public ScheduleError {
+ public:
+  explicit StorageAlignInvalidFactorError(IRModule mod, int factor)
+      : mod_(std::move(mod)), factor_(factor) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The input `factor` of storage_align is expected to 
be a positive "
+           "number.";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The input `factor` of storage_align is expected to be a positive 
number. However, the "
+          "input `factor` is "
+       << factor_ << ", which is out of the expected range.";
+    return os.str();
+  }
+
+  static void Check(const IRModule& mod, int factor) {
+    if (factor <= 0) {
+      throw StorageAlignInvalidFactorError(mod, factor);
+    }
+  }
+
+  Array<ObjectRef> LocationsOfInterest() const final { return {}; }
+  IRModule mod() const final { return mod_; }
+
+ private:
+  IRModule mod_;
+  int factor_;
+};
+
+class StorageAlignInvalidAnnotationError : public ScheduleError {
+ public:
+  explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block)
+      : mod_(std::move(mod)), block_(std::move(block)) {}
+
+  String FastErrorString() const final {
+    return "ScheduleError: The block annotation for storage align is expected 
to be an array of "
+           "3-integer-tuples (axis, factor, offset).";
+  }
+
+  String DetailRenderTemplate() const final {
+    std::ostringstream os;
+    os << "The block annotation for storage align is expected to be an array 
of 3-integer-tuples "
+          "(axis, factor, offset). However, the block annotation with key "
+       << attr::buffer_dim_align << " of the block {0} is "
+       << block_->annotations.at(attr::buffer_dim_align) << ", which is 
unexpected.";
+    return os.str();
+  }
+
+  static Array<Array<Array<Integer>>> CheckAndGetAnnotation(const IRModule& 
mod,
+                                                            const Block& 
block) {
+    // Get existing annotation value.
+    auto it = block->annotations.find(attr::buffer_dim_align);
+    if (it != block->annotations.end()) {
+      if (!IsValidAnnotation(block, (*it).second)) {
+        throw StorageAlignInvalidAnnotationError(mod, block);
+      }
+      return Downcast<Array<Array<Array<Integer>>>>((*it).second);
+    }
+
+    // Create new annotation value
+    Array<Array<Array<Integer>>> storage_align_annotation;
+    storage_align_annotation.resize(block->writes.size());

Review comment:
       Yes, I was proposing the lazily store buffer_index as well: 
https://github.com/apache/tvm/pull/8693#discussion_r686489336




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to