MasterJH5574 commented on a change in pull request #8943: URL: https://github.com/apache/tvm/pull/8943#discussion_r704577677
########## File path: src/tir/schedule/primitive/compute_at.cc ########## @@ -0,0 +1,584 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +/******** Error Classes ********/ + +/*! + * \brief An error raised when not all required blocks are under the given loop. + * \tparam is_consumer Indicates if all the required blocks are consumers or producers + */ +template <bool is_consumer> +class NotAllRequiredBlocksAreVisitedError : public ScheduleError { + public: + explicit NotAllRequiredBlocksAreVisitedError(IRModule mod, int num_not_visited, + const Array<StmtSRef>& required) + : mod_(mod), num_not_visited_(num_not_visited) { + required_.reserve(required.size()); + for (const StmtSRef& block_sref : required) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + required_.push_back(GetRef<Block>(block)); + } + } + + String FastErrorString() const final { + return "ScheduleError: Not all required blocks are under the loop scope"; + } + + String DetailRenderTemplate() const final { + String relation = is_consumer ? "consumer(s)" : "producer(s)"; + std::ostringstream os; + os << "The primitive requires all the " << relation + << " of the given block to be present under the target loop. However, there are " + << num_not_visited_ << " " << relation << " not satisfying the constraint. List of the " + << relation << ":"; + for (int i = 0, n = required_.size(); i < n; ++i) { + os << "{" << i << "}"; + } + return os.str(); + } + + IRModule mod() const final { return mod_; } + + Array<ObjectRef> LocationsOfInterest() const final { + return {required_.begin(), required_.end()}; + } + + private: + IRModule mod_; + int num_not_visited_; + Array<Block> required_; +}; + +/*! + * \brief An error raised when the given block is not in the same block scope as the given loop, + * or the given loop is the ancestor of the given block. + */ +class NotInSameScopeError : public ScheduleError { + public: + static void CheckAndBindLoopDomain(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& loop_sref, const StmtSRef& scope_root_sref, + arith::Analyzer* analyzer) { + for (const StmtSRefNode* p = loop_sref.get();; p = p->parent) { + if (const ForNode* loop = p->StmtAs<ForNode>()) { + analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } else if (p != scope_root_sref.get()) { + throw NotInSameScopeError(self->mod, block_sref, loop_sref); + } else { + break; + } + } + for (const StmtSRefNode* p = block_sref->parent; p != scope_root_sref.get(); p = p->parent) { + if (p == loop_sref.get()) { + throw NotInSameScopeError(self->mod, block_sref, loop_sref); + } + } + } + + String FastErrorString() const final { + return "ScheduleError: Expected the block and loop to be under the same block scope, and loop " + "not to be the ancestor of block"; + } + String DetailRenderTemplate() const final { + return "ScheduleError: Expected the block {0} and loop {1} to be under the same block scope, " + "and loop not to be the ancestor of block"; + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_, loop_}; } + + private: + explicit NotInSameScopeError(IRModule mod, const StmtSRef& block_sref, const StmtSRef& loop_sref) + : mod_(mod), + block_(GetRef<Block>(block_sref->StmtAs<BlockNode>())), + loop_(GetRef<For>(loop_sref->StmtAs<ForNode>())) {} + + IRModule mod_; + Block block_; + For loop_; +}; + +/******** Helper Functions/Classes ********/ + +/*! + * \brief Find a point where the block can be inserted under the loop + * \tparam require_all_producers_visited Requires all producer blocks to be present under the loop + * \tparam require_all_consumers_visited Requires all consumer blocks to be present under the loop + * \param self The schedule state + * \param subtrees The subtrees under the loop, among which the insertion points are sought + * \param producer_srefs The producer blocks + * \param consumer_srefs The consumer blocks + * \param block2realize A cache that maps a block to its realize + * \return The last position the new block can be inserted onto, and the + * producer-consumer-relationship is still satisfied. + */ +template <bool require_all_producers_visited, bool require_all_consumers_visited> +int FindInsertionPoint( + const ScheduleState& self, const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs, + const Array<StmtSRef>& consumer_srefs, + std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize) { + ProducerConsumerSplit split = + ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize); + // Step 1. Check if all the producers are visited in the subtrees, if required to + if (require_all_producers_visited) { + int num_producers = producer_srefs.size(); + if (split.n_producers_visited < num_producers) { + throw NotAllRequiredBlocksAreVisitedError<false>( + self->mod, num_producers - split.n_producers_visited, producer_srefs); + } + } + // Step 2. Check if all the consumers are visited in the subtrees, if required to + if (require_all_consumers_visited) { + int num_consumers = consumer_srefs.size(); + if (split.n_consumers_visited < num_consumers) { + throw NotAllRequiredBlocksAreVisitedError<true>( + self->mod, num_consumers - split.n_consumers_visited, consumer_srefs); + } + } + // Step 3. Check if there is at least one index of the position can be inserted into + // The valid indices are: (last_producer_position, first_consumer_position] + ICHECK(split.last_producer_position < split.first_consumer_position); + // Step 4. Return the last valid insertion point + return split.first_consumer_position; +} + +/*! + * \brief A helper to reconstruct the block scope where the given block is moved under the given + * loop, and the given block's induced loop nest is regenerated to satisfy the required region. + */ +class ScopeReconstructor : private StmtMutator { + public: + explicit ScopeReconstructor(Block scope_root, Block block, For loop) + : scope_root_(scope_root), block_(block), loop_(loop) {} + + using StmtMutator::operator(); Review comment: What does this line mean? ########## File path: src/tir/schedule/primitive/compute_at.cc ########## @@ -0,0 +1,584 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +/******** Error Classes ********/ + +/*! + * \brief An error raised when not all required blocks are under the given loop. + * \tparam is_consumer Indicates if all the required blocks are consumers or producers + */ +template <bool is_consumer> +class NotAllRequiredBlocksAreVisitedError : public ScheduleError { + public: + explicit NotAllRequiredBlocksAreVisitedError(IRModule mod, int num_not_visited, + const Array<StmtSRef>& required) + : mod_(mod), num_not_visited_(num_not_visited) { + required_.reserve(required.size()); + for (const StmtSRef& block_sref : required) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + required_.push_back(GetRef<Block>(block)); + } + } + + String FastErrorString() const final { + return "ScheduleError: Not all required blocks are under the loop scope"; + } + + String DetailRenderTemplate() const final { + String relation = is_consumer ? "consumer(s)" : "producer(s)"; + std::ostringstream os; + os << "The primitive requires all the " << relation + << " of the given block to be present under the target loop. However, there are " + << num_not_visited_ << " " << relation << " not satisfying the constraint. List of the " + << relation << ":"; + for (int i = 0, n = required_.size(); i < n; ++i) { + os << "{" << i << "}"; + } + return os.str(); + } + + IRModule mod() const final { return mod_; } + + Array<ObjectRef> LocationsOfInterest() const final { + return {required_.begin(), required_.end()}; + } + + private: + IRModule mod_; + int num_not_visited_; + Array<Block> required_; +}; + +/*! + * \brief An error raised when the given block is not in the same block scope as the given loop, + * or the given loop is the ancestor of the given block. + */ +class NotInSameScopeError : public ScheduleError { + public: + static void CheckAndBindLoopDomain(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& loop_sref, const StmtSRef& scope_root_sref, + arith::Analyzer* analyzer) { + for (const StmtSRefNode* p = loop_sref.get();; p = p->parent) { + if (const ForNode* loop = p->StmtAs<ForNode>()) { + analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } else if (p != scope_root_sref.get()) { + throw NotInSameScopeError(self->mod, block_sref, loop_sref); + } else { + break; + } + } + for (const StmtSRefNode* p = block_sref->parent; p != scope_root_sref.get(); p = p->parent) { + if (p == loop_sref.get()) { + throw NotInSameScopeError(self->mod, block_sref, loop_sref); + } + } + } + + String FastErrorString() const final { + return "ScheduleError: Expected the block and loop to be under the same block scope, and loop " + "not to be the ancestor of block"; + } + String DetailRenderTemplate() const final { + return "ScheduleError: Expected the block {0} and loop {1} to be under the same block scope, " + "and loop not to be the ancestor of block"; + } + IRModule mod() const final { return mod_; } + Array<ObjectRef> LocationsOfInterest() const final { return {block_, loop_}; } + + private: + explicit NotInSameScopeError(IRModule mod, const StmtSRef& block_sref, const StmtSRef& loop_sref) + : mod_(mod), + block_(GetRef<Block>(block_sref->StmtAs<BlockNode>())), + loop_(GetRef<For>(loop_sref->StmtAs<ForNode>())) {} + + IRModule mod_; + Block block_; + For loop_; +}; + +/******** Helper Functions/Classes ********/ + +/*! + * \brief Find a point where the block can be inserted under the loop + * \tparam require_all_producers_visited Requires all producer blocks to be present under the loop + * \tparam require_all_consumers_visited Requires all consumer blocks to be present under the loop + * \param self The schedule state + * \param subtrees The subtrees under the loop, among which the insertion points are sought + * \param producer_srefs The producer blocks + * \param consumer_srefs The consumer blocks + * \param block2realize A cache that maps a block to its realize + * \return The last position the new block can be inserted onto, and the + * producer-consumer-relationship is still satisfied. + */ Review comment: ```suggestion * \return The last position the new block can be inserted onto, and the * producer-consumer-relationship is still satisfied. * \throws ScheduleError If blabla */ ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
