junrushao1994 commented on a change in pull request #8693: URL: https://github.com/apache/tvm/pull/8693#discussion_r686373290
########## File path: src/tir/schedule/primitive/block_annotate.cc ########## @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +class StorageAlignAxisOutOfRangeError : public ScheduleError { + public: + explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int axis) + : mod_(std::move(mod)), buffer_(std::move(buffer)), axis_(axis) {} + + String FastErrorString() const final { + return "ScheduleError: The input `axis` is out of range. It is required to be in range " + "[-ndim, ndim) where `ndim` is the number of dimensions of the buffer to set " + "storage alignment."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + int ndim = static_cast<int>(buffer_->shape.size()); + os << "The buffer to set storage alignment of, " << buffer_->name << ", has " << ndim + << " dimension(s), so `axis` is required to be in [" << -(ndim) << ", " << ndim + << ") for storage_align. However, the input `axis` is " << axis_ + << ", which is out of the expected range."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {}; } + + static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int axis) { + int ndim = static_cast<int>(buffer->shape.size()); + if (axis < -ndim || axis >= ndim) { + throw StorageAlignAxisOutOfRangeError(mod, buffer, axis); + } + // If axis is negative, convert it to a non-negative one. + if (axis < 0) { + axis += ndim; + } + return axis; + } + + private: + IRModule mod_; + Buffer buffer_; + int axis_; +}; + +/*! + * \brief Find the defining site of the buffer in the given block and its ancestors + * \param block_sref The block sref + * \param buffer The buffer + * \return The defining site of the buffer and whether the buffer is allocated (otherwise the + * buffer is from match_buffer). + */ +std::pair<StmtSRef, bool> GetBufferDefiningSite(const StmtSRef& block_sref, const Buffer& buffer) { Review comment: Use `Optional` to explicitly indicate the case that the buffer is defined in `PrimFuncNode::buffer_map` ```suggestion std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_sref, const Buffer& buffer) { ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
