This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e02ea74 Add DictAttrs to IRModule and refactor DictAttrs utility
functions (#8750)
e02ea74 is described below
commit e02ea7430589fa345ab4472f02511ae8d6c08dea
Author: Lily Orth-Smith <[email protected]>
AuthorDate: Mon Aug 16 22:44:59 2021 -0700
Add DictAttrs to IRModule and refactor DictAttrs utility functions (#8750)
* Add DictAttrs to IRModuleNode
Move GetAttrs to be a member of DictAttrs
Generalize WithAttrs to work with IRModule and move to attrs.h
Change func->GetAttr to func->attrs.GetAttr
* lint
* Fix documentation
* fix typo
* Another typo!
* Revert GetAttrs to ->attrs.GetAttrs change
* Didn't mean to revert these
* Revert a few more things
* Add GetAttrs to IRModuleNode
---
include/tvm/ir/attrs.h | 108 ++++++++++++++++++++++++++++++++++++++++++++++
include/tvm/ir/function.h | 57 ++----------------------
include/tvm/ir/module.h | 54 +++++++++++++++++++++++
3 files changed, 165 insertions(+), 54 deletions(-)
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index da7bc12..fa18610 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -214,6 +214,7 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown)
final;
Array<AttrFieldInfo> ListFieldInfo() const final;
+
// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
@@ -232,6 +233,72 @@ class DictAttrs : public Attrs {
*/
TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);
+ // Utils for accessing attributes
+ // This needs to be on DictAttrs, not DictAttrsNode because we return the
default
+ // value if DictAttrsNode is not defined.
+ /*!
+ * \brief Get a function attribute.
+ *
+ * \param attr_key The attribute key.
+ * \param default_value The default value if the key does not exist,
defaults to nullptr.
+ *
+ * \return The result
+ *
+ * \tparam TOBjectRef the expected object type.
+ * \throw Error if the key exists but the value does not match TObjectRef
+ *
+ * \code
+ *
+ * void GetAttrExample(const BaseFunc& f) {
+ * auto value = f->attrs.GetAttr<Integer>("AttrKey", 0);
+ * }
+ *
+ * \endcode
+ */
+ template <typename TObjectRef>
+ Optional<TObjectRef> GetAttr(
+ const std::string& attr_key,
+ Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr))
const {
+ static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
+ "Can only call GetAttr with ObjectRef types.");
+ if (!defined()) return default_value;
+ const DictAttrsNode* node = this->as<DictAttrsNode>();
+
+ auto it = node->dict.find(attr_key);
+ if (it != node->dict.end()) {
+ return Downcast<Optional<TObjectRef>>((*it).second);
+ } else {
+ return default_value;
+ }
+ }
+ // variant that uses TObjectRef to enable implicit conversion to default
value.
+ template <typename TObjectRef>
+ Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef
default_value) const {
+ return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
+ }
+ /*!
+ * \brief Check whether the function has an non-zero integer attr.
+ *
+ * This function can be used to check whether an optional
+ * attribute mark(e.g. inline) exists.
+ *
+ * \param attr_key The key to the attribute.
+ * \return The check result.
+ *
+ * \code
+ *
+ * void HasNonzeroAttrExample(const BaseFunc& f) {
+ * if (f->HasNonzeroAttr(attr::kInline)) {
+ * // inline the function.
+ * }
+ * }
+ *
+ * \endcode
+ */
+ bool HasNonzeroAttr(const std::string& attr_key) const {
+ return GetAttr<Integer>(attr_key, 0) != 0;
+ }
+
TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};
@@ -249,6 +316,47 @@ inline TAttrs AttrsWithDefaultValues() {
return TAttrs(n);
}
+/*!
+ * \brief Copy the function or module, but overrides
+ * the attribute value key with the value.
+ *
+ * \param input The thing to annotate (BaseFunc or IRModule)
+ * \param attr_key The attribute key.
+ * \param attr_value The value attribute value.
+ *
+ * \tparam TFunc The corresponding function or module type.
+ *
+ * \returns The new function or module with updated attributes.
+ *
+ * \note This function performs copy on write optimization for func and module.
+ * If we move a uniquely referenced func or module into WithAttr,
+ * then no additional copy will be performed.
+ *
+ * This is also why we make it as a function instead of a member function
+ * and why we pass by value in the first argument.
+ *
+ * \code
+ *
+ * // Recommended way to trigger copy on write
+ * func = WithAttr(std::move(func), "key1", value1);
+ * func = WithAttr(std::move(func), "key2", value2);
+ *
+ * \endcode
+ */
+template <typename TFunc>
+inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef
attr_value) {
+ using TNode = typename TFunc::ContainerType;
+ static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
+ TNode* node = input.CopyOnWrite();
+ if (node->attrs.defined()) {
+ node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
+ } else {
+ Map<String, ObjectRef> dict = {{attr_key, attr_value}};
+ node->attrs = DictAttrs(dict);
+ }
+ return input;
+}
+
// Namespace containing detail implementations
namespace detail {
using runtime::TVMArgValue;
diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h
index 09c074c..13b984d 100644
--- a/include/tvm/ir/function.h
+++ b/include/tvm/ir/function.h
@@ -102,21 +102,14 @@ class BaseFuncNode : public RelayExprNode {
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr))
const {
- static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
- "Can only call GetAttr with ObjectRef types.");
- if (!attrs.defined()) return default_value;
- auto it = attrs->dict.find(attr_key);
- if (it != attrs->dict.end()) {
- return Downcast<Optional<TObjectRef>>((*it).second);
- } else {
- return default_value;
- }
+ return attrs.GetAttr(attr_key, default_value);
}
// variant that uses TObjectRef to enable implicit conversion to default
value.
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef
default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}
+
/*!
* \brief Check whether the function has an non-zero integer attr.
*
@@ -136,9 +129,7 @@ class BaseFuncNode : public RelayExprNode {
*
* \endcode
*/
- bool HasNonzeroAttr(const std::string& attr_key) const {
- return GetAttr<Integer>(attr_key, 0) != 0;
- }
+ bool HasNonzeroAttr(const std::string& attr_key) const { return
attrs.HasNonzeroAttr(attr_key); }
static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
@@ -155,48 +146,6 @@ class BaseFunc : public RelayExpr {
};
/*!
- * \brief Create a new function that copies func, but overrides
- * the attribute value key with the value.
- *
- * \param func The input function.
- * \param attr_key The attribute key.
- * \param attr_value The value attribute value.
- *
- * \tparam TFunc The corresponding function type.
- *
- * \returns The new function with updated attributes.
- *
- * \note This function performs copy on write optimization for func.
- * If we move a uniquely referenced func into WithAttr,
- * then no additional copy will be performed.
- *
- * This is also why we make it as a function instead of a member function
- * and why we pass by value in the first argument.
- *
- * \code
- *
- * // Recommended way to trigger copy on write
- * func = WithAttr(std::move(func), "key1", value1);
- * func = WithAttr(std::move(func), "key2", value2);
- *
- * \endcode
- */
-template <typename TFunc,
- typename = typename std::enable_if<std::is_base_of<BaseFunc,
TFunc>::value>::type>
-inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef
attr_value) {
- using TNode = typename TFunc::ContainerType;
- static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
- TNode* node = func.CopyOnWrite();
- if (node->attrs.defined()) {
- node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
- } else {
- Map<String, ObjectRef> dict = {{attr_key, attr_value}};
- node->attrs = DictAttrs(dict);
- }
- return func;
-}
-
-/*!
* \brief Generic attribute names that can be attached to any function.
*
* \sa tvm::tir::attr, tvm::relay::attr
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 638f132..9ca27ec 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -58,6 +58,60 @@ class IRModuleNode : public Object {
Map<GlobalTypeVar, TypeData> type_definitions;
/*! \brief The source map for the module. */
parser::SourceMap source_map;
+ /* \brief Additional attributes storing meta-data about the module. */
+ DictAttrs attrs;
+
+ /*!
+ * \brief Get a module attribute.
+ *
+ * \param attr_key The attribute key.
+ * \param default_value The default value if the key does not exist,
defaults to nullptr.
+ *
+ * \return The result
+ *
+ * \tparam TOBjectRef the expected object type.
+ * \throw Error if the key exists but the value does not match TObjectRef
+ *
+ * \code
+ *
+ * void GetAttrExample(const IRModule& mod) {
+ * auto value = f->GetAttr<Integer>("AttrKey", 0);
+ * }
+ *
+ * \endcode
+ */
+ template <typename TObjectRef>
+ Optional<TObjectRef> GetAttr(
+ const std::string& attr_key,
+ Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr))
const {
+ return attrs.GetAttr(attr_key, default_value);
+ }
+ // variant that uses TObjectRef to enable implicit conversion to default
value.
+ template <typename TObjectRef>
+ Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef
default_value) const {
+ return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
+ }
+
+ /*!
+ * \brief Check whether the module has an non-zero integer attr.
+ *
+ * This function can be used to check whether an optional
+ * attribute mark(e.g. inline) exists.
+ *
+ * \param attr_key The key to the attribute.
+ * \return The check result.
+ *
+ * \code
+ *
+ * void HasNonzeroAttrExample(const IRModule& mod) {
+ * if (mod->HasNonzeroAttr(attr::kInline)) {
+ * // inline the function.
+ * }
+ * }
+ *
+ * \endcode
+ */
+ bool HasNonzeroAttr(const std::string& attr_key) const { return
attrs.HasNonzeroAttr(attr_key); }
IRModuleNode() : source_map() {}