junrushao1994 commented on a change in pull request #7765: URL: https://github.com/apache/tvm/pull/7765#discussion_r604371963
########## File path: src/tir/schedule/state.cc ########## @@ -0,0 +1,863 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./utils.h" + +namespace tvm { +namespace tir { + +template <class K, class V> +using SMap = std::unordered_map<K, V, ObjectPtrHash, ObjectPtrEqual>; + +/**************** Utility functions ****************/ + +/*! + * \brief Set the `StmtSRefNode::seq_index` field for stmt + * \param self The schedule class + * \param stmt The statement, or the realize node of the statement whose sref to be set + * \param seq_index The seq_index to be set + * \note The method is NOP for statements that are not scheduleable, i.e. not For or Block + */ +void SetSeqIndex(ScheduleStateNode* self, const Stmt& stmt, int seq_index) { + if (const auto* realize = stmt.as<BlockRealizeNode>()) { + const BlockNode* block = realize->block.get(); + ICHECK(self->stmt2ref.count(block)); + self->stmt2ref.at(block)->seq_index = seq_index; + } else if (const auto* block = stmt.as<BlockNode>()) { + ICHECK(self->stmt2ref.count(block)); + self->stmt2ref.at(block)->seq_index = seq_index; + } else if (const auto* loop = stmt.as<ForNode>()) { + ICHECK(self->stmt2ref.count(loop)); + self->stmt2ref.at(loop)->seq_index = seq_index; + } else { + // do nothing + } +} + +/*! + * \brief Update seq_index of the children of a SeqStmt + * \param self The schedule class + * \param seq_stmt The SeqStmt whose children need updating + */ +void SetSeqIndexInChildren(ScheduleStateNode* self, const SeqStmtNode* seq_stmt) { + int i = 0; + for (const Stmt& stmt : seq_stmt->seq) { + SetSeqIndex(self, stmt, i); + ++i; + } +} + +/*! + * \brief Update the sref information on the schedule class, as well as the statement of sref itself + * More specifically, update + * `sref->stmt` to `new_stmt` + * `self->stmt2ref`, remove the old statement that sref points to, and add the new statement + * \param self The schedule class to be updated + * \param sref The sref to be updated + * \param new_stmt The statement that replaces the statement inside the sref + */ +void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new_stmt) { + ICHECK(new_stmt->IsInstance<BlockNode>() || new_stmt->IsInstance<ForNode>()); + const StmtNode* old_stmt = sref->stmt; + ICHECK_NE(new_stmt, old_stmt); + self->stmt2ref[new_stmt] = GetRef<StmtSRef>(sref); + self->stmt2ref.erase(sref->stmt); + sref->stmt = new_stmt; +} + +/*! + * \brief Get PrimFunc and GlobalVar that the root block belongs to + * \param mod The IRModule + * \param root_block The root block of the PrimFunc + * \param result_g_var The result GlobalVar + * \return The result PrimFunc where the root block belongs to + * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write + */ +const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, + GlobalVar* result_g_var) { Review comment: I brought the discussion here: https://github.com/apache/tvm/pull/7765#issuecomment-810515117 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
