zhiics commented on a change in pull request #5030: [RELAY] Added a AnnotatedRegion utility class URL: https://github.com/apache/incubator-tvm/pull/5030#discussion_r396646449
########## File path: src/relay/analysis/annotated_region_set.h ########## @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/pass/annotated_region_set.h + * \brief Define data structures to extract and manipulate regions from + * a relay function. Regions are denoted by region_begin and region_end + * annotations that exist on all the input and output edges of the region. + */ + +#ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ +#define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ + +#include <tvm/relay/analysis.h> +#include <tvm/relay/attrs/annotation.h> +#include <tvm/relay/expr.h> +#include <tvm/ir/error.h> +#include <tvm/relay/expr_functor.h> +#include <tvm/relay/transform.h> + +#include <string> +#include <unordered_set> +#include <utility> +#include <vector> +#include <list> + +namespace tvm { +namespace relay { + +class AnnotatedRegion; +class AnnotatedRegionSet; + +class AnnotatedRegionNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("id", &id); + Array<Expr> nodes_array(nodes.begin(), nodes.end()); + v->Visit("nodes", &nodes_array); + Array<Expr> args_array(ins.begin(), ins.end()); + v->Visit("args", &args_array); + Array<Expr> rets_array(outs.begin(), outs.end()); + v->Visit("rets", &rets_array); + } + + /*! \brief Get the region ID. */ + int GetID() const { + return id; + } + + /*! \brief Get the region's inputs. */ + std::list<Expr> GetInputs() const { + return ins; + } + + /*! \brief Get the region's outputs. */ + std::list<Expr> GetOutputs() const { + return outs; + } + + /*! \brief Get the region's nodes. */ + std::unordered_set<Expr, ObjectHash, ObjectEqual> GetNodes() const { + return nodes; + } + + static constexpr const char* _type_key = "relay.AnnotatedRegion"; + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object); + + protected: + /*! \brief The region ID. */ + int id{-1}; + /*! \brief The inputs to this region. */ + std::list<Expr> ins; + /*! \brief The outputs of this region */ + std::list<Expr> outs; + /*! \brief Nodes in this region. */ + std::unordered_set<Expr, ObjectHash, ObjectEqual> nodes; + + friend class AnnotatedRegionSet; + friend class AnnotatedRegionSetNode; +}; + +/*! + * \brief An object to hold the properties of a region as used by the + * AnnotatedRegionSet class. This should be considered read-only. +*/ +class AnnotatedRegion : public ObjectRef { + public: + AnnotatedRegion() { + auto n = make_object<AnnotatedRegionNode>(); + data_ = std::move(n); + } + + /*! + * \brief Construct from an object pointer. + * \param n The object pointer. + */ + explicit AnnotatedRegion(ObjectPtr<Object> n) : ObjectRef(n) {} + + /*! \return Mutable pointers to the node. */ + AnnotatedRegionNode* operator->() const { + auto* ptr = get_mutable(); + CHECK(ptr != nullptr); + return static_cast<AnnotatedRegionNode*>(ptr); + } +}; + +class AnnotatedRegionSetNode : public Object { + using UnorderedRegionSet = + std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>; + // Create iterator alias for a RegionSet object. + using iterator = UnorderedRegionSet::iterator; + using const_iterator = UnorderedRegionSet::const_iterator; + + public: + /*! \brief Default constructor. */ + AnnotatedRegionSetNode() = default; + + /*! \return The begin iterator */ + iterator begin() { + return regions_.begin(); + } + /*! \return The end iterator */ + iterator end() { + return regions_.end(); + } + /*! \return The const begin iterator */ + const_iterator begin() const { + return regions_.begin(); + } + /*! \return The const end iterator */ + const_iterator end() const { + return regions_.end(); + } + + /*! + * \brief Get the region that an expression belongs to. + * + * \param expr Which expr to get the region for. + * + * \return A pointer to the region, nullptr if the expression + * doesn't belong to a region. + */ + AnnotatedRegion GetRegion(const Expr& expr) const; + + /*! + * \brief Merge region 1 into region 2. + * + * \param region1 A region to merge. + * \param region2 A region to merge. + */ + void MergeRegions(AnnotatedRegion region1, AnnotatedRegion region2); + + void VisitAttrs(AttrVisitor* v) { + Array<AnnotatedRegion> regions_array(regions_.begin(), regions_.end()); + v->Visit("regions", ®ions_array); + } + + static constexpr const char* _type_key = "relay.AnnotatedRegionSet"; + TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionSetNode, Object); + + private: + /*! + * \brief Add an expression to a region. + * + * \param region The region to add the expression to. + * \param expr The expression. + */ + void AddToRegion(AnnotatedRegion region, const Expr& expr); + + /*! + * \brief Make a new region. + * + * \return The new region. + */ + AnnotatedRegion MakeRegion(); + + std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual> regions_; + + friend class AnnotatedRegionSet; +}; + +/*! + * \brief A class to hold a set of regions produced from a relay expression + * that contains 'region_begin' and 'region_end' style annotations. The + * regions should be disjoint. The class provides both a method to construct + * the region set of a given relay expression as well as additional methods + * to update and query regions. + */ +class AnnotatedRegionSet : public ObjectRef { + using UnorderedRegionSet = + std::unordered_set<AnnotatedRegion, ObjectHash, ObjectEqual>; + // Create iterator alias for a RegionSet object. + using iterator = UnorderedRegionSet::iterator; + using const_iterator = UnorderedRegionSet::const_iterator; + + public: + AnnotatedRegionSet() { + auto n = make_object<AnnotatedRegionSetNode>(); + data_ = std::move(n); + } + + /*! + * \brief Construct from an object pointer. + * + * \param n The object pointer. + */ + explicit AnnotatedRegionSet(ObjectPtr<Object> n) : ObjectRef(n) {} + + /*! \return The begin iterator. */ + iterator begin() { + auto* n = operator->(); + CHECK(n); + return n->begin(); + } + /*! \return The end iterator. */ + iterator end() { + auto* n = operator->(); + CHECK(n); + return n->end(); + } + /*! \return The begin iterator. */ + const_iterator begin() const { + const auto* n = operator->(); + CHECK(n); + return n->begin(); + } + /*! \return The end iterator. */ + const_iterator end() const { + const auto *n = operator->(); + CHECK(n); + return n->end(); + } + + /*! \return mutable pointers to the node. */ Review comment: Indentation ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
