This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e02ea74  Add DictAttrs to IRModule and refactor DictAttrs utility 
functions (#8750)
e02ea74 is described below

commit e02ea7430589fa345ab4472f02511ae8d6c08dea
Author: Lily Orth-Smith <[email protected]>
AuthorDate: Mon Aug 16 22:44:59 2021 -0700

    Add DictAttrs to IRModule and refactor DictAttrs utility functions (#8750)
    
    * Add DictAttrs to IRModuleNode
    
    Move GetAttrs to be a member of DictAttrs
    
    Generalize WithAttrs to work with IRModule and move to attrs.h
    
    Change func->GetAttr to func->attrs.GetAttr
    
    * lint
    
    * Fix documentation
    
    * fix typo
    
    * Another typo!
    
    * Revert GetAttrs to ->attrs.GetAttrs change
    
    * Didn't mean to revert these
    
    * Revert a few more things
    
    * Add GetAttrs to IRModuleNode
---
 include/tvm/ir/attrs.h    | 108 ++++++++++++++++++++++++++++++++++++++++++++++
 include/tvm/ir/function.h |  57 ++----------------------
 include/tvm/ir/module.h   |  54 +++++++++++++++++++++++
 3 files changed, 165 insertions(+), 54 deletions(-)

diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index da7bc12..fa18610 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -214,6 +214,7 @@ class DictAttrsNode : public BaseAttrsNode {
   void VisitNonDefaultAttrs(AttrVisitor* v) final;
   void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) 
final;
   Array<AttrFieldInfo> ListFieldInfo() const final;
+
   // type info
   static constexpr const char* _type_key = "DictAttrs";
   TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
@@ -232,6 +233,72 @@ class DictAttrs : public Attrs {
    */
   TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);
 
+  // Utils for accessing attributes
+  // This needs to be on DictAttrs, not DictAttrsNode because we return the 
default
+  // value if DictAttrsNode is not defined.
+  /*!
+   * \brief Get a function attribute.
+   *
+   * \param attr_key The attribute key.
+   * \param default_value The default value if the key does not exist, 
defaults to nullptr.
+   *
+   * \return The result
+   *
+   * \tparam TOBjectRef the expected object type.
+   * \throw Error if the key exists but the value does not match TObjectRef
+   *
+   * \code
+   *
+   *  void GetAttrExample(const BaseFunc& f) {
+   *    auto value = f->attrs.GetAttr<Integer>("AttrKey", 0);
+   *  }
+   *
+   * \endcode
+   */
+  template <typename TObjectRef>
+  Optional<TObjectRef> GetAttr(
+      const std::string& attr_key,
+      Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) 
const {
+    static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
+                  "Can only call GetAttr with ObjectRef types.");
+    if (!defined()) return default_value;
+    const DictAttrsNode* node = this->as<DictAttrsNode>();
+
+    auto it = node->dict.find(attr_key);
+    if (it != node->dict.end()) {
+      return Downcast<Optional<TObjectRef>>((*it).second);
+    } else {
+      return default_value;
+    }
+  }
+  // variant that uses TObjectRef to enable implicit conversion to default 
value.
+  template <typename TObjectRef>
+  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef 
default_value) const {
+    return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
+  }
+  /*!
+   * \brief Check whether the function has an non-zero integer attr.
+   *
+   * This function can be used to check whether an optional
+   * attribute mark(e.g. inline) exists.
+   *
+   * \param attr_key The key to the attribute.
+   * \return The check result.
+   *
+   * \code
+   *
+   *  void HasNonzeroAttrExample(const BaseFunc& f) {
+   *    if (f->HasNonzeroAttr(attr::kInline)) {
+   *      // inline the function.
+   *    }
+   *  }
+   *
+   * \endcode
+   */
+  bool HasNonzeroAttr(const std::string& attr_key) const {
+    return GetAttr<Integer>(attr_key, 0) != 0;
+  }
+
   TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
 };
@@ -249,6 +316,47 @@ inline TAttrs AttrsWithDefaultValues() {
   return TAttrs(n);
 }
 
+/*!
+ * \brief Copy the function or module, but overrides
+ *        the attribute value key with the value.
+ *
+ * \param input The thing to annotate (BaseFunc or IRModule)
+ * \param attr_key The attribute key.
+ * \param attr_value The value attribute value.
+ *
+ * \tparam TFunc The corresponding function or module type.
+ *
+ * \returns The new function or module with updated attributes.
+ *
+ * \note This function performs copy on write optimization for func and module.
+ *       If we move a uniquely referenced func or module into WithAttr,
+ *       then no additional copy will be performed.
+ *
+ *       This is also why we make it as a function instead of a member function
+ *       and why we pass by value in the first argument.
+ *
+ * \code
+ *
+ *  // Recommended way to trigger copy on write
+ *  func = WithAttr(std::move(func), "key1", value1);
+ *  func = WithAttr(std::move(func), "key2", value2);
+ *
+ * \endcode
+ */
+template <typename TFunc>
+inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef 
attr_value) {
+  using TNode = typename TFunc::ContainerType;
+  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
+  TNode* node = input.CopyOnWrite();
+  if (node->attrs.defined()) {
+    node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
+  } else {
+    Map<String, ObjectRef> dict = {{attr_key, attr_value}};
+    node->attrs = DictAttrs(dict);
+  }
+  return input;
+}
+
 // Namespace containing detail implementations
 namespace detail {
 using runtime::TVMArgValue;
diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h
index 09c074c..13b984d 100644
--- a/include/tvm/ir/function.h
+++ b/include/tvm/ir/function.h
@@ -102,21 +102,14 @@ class BaseFuncNode : public RelayExprNode {
   Optional<TObjectRef> GetAttr(
       const std::string& attr_key,
       Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) 
const {
-    static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
-                  "Can only call GetAttr with ObjectRef types.");
-    if (!attrs.defined()) return default_value;
-    auto it = attrs->dict.find(attr_key);
-    if (it != attrs->dict.end()) {
-      return Downcast<Optional<TObjectRef>>((*it).second);
-    } else {
-      return default_value;
-    }
+    return attrs.GetAttr(attr_key, default_value);
   }
   // variant that uses TObjectRef to enable implicit conversion to default 
value.
   template <typename TObjectRef>
   Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef 
default_value) const {
     return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
   }
+
   /*!
    * \brief Check whether the function has an non-zero integer attr.
    *
@@ -136,9 +129,7 @@ class BaseFuncNode : public RelayExprNode {
    *
    * \endcode
    */
-  bool HasNonzeroAttr(const std::string& attr_key) const {
-    return GetAttr<Integer>(attr_key, 0) != 0;
-  }
+  bool HasNonzeroAttr(const std::string& attr_key) const { return 
attrs.HasNonzeroAttr(attr_key); }
 
   static constexpr const char* _type_key = "BaseFunc";
   static constexpr const uint32_t _type_child_slots = 2;
@@ -155,48 +146,6 @@ class BaseFunc : public RelayExpr {
 };
 
 /*!
- * \brief Create a new function that copies func, but overrides
- *        the attribute value key with the value.
- *
- * \param func The input function.
- * \param attr_key The attribute key.
- * \param attr_value The value attribute value.
- *
- * \tparam TFunc The corresponding function type.
- *
- * \returns The new function with updated attributes.
- *
- * \note This function performs copy on write optimization for func.
- *       If we move a uniquely referenced func into WithAttr,
- *       then no additional copy will be performed.
- *
- *       This is also why we make it as a function instead of a member function
- *       and why we pass by value in the first argument.
- *
- * \code
- *
- *  // Recommended way to trigger copy on write
- *  func = WithAttr(std::move(func), "key1", value1);
- *  func = WithAttr(std::move(func), "key2", value2);
- *
- * \endcode
- */
-template <typename TFunc,
-          typename = typename std::enable_if<std::is_base_of<BaseFunc, 
TFunc>::value>::type>
-inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef 
attr_value) {
-  using TNode = typename TFunc::ContainerType;
-  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
-  TNode* node = func.CopyOnWrite();
-  if (node->attrs.defined()) {
-    node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
-  } else {
-    Map<String, ObjectRef> dict = {{attr_key, attr_value}};
-    node->attrs = DictAttrs(dict);
-  }
-  return func;
-}
-
-/*!
  * \brief Generic attribute names that can be attached to any function.
  *
  * \sa tvm::tir::attr, tvm::relay::attr
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 638f132..9ca27ec 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -58,6 +58,60 @@ class IRModuleNode : public Object {
   Map<GlobalTypeVar, TypeData> type_definitions;
   /*! \brief The source map for the module. */
   parser::SourceMap source_map;
+  /* \brief Additional attributes storing meta-data about the module. */
+  DictAttrs attrs;
+
+  /*!
+   * \brief Get a module attribute.
+   *
+   * \param attr_key The attribute key.
+   * \param default_value The default value if the key does not exist, 
defaults to nullptr.
+   *
+   * \return The result
+   *
+   * \tparam TOBjectRef the expected object type.
+   * \throw Error if the key exists but the value does not match TObjectRef
+   *
+   * \code
+   *
+   *  void GetAttrExample(const IRModule& mod) {
+   *    auto value = f->GetAttr<Integer>("AttrKey", 0);
+   *  }
+   *
+   * \endcode
+   */
+  template <typename TObjectRef>
+  Optional<TObjectRef> GetAttr(
+      const std::string& attr_key,
+      Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) 
const {
+    return attrs.GetAttr(attr_key, default_value);
+  }
+  // variant that uses TObjectRef to enable implicit conversion to default 
value.
+  template <typename TObjectRef>
+  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef 
default_value) const {
+    return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
+  }
+
+  /*!
+   * \brief Check whether the module has an non-zero integer attr.
+   *
+   * This function can be used to check whether an optional
+   * attribute mark(e.g. inline) exists.
+   *
+   * \param attr_key The key to the attribute.
+   * \return The check result.
+   *
+   * \code
+   *
+   *  void HasNonzeroAttrExample(const IRModule& mod) {
+   *    if (mod->HasNonzeroAttr(attr::kInline)) {
+   *      // inline the function.
+   *    }
+   *  }
+   *
+   * \endcode
+   */
+  bool HasNonzeroAttr(const std::string& attr_key) const { return 
attrs.HasNonzeroAttr(attr_key); }
 
   IRModuleNode() : source_map() {}
 

Reply via email to