junrushao1994 commented on a change in pull request #7847: URL: https://github.com/apache/tvm/pull/7847#discussion_r613641015
########## File path: include/tvm/tir/schedule/schedule.h ########## @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_ +#define TVM_TIR_SCHEDULE_SCHEDULE_H_ + +#include <tvm/tir/schedule/state.h> + +namespace tvm { +namespace tir { + +/**************** Random variable: BlockRV ****************/ + +/*! \brief A random variable that evaluates to a TensorIR block */ +class BlockRVNode : public runtime::Object { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "tir.BlockRV"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockRVNode, runtime::Object); +}; + +/*! + * \brief Managed reference to BlockRVNode + * \sa BlockRVNode + */ +class BlockRV : public runtime::ObjectRef { + public: + /*! \brief Constructor */ + TVM_DLL BlockRV(); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BlockRV, runtime::ObjectRef, BlockRVNode); +}; + +/**************** Random variable: LoopRV ****************/ + +/*! \brief A random variable that evaluates to a TensorIR for loop */ +class LoopRVNode : public runtime::Object { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "tir.LoopRV"; + TVM_DECLARE_FINAL_OBJECT_INFO(LoopRVNode, runtime::Object); +}; + +/*! + * \brief Managed reference to LoopRVNode + * \sa LoopRVNode + */ +class LoopRV : public runtime::ObjectRef { + public: + /*! \brief Constructor */ + TVM_DLL LoopRV(); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopRV, runtime::ObjectRef, LoopRVNode); +}; + +/**************** Random variable: IntRV ****************/ + +/*! \brief An integer random variable */ +using IntRV = PrimExpr; + +using IntRVNode = PrimExprNode; + +/**************** The schedule class ****************/ + +class Schedule; + +/*! + * \brief The user-facing abstract schedule class + */ +class ScheduleNode : public runtime::Object { + friend class Schedule; + + public: + virtual ~ScheduleNode() = default; + + static constexpr const char* _type_key = "tir.Schedule"; + TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object); + + public: + /*! \brief Take the IRModule out of the schedule */ + virtual IRModule mod() const { return state()->mod; } + /*! \return The internal state of scheduling */ + virtual ScheduleState state() const = 0; + /*! + * \brief Returns a copy of the schedule, including both the state and the symbol table, + * guaranteeing that + * 1) SRef tree is completely reconstructed; + * 2) The IRModule being scheduled is untouched; + * 3) All the random variables are valid in the copy, pointing to the correpsonding sref + * reconstructed + */ + virtual Schedule Copy() const = 0; + /*! + * \brief Seed the randomness + * \param seed The new random seed, -1 if use device random, otherwise non-negative + */ + virtual void Seed(int64_t seed = -1) { + LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed"; + } + + public: + /******** Lookup/Remove random variables ********/ + /*! + * \brief Get the block corresponding to the specific BlockRV + * \param block_rv The BlockRV to be looked up + * \return The corresponding block + */ + virtual Block Get(const BlockRV& block_rv) const = 0; + /*! + * \brief Get the for loop corresponding to the specific LoopRV + * \param loop_rv The LoopRV to be looked up + * \return The corresponding for loop + */ + virtual For Get(const LoopRV& loop_rv) const = 0; + /*! + * \brief Get the value corresponding to the specific random variable + * \param var_rv The random variable to be looked up + * \return The corresponding value + */ + virtual int64_t Get(const IntRV& var_rv) const = 0; + /*! + * \brief Get the block sref corresponding to the specific BlockRV + * \param block_rv The BlockRV to be looked up + * \return The corresponding block sref + */ + virtual StmtSRef GetSRef(const BlockRV& block_rv) const = 0; + /*! + * \brief Get the loop sref corresponding to the specific LoopRV + * \param loop_rv The LoopRV to be looked up + * \return The corresponding loop sref + */ + virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; + /*! + * \brief Get the block/loop sref corresponding to the specific statement + * \param stmt The statement to be looked up + * \return The corresponding block/loop sref + */ + virtual StmtSRef GetSRef(const StmtNode* stmt) const; + /*! + * \brief Get the block/loop sref corresponding to the specific statement + * \param stmt The statement to be looked up + * \return The corresponding block/loop sref + */ + StmtSRef GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); } + /*! + * \brief Remove a block random variable from the symbol table + * \param block_rv The random variable to be removed + */ + virtual void RemoveRV(const BlockRV& block_rv) = 0; + /*! + * \brief Remove a loop random variable from the symbol table + * \param block_rv The random variable to be removed + */ + virtual void RemoveRV(const LoopRV& loop_rv) = 0; + /*! + * \brief Remove an integer random variable from the symbol table + * \param block_rv The random variable to be removed + */ + virtual void RemoveRV(const IntRV& var_rv) = 0; + + public: + /******** Block/Loop relation ********/ + /*! + * \brief Retrieve a block in a specific function with its name Review comment: Yeah we have a debate in the previous PR (https://github.com/apache/tvm/pull/7765), and our conclusion is that we need to enforce name uniqueness -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
