liangW-intellif commented on code in PR #13033: URL: https://github.com/apache/tvm/pull/13033#discussion_r997974823
########## src/tir/schedule/primitive/rolling_buffer.cc: ########## @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <functional> + +#include "../ir_comparator.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +namespace { + +struct RollingBufferInfo { + Buffer old_buffer; + Buffer new_buffer; + int rolling_axis; + int rolling_extent; + std::vector<int> axis_overlaps; + std::vector<Optional<Var>> axis_iter_vars; + /*! \brief The map used for ScheduleStateNode::Replace. */ + Map<Block, Block> block_reuse; +}; + +BufferRegion GetRelaxedBufferRegion(const BlockRealize& realize, const BufferRegion& buffer_region, + const Map<Var, arith::IntSet>& dom_map) { + Array<arith::IntSet> relaxed_intsets = + arith::EvalSet(Substitute(buffer_region->region, GetBindings(realize)), dom_map); + Region relaxed_region; + relaxed_region.reserve(relaxed_intsets.size()); + for (size_t i = 0; i < relaxed_intsets.size(); ++i) { + relaxed_region.push_back( + relaxed_intsets[i].CoverRange(Range::FromMinExtent(0, buffer_region->buffer->shape[i]))); + } + return BufferRegion(buffer_region->buffer, relaxed_region); +} + +class RollingBufferMatchError : public ScheduleError { + public: + RollingBufferMatchError(IRModule mod, Block block, BufferRegion buffer_region) + : mod_(mod), block_(block), buffer_region_(buffer_region) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer expect the buffer region to have at least one dimention" + "matching the rolling pattern such as: hh.outer * stride + hh.inner"; + } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The target buffer " << buffer_region_->buffer->name << " with region " + << buffer_region_->region + << " should have at least one dimension range that matches a rolling pattern " + "such as hh.outer * stride + hh.inner. "; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + BufferRegion buffer_region_; +}; + +class RollingBufferInsertionError : public ScheduleError { + public: + RollingBufferInsertionError(IRModule mod, Buffer buffer, Block block) + : mod_(mod), buffer_(std::move(buffer)), block_(block) {} + String FastErrorString() const final { + return "ScheduleError: rolling_buffer injection is invalid, the lca of the access " + "location of the target buffer is not a for loop. "; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "rolling_buffer injection is invalid. The block {0} should be tiled so that " + << "the lca of the access location of the target buffer " << buffer_->name + << " is a for loop. "; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Buffer buffer_; + Block block_; +}; + +class RollingBufferInfoCollector { + public: + static RollingBufferInfo CheckAndGetRollingBufferInfo(const IRModule& mod, + const StmtSRef& block_sref, + const BufferRegion& buffer_region) { + RollingBufferInfoCollector collector; + if (!collector.MatchRollingBuffer(block_sref, buffer_region)) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); + throw RollingBufferMatchError(mod, GetRef<Block>(block), buffer_region); + } + return collector.info_; + } + + private: + bool MatchRollingBuffer(const StmtSRef& block_sref, const BufferRegion& buffer_region) { + const Buffer& buffer = buffer_region->buffer; + const Region& region = buffer_region->region; + + std::vector<Optional<Var>> bound_iter_vars; + std::vector<int> bound_overlaps; + auto stride = 0; + auto divisor = 1; + Optional<Var> iter_var; + for (auto bound : region) { + divisor = 1; + if (auto floor_div = bound->min.as<FloorDivNode>()) { Review Comment: Good idea! I updated this part with tir pattern matching, please take a look again. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
