This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e62169cc8a [Unity] Add Global info (#14132)
e62169cc8a is described below
commit e62169cc8afb5b9062a40072d6d44fc817408d77
Author: Hongyi Jin <[email protected]>
AuthorDate: Sun Feb 26 11:05:47 2023 -0500
[Unity] Add Global info (#14132)
---
include/tvm/ir/global_info.h | 80 ++++++++++++++++++++++
include/tvm/ir/module.h | 16 ++++-
include/tvm/script/ir_builder/base.h | 2 +
include/tvm/script/ir_builder/ir/frame.h | 7 ++
python/tvm/ir/__init__.py | 1 +
.../parser/ir/__init__.py => ir/global_info.py} | 28 ++++++--
python/tvm/ir/module.py | 30 +++++++-
python/tvm/script/ir_builder/base.py | 11 +++
python/tvm/script/ir_builder/ir/__init__.py | 9 ++-
python/tvm/script/ir_builder/ir/ir.py | 39 ++++++++++-
python/tvm/script/parser/ir/__init__.py | 4 +-
python/tvm/script/parser/ir/parser.py | 11 ++-
src/ir/global_info.cc | 32 +++++++++
src/ir/module.cc | 25 +++++--
src/script/ir_builder/base.cc | 6 ++
src/script/ir_builder/ir/frame.cc | 3 +-
src/script/ir_builder/ir/ir.cc | 24 +++++++
src/script/printer/ir/ir.cc | 15 ++++
tests/python/relax/test_tvmscript_parser.py | 42 ++++++++++++
19 files changed, 365 insertions(+), 20 deletions(-)
diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h
new file mode 100644
index 0000000000..65b5e0a3d2
--- /dev/null
+++ b/include/tvm/ir/global_info.h
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/global_info.h
+ * \brief GlobalInfo are globally static object that are referred by the IR
itself.
+ */
+
+#ifndef TVM_IR_GLOBAL_INFO_H_
+#define TVM_IR_GLOBAL_INFO_H_
+
+#include "tvm/ir/expr.h"
+
+namespace tvm {
+
+/*!
+ * \brief GlobalInfo are globally static object that are referred by the IR
itself.
+ * Base node for all global info that can appear in the IR
+ */
+class GlobalInfoNode : public Object {
+ public:
+ static constexpr const char* _type_key = "GlobalInfoNode";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
+ static constexpr const bool _type_has_method_shash_reduce = true;
+ TVM_DECLARE_BASE_OBJECT_INFO(GlobalInfoNode, Object);
+};
+
+/*!
+ * \brief Managed reference to GlobalInfoNode.
+ * \sa GlobalInfoNode
+ */
+class GlobalInfo : public ObjectRef {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(GlobalInfo, ObjectRef, GlobalInfoNode);
+};
+
+/*!
+ * \brief A dummy global info sub-class for testing purpose.
+ */
+class DummyGlobalInfoNode : public GlobalInfoNode {
+ public:
+ void VisitAttrs(tvm::AttrVisitor* v) {}
+ static constexpr const char* _type_key = "DummyGlobalInfo";
+
+ TVM_DLL bool SEqualReduce(const DummyGlobalInfoNode* other, SEqualReducer
equal) const {
+ return true;
+ }
+
+ TVM_DLL void SHashReduce(SHashReducer hash_reduce) const {}
+ TVM_DECLARE_FINAL_OBJECT_INFO(DummyGlobalInfoNode, GlobalInfoNode);
+};
+
+/*!
+ * \brief Managed reference to DummyGlobalInfoNode.
+ * \sa DummyGlobalInfoNode
+ */
+class DummyGlobalInfo : public GlobalInfo {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(DummyGlobalInfo, GlobalInfo,
DummyGlobalInfoNode);
+};
+
+} // namespace tvm
+
+#endif // TVM_IR_GLOBAL_INFO_H_
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 538ff64ca3..4c2d5cd812 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -27,6 +27,7 @@
#include <tvm/ir/adt.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
+#include <tvm/ir/global_info.h>
#include <tvm/ir/source_map.h>
#include <tvm/ir/type.h>
#include <tvm/runtime/container/array.h>
@@ -63,6 +64,8 @@ class IRModuleNode : public Object {
SourceMap source_map;
/* \brief Additional attributes storing meta-data about the module. */
DictAttrs attrs;
+ /*! \brief Globally static object that are referred by the IR itself */
+ Map<String, Array<GlobalInfo>> global_infos;
/*!
* \brief A map from string names to global variables that
* ensures global uniqueness.
@@ -151,6 +154,7 @@ class IRModuleNode : public Object {
v->Visit("global_type_var_map_", &global_type_var_map_);
v->Visit("source_map", &source_map);
v->Visit("attrs", &attrs);
+ v->Visit("global_infos", &global_infos);
}
TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal)
const;
@@ -210,6 +214,13 @@ class IRModuleNode : public Object {
*/
TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type);
+ /*!
+ * \brief Update an array of global infos in the global environment.
+ * \param name The name of the global info.
+ * \param info The new array of global infos.
+ */
+ TVM_DLL void UpdateGlobalInfo(const String& name, const Array<GlobalInfo>&
info);
+
/*!
* \brief Remove a function from the global environment.
* \param var The name of the global function to update.
@@ -359,12 +370,13 @@ class IRModule : public ObjectRef {
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module.
* \param map The module source map.
- * \param attrs The module attributes.
+ * \param attrs The module meta-data attributes.
+ * \param global_infos Global infos in the module.
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<String> import_set = {},
SourceMap map = {},
- DictAttrs attrs = {});
+ DictAttrs attrs = {}, Map<String,
Array<GlobalInfo>> global_infos = {});
/*! \brief default constructor */
IRModule() : IRModule(Map<GlobalVar, BaseFunc>({})) {}
diff --git a/include/tvm/script/ir_builder/base.h
b/include/tvm/script/ir_builder/base.h
index 61ca3eb9f7..a00ea5768e 100644
--- a/include/tvm/script/ir_builder/base.h
+++ b/include/tvm/script/ir_builder/base.h
@@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef {
* \sa tvm::support::With
*/
static IRBuilder Current();
+ /*! \brief See if the current thread-local scope has an IRBuilder. */
+ static bool IsInScope();
/*!
* \brief Give a string name to the `obj`
* \tparam TObjectRef The type of the object to name.
diff --git a/include/tvm/script/ir_builder/ir/frame.h
b/include/tvm/script/ir_builder/ir/frame.h
index dacfc361a6..6e758372b9 100644
--- a/include/tvm/script/ir_builder/ir/frame.h
+++ b/include/tvm/script/ir_builder/ir/frame.h
@@ -21,6 +21,7 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
+#include <tvm/ir/module.h>
#include <tvm/node/node.h>
#include <tvm/script/ir_builder/base.h>
@@ -45,11 +46,17 @@ class IRModuleFrameNode : public IRBuilderFrameNode {
* \note Only defined functions are in the map, while declared functions are
not included.
*/
Map<GlobalVar, BaseFunc> functions;
+ /*! \brief IRModule's attributes. */
+ Map<String, ObjectRef> attrs;
+ /*! \brief IRModule's global_infos */
+ Map<String, Array<GlobalInfo>> global_infos;
void VisitAttrs(tvm::AttrVisitor* v) {
IRBuilderFrameNode::VisitAttrs(v);
v->Visit("global_vars", &global_var_map);
v->Visit("functions", &functions);
+ v->Visit("attrs", &attrs);
+ v->Visit("global_infos", &global_infos);
}
static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame";
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 4f63cbecd9..01fea2abbd 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -34,6 +34,7 @@ from .base import (
from .container import Array, Map
from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr
from .function import BaseFunc, CallingConv
+from .global_info import GlobalInfo, DummyGlobalInfo
from .memory_pools import (
ConstantMemoryPools,
ConstantPoolInfo,
diff --git a/python/tvm/script/parser/ir/__init__.py
b/python/tvm/ir/global_info.py
similarity index 53%
copy from python/tvm/script/parser/ir/__init__.py
copy to python/tvm/ir/global_info.py
index fedd2f0a14..17011e76a6 100644
--- a/python/tvm/script/parser/ir/__init__.py
+++ b/python/tvm/ir/global_info.py
@@ -14,9 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""The ir module parser"""
+"""Global Info."""
+import tvm
+from tvm.runtime.object import Object
+from . import _ffi_api
-from . import parser as _parser
-from .entry import ir_module
-__all__ = ["ir_module"]
+class GlobalInfo(Object):
+ """Base node for all global info that can appear in the IR"""
+
+ def __eq__(self, other):
+ """Compare two struct info for structural equivalence."""
+ return tvm.ir.structural_equal(self, other)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def same_as(self, other):
+ """Overload with structural equality."""
+ return super().__eq__(other)
+
+
+class DummyGlobalInfo(GlobalInfo):
+ def __init__(self) -> None:
+ self.__init_handle_by_constructor__(
+ _ffi_api.DummyGlobalInfo,
+ )
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 6a151d5a89..707d46d0cd 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -42,7 +42,7 @@ class IRModule(Node, Scriptable):
Map of global var to BaseFunc
"""
- def __init__(self, functions=None, type_definitions=None):
+ def __init__(self, functions=None, type_definitions=None, attrs=None,
global_infos=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
@@ -65,7 +65,20 @@ class IRModule(Node, Scriptable):
raise TypeError("Expect type_definitions to be
Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
- self.__init_handle_by_constructor__(_ffi_api.IRModule, functions,
type_definitions)
+
+ attrs = None if not attrs else attrs
+ if attrs is not None:
+ attrs = ast.literal_eval(str(attrs))
+ attrs = tvm.ir.make_node("DictAttrs", **attrs)
+ if global_infos is None:
+ global_infos = {}
+ self.__init_handle_by_constructor__(
+ _ffi_api.IRModule,
+ functions,
+ type_definitions,
+ attrs,
+ global_infos,
+ )
def __setitem__(self, var, val):
"""Add a mapping to the module.
@@ -140,6 +153,19 @@ class IRModule(Node, Scriptable):
"""
return _ffi_api.Module_UpdateFunction(self, var, func)
+ def update_global_info(self, name, global_info):
+ """Update global info in the module
+
+ Parameters
+ ----------
+ name: str
+ The name for the global info.
+
+ global_info: List[GlobalInfo]
+ The global info to be updated.
+ """
+ return _ffi_api.Module_UpdateGlobalInfo(self, name, global_info)
+
def get_global_var(self, name):
"""Get a global variable in the function by name.
diff --git a/python/tvm/script/ir_builder/base.py
b/python/tvm/script/ir_builder/base.py
index b35bbd0a7d..1d5d050444 100644
--- a/python/tvm/script/ir_builder/base.py
+++ b/python/tvm/script/ir_builder/base.py
@@ -138,6 +138,17 @@ class IRBuilder(_Object):
"""
return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] #
pylint: disable=no-member
+ @staticmethod
+ def is_in_scope() -> bool:
+ """See if the current thread-local scope has an IRBuilder.
+
+ Returns
+ -------
+ bool
+ Whether the current thread-local scope has an IRBuilder
+ """
+ return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] #
pylint: disable=no-member
+
def get(self) -> _Object:
"""Get the constructed IR."""
return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] #
pylint: disable=no-member
diff --git a/python/tvm/script/ir_builder/ir/__init__.py
b/python/tvm/script/ir_builder/ir/__init__.py
index 946be263a7..68eda2cfee 100644
--- a/python/tvm/script/ir_builder/ir/__init__.py
+++ b/python/tvm/script/ir_builder/ir/__init__.py
@@ -16,4 +16,11 @@
# under the License.
"""Package tvm.script.ir_builder.ir"""
from .frame import IRModuleFrame
-from .ir import decl_function, def_function, ir_module
+from .ir import (
+ decl_function,
+ def_function,
+ ir_module,
+ module_attrs,
+ module_global_infos,
+ dummy_global_info,
+)
diff --git a/python/tvm/script/ir_builder/ir/ir.py
b/python/tvm/script/ir_builder/ir/ir.py
index 796d6f3aad..53c48b4cc5 100644
--- a/python/tvm/script/ir_builder/ir/ir.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -16,7 +16,11 @@
# under the License.
"""Package tvm.script.ir_builder.ir.ir"""
-from tvm.ir import BaseFunc, GlobalVar
+from typing import Dict, List
+
+from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, DummyGlobalInfo
+from tvm.runtime import Object as tvm_Object
+
from . import _ffi_api
from .frame import IRModuleFrame
@@ -67,3 +71,36 @@ def def_function(func_name: str, func: BaseFunc) -> None:
The given function implementation
"""
return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined]
# pylint: disable=no-member
+
+
+def module_attrs(attrs: Dict[str, tvm_Object]) -> None:
+ """Specify the attrs of the ir_module frame.
+ Parameters
+ ----------
+ attrs: Dict[str, Object]
+ The module attrs.
+ """
+ return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint:
disable=no-member
+
+
+def module_global_infos(global_infos: Dict[str, List[GlobalInfo]]) -> None:
+ """Specify the global infos of the ir_module frame.
+ Parameters
+ ----------
+ global_infos: Dict[str, List[GlobalInfo]]
+ The module global infos.
+ """
+ return _ffi_api.ModuleGlobalInfos(global_infos) # type:
ignore[attr-defined] # pylint: disable=no-member
+
+
+############################### GlobalInfo ###############################
+
+
+def dummy_global_info() -> DummyGlobalInfo:
+ """Create a dummy global info expression.
+ Returns
+ -------
+ res : DummyGlobalInfo
+ The result dummy global info.
+ """
+ return DummyGlobalInfo() # type: ignore[attr-defined] # pylint:
disable=no-member
diff --git a/python/tvm/script/parser/ir/__init__.py
b/python/tvm/script/parser/ir/__init__.py
index fedd2f0a14..f8c9d4f0af 100644
--- a/python/tvm/script/parser/ir/__init__.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""The ir module parser"""
-
+from ...ir_builder.ir import * # pylint: disable=redefined-builtin
from . import parser as _parser
from .entry import ir_module
-__all__ = ["ir_module"]
+__all__ = ["ir_module", "module_attrs", "module_global_infos",
"dummy_global_info"]
diff --git a/python/tvm/script/parser/ir/parser.py
b/python/tvm/script/parser/ir/parser.py
index 13b3e29859..201c99074f 100644
--- a/python/tvm/script/parser/ir/parser.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -35,11 +35,17 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) ->
None:
with self.var_table.with_frame():
with I.ir_module():
+ with self.with_dispatch_token("ir"):
+ for stmt in node.body:
+ if not isinstance(stmt, doc.FunctionDef):
+ self.visit(stmt)
for stmt in node.body:
if isinstance(stmt, doc.FunctionDef):
self.visit_tvm_declare_function(stmt)
with self.with_dispatch_token("ir"):
- self.visit_body(node.body)
+ for stmt in node.body:
+ if isinstance(stmt, doc.FunctionDef):
+ self.visit(stmt)
@dispatch.register(token="ir", type_name="Assign")
@@ -57,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
@dispatch.register(token="ir", type_name="Expr")
-def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+def _visit_expr(self: Parser, node: doc.Expr) -> None:
"""The expression visiting method for ir module.
Parameters
@@ -68,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
node : doc.ClassDef
The doc AST expression node.
"""
+ self.eval_expr(node.value)
@dispatch.register(token="default", type_name="Assign")
diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc
new file mode 100644
index 0000000000..48f56d60d6
--- /dev/null
+++ b/src/ir/global_info.cc
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/ir/global_info.cc
+ * \brief Module global info.
+ */
+
+#include <tvm/ir/global_info.h>
+namespace tvm {
+TVM_REGISTER_NODE_TYPE(DummyGlobalInfoNode);
+TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() {
+ auto n = DummyGlobalInfo(make_object<DummyGlobalInfoNode>());
+ return n;
+});
+} // namespace tvm
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 8f23f19d35..da1f3942c7 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -34,7 +34,8 @@ namespace tvm {
IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions,
- std::unordered_set<String> import_set, SourceMap
source_map, DictAttrs attrs) {
+ std::unordered_set<String> import_set, SourceMap
source_map, DictAttrs attrs,
+ Map<String, Array<GlobalInfo>> global_infos) {
auto n = make_object<IRModuleNode>();
n->functions = std::move(functions);
n->type_definitions = std::move(type_definitions);
@@ -44,6 +45,7 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
n->import_set_ = std::move(import_set);
n->source_map = source_map;
n->attrs = std::move(attrs);
+ n->global_infos = std::move(global_infos);
for (const auto& kv : n->functions) {
// set global var map
@@ -64,7 +66,10 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer
equal) const {
if (!equal(this->attrs, other->attrs)) return false;
-
+ if (this->global_infos.size() != other->global_infos.size()) return false;
+ for (const auto& kv : this->global_infos) {
+ if (!equal(kv.second, other->global_infos[kv.first])) return false;
+ }
if (functions.size() != other->functions.size()) return false;
// Update GlobalVar remap
for (const auto& gv : this->GetGlobalVars()) {
@@ -116,6 +121,7 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce)
const {
}
reduce_temp();
hash_reduce(this->attrs);
+ hash_reduce(this->global_infos);
}
bool IRModuleNode::ContainGlobalVar(const String& name) const {
@@ -239,6 +245,10 @@ void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var,
const TypeData& type)
this->AddTypeDef(var, type, true);
}
+void IRModuleNode::UpdateGlobalInfo(const String& name, const
Array<GlobalInfo>& info) {
+ this->global_infos.Set(name, info);
+}
+
void IRModuleNode::Remove(const GlobalVar& var) {
auto functions_node = this->functions.CopyOnWrite();
functions_node->erase(var);
@@ -359,9 +369,9 @@ IRModule IRModule::FromText(const String& text, const
String& source_path) {
TVM_REGISTER_NODE_TYPE(IRModuleNode);
TVM_REGISTER_GLOBAL("ir.IRModule")
- .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
- tvm::Map<GlobalTypeVar, TypeData> types) {
- return IRModule(funcs, types, {});
+ .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
tvm::Map<GlobalTypeVar, TypeData> types,
+ tvm::DictAttrs attrs, Map<String, Array<GlobalInfo>>
global_infos) {
+ return IRModule(funcs, types, {}, {}, attrs, global_infos);
});
TVM_REGISTER_GLOBAL("ir.Module_Add")
@@ -423,6 +433,11 @@
TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule
TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction")
.set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) {
mod->Update(gv, func); });
+TVM_REGISTER_GLOBAL("ir.Module_UpdateGlobalInfo")
+ .set_body_typed([](IRModule mod, String name, Array<GlobalInfo>
global_info) {
+ mod->UpdateGlobalInfo(name, global_info);
+ });
+
TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String
path) {
mod->Import(path);
});
diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc
index 8303efff4f..879db4f3d7 100644
--- a/src/script/ir_builder/base.cc
+++ b/src/script/ir_builder/base.cc
@@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() {
return stack->back();
}
+bool IRBuilder::IsInScope() {
+ std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+ return !stack->empty();
+}
+
namespace details {
Namer::FType& Namer::vtable() {
@@ -106,6 +111,7 @@
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet")
.set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);
diff --git a/src/script/ir_builder/ir/frame.cc
b/src/script/ir_builder/ir/frame.cc
index addf129284..3d917cee88 100644
--- a/src/script/ir_builder/ir/frame.cc
+++ b/src/script/ir_builder/ir/frame.cc
@@ -38,7 +38,8 @@ void IRModuleFrameNode::ExitWithScope() {
}
IRBuilder builder = IRBuilder::Current();
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has
already been set";
- builder->result = tvm::IRModule(func_map);
+ auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
+ builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs,
global_infos);
}
TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index da2330b577..148e90b28c 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -69,9 +69,33 @@ void DefFunction(const String& func_name, const BaseFunc&
func) {
}
}
+void ModuleAttrs(Map<String, ObjectRef> attrs) {
+ if (IRBuilder::IsInScope()) {
+ // TODO(hongyi): add comments to explain why we need to check if the
module frame is in scope
+ IRModuleFrame frame = FindModuleFrame("I.ModuleAttr");
+ if (!frame->attrs.empty()) {
+ LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n"
<< frame->attrs;
+ }
+ frame->attrs = attrs;
+ }
+}
+
+void ModuleGlobalInfos(Map<String, Array<GlobalInfo>> global_infos) {
+ if (IRBuilder::IsInScope()) {
+ IRModuleFrame frame = FindModuleFrame("I.ModuleGlobalInfos");
+ if (!frame->global_infos.empty()) {
+ LOG(FATAL) << "ValueError: Duplicate module global_infos, previous one
is:\n"
+ << frame->global_infos;
+ }
+ frame->global_infos = global_infos;
+ }
+}
+
TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction);
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos);
} // namespace ir
} // namespace ir_builder
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index e6f4a1eaee..62919246b0 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -64,6 +64,16 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
std::sort(functions.begin(), functions.end());
With<IRFrame> f(d);
(*f)->AddDispatchToken(d, "ir");
+ if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
+ (*f)->stmts.push_back(
+ ExprStmtDoc(IR(d, "module_attrs") //
+ ->Call({d->AsDoc<ExprDoc>(mod->attrs,
p->Attr("attrs"))})));
+ }
+ if (mod->global_infos.defined() && !mod->global_infos.empty()) {
+ (*f)->stmts.push_back(ExprStmtDoc(
+ IR(d, "module_global_infos") //
+ ->Call({d->AsDoc<ExprDoc>(mod->global_infos,
p->Attr("global_infos"))})));
+ }
for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
const BaseFunc& func = entry.func;
@@ -92,6 +102,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint,
p->Attr("name_hint"))});
});
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<DummyGlobalInfo>("", [](GlobalInfo ginfo, ObjectPath p,
IRDocsifier d) -> Doc {
+ return IR(d, "dummy_global_info")->Call({});
+ });
+
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<Op>("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc {
return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))});
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index 7724c8e761..9636a98b41 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -22,6 +22,7 @@ import tvm
import tvm.script
import tvm.testing
from tvm import IRModule, relax, tir, topi
+from tvm.ir import DummyGlobalInfo
from tvm.script.parser import ir as I
from tvm.script.parser import relax as R
from tvm.script.parser import tir as T
@@ -183,6 +184,47 @@ def test_simple_module():
_check(TestModule, bb.get())
+def test_module_with_attr_and_global_info():
+ @I.ir_module
+ class TestModule:
+ I.module_attrs({"attr": 10})
+ I.module_global_infos(
+ {
+ "dummy": [
+ I.dummy_global_info(), # dummy[0]
+ I.dummy_global_info(), # dummy[1]
+ ]
+ }
+ )
+
+ @T.prim_func
+ def tir_func(
+ x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for i, j in T.grid(T.int64(128), T.int64(128)):
+ with T.block():
+ vi, vj = T.axis.remap("SS", [i, j])
+ y[vi, vj] = x[vi, vj] + 1.0
+
+ @R.function
+ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128),
"float32"):
+ # TODO(Siyuan): Need to change to `TestModule.tir_func`
+ gv0 = R.call_tir(tir_func, x, R.Tensor((128, 128),
dtype="float32"))
+ return gv0
+
+ x = relax.Var("x", R.Tensor((128, 128), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", (x,)):
+ out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
+ bb.emit_func_output(out)
+ mod = bb.get()
+ mod.update_global_info("dummy", [DummyGlobalInfo(), DummyGlobalInfo()])
+ mod = mod.with_attr("attr", tvm.tir.IntImm("int32", 10))
+ _check(TestModule, mod)
+
+
def test_relax_tensor_op():
@R.function
def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"):