This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e62169cc8a [Unity] Add Global info (#14132)
e62169cc8a is described below

commit e62169cc8afb5b9062a40072d6d44fc817408d77
Author: Hongyi Jin <[email protected]>
AuthorDate: Sun Feb 26 11:05:47 2023 -0500

    [Unity] Add Global info (#14132)
---
 include/tvm/ir/global_info.h                       | 80 ++++++++++++++++++++++
 include/tvm/ir/module.h                            | 16 ++++-
 include/tvm/script/ir_builder/base.h               |  2 +
 include/tvm/script/ir_builder/ir/frame.h           |  7 ++
 python/tvm/ir/__init__.py                          |  1 +
 .../parser/ir/__init__.py => ir/global_info.py}    | 28 ++++++--
 python/tvm/ir/module.py                            | 30 +++++++-
 python/tvm/script/ir_builder/base.py               | 11 +++
 python/tvm/script/ir_builder/ir/__init__.py        |  9 ++-
 python/tvm/script/ir_builder/ir/ir.py              | 39 ++++++++++-
 python/tvm/script/parser/ir/__init__.py            |  4 +-
 python/tvm/script/parser/ir/parser.py              | 11 ++-
 src/ir/global_info.cc                              | 32 +++++++++
 src/ir/module.cc                                   | 25 +++++--
 src/script/ir_builder/base.cc                      |  6 ++
 src/script/ir_builder/ir/frame.cc                  |  3 +-
 src/script/ir_builder/ir/ir.cc                     | 24 +++++++
 src/script/printer/ir/ir.cc                        | 15 ++++
 tests/python/relax/test_tvmscript_parser.py        | 42 ++++++++++++
 19 files changed, 365 insertions(+), 20 deletions(-)

diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h
new file mode 100644
index 0000000000..65b5e0a3d2
--- /dev/null
+++ b/include/tvm/ir/global_info.h
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ir/global_info.h
+ * \brief GlobalInfo are globally static object that are referred by the IR 
itself.
+ */
+
+#ifndef TVM_IR_GLOBAL_INFO_H_
+#define TVM_IR_GLOBAL_INFO_H_
+
+#include "tvm/ir/expr.h"
+
+namespace tvm {
+
+/*!
+ * \brief GlobalInfo are globally static object that are referred by the IR 
itself.
+ *        Base node for all global info that can appear in the IR
+ */
+class GlobalInfoNode : public Object {
+ public:
+  static constexpr const char* _type_key = "GlobalInfoNode";
+  static constexpr const bool _type_has_method_sequal_reduce = true;
+  static constexpr const bool _type_has_method_shash_reduce = true;
+  TVM_DECLARE_BASE_OBJECT_INFO(GlobalInfoNode, Object);
+};
+
+/*!
+ * \brief Managed reference to GlobalInfoNode.
+ * \sa GlobalInfoNode
+ */
+class GlobalInfo : public ObjectRef {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(GlobalInfo, ObjectRef, GlobalInfoNode);
+};
+
+/*!
+ * \brief A dummy global info sub-class for testing purpose.
+ */
+class DummyGlobalInfoNode : public GlobalInfoNode {
+ public:
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+  static constexpr const char* _type_key = "DummyGlobalInfo";
+
+  TVM_DLL bool SEqualReduce(const DummyGlobalInfoNode* other, SEqualReducer 
equal) const {
+    return true;
+  }
+
+  TVM_DLL void SHashReduce(SHashReducer hash_reduce) const {}
+  TVM_DECLARE_FINAL_OBJECT_INFO(DummyGlobalInfoNode, GlobalInfoNode);
+};
+
+/*!
+ * \brief Managed reference to DummyGlobalInfoNode.
+ * \sa DummyGlobalInfoNode
+ */
+class DummyGlobalInfo : public GlobalInfo {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(DummyGlobalInfo, GlobalInfo, 
DummyGlobalInfoNode);
+};
+
+}  // namespace tvm
+
+#endif  // TVM_IR_GLOBAL_INFO_H_
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 538ff64ca3..4c2d5cd812 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -27,6 +27,7 @@
 #include <tvm/ir/adt.h>
 #include <tvm/ir/expr.h>
 #include <tvm/ir/function.h>
+#include <tvm/ir/global_info.h>
 #include <tvm/ir/source_map.h>
 #include <tvm/ir/type.h>
 #include <tvm/runtime/container/array.h>
@@ -63,6 +64,8 @@ class IRModuleNode : public Object {
   SourceMap source_map;
   /* \brief Additional attributes storing meta-data about the module. */
   DictAttrs attrs;
+  /*! \brief Globally static object that are referred by the IR itself */
+  Map<String, Array<GlobalInfo>> global_infos;
   /*!
    * \brief A map from string names to global variables that
    * ensures global uniqueness.
@@ -151,6 +154,7 @@ class IRModuleNode : public Object {
     v->Visit("global_type_var_map_", &global_type_var_map_);
     v->Visit("source_map", &source_map);
     v->Visit("attrs", &attrs);
+    v->Visit("global_infos", &global_infos);
   }
 
   TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) 
const;
@@ -210,6 +214,13 @@ class IRModuleNode : public Object {
    */
   TVM_DLL void UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type);
 
+  /*!
+   * \brief Update an array of global infos in the global environment.
+   * \param name The name of the global info.
+   * \param info The new array of global infos.
+   */
+  TVM_DLL void UpdateGlobalInfo(const String& name, const Array<GlobalInfo>& 
info);
+
   /*!
    * \brief Remove a function from the global environment.
    * \param var The name of the global function to update.
@@ -359,12 +370,13 @@ class IRModule : public ObjectRef {
    * \param type_definitions Type definitions in the module.
    * \param import_set Set of imported files in the module.
    * \param map The module source map.
-   * \param attrs The module attributes.
+   * \param attrs The module meta-data attributes.
+   * \param global_infos Global infos in the module.
    */
   TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
                             Map<GlobalTypeVar, TypeData> type_definitions = {},
                             std::unordered_set<String> import_set = {}, 
SourceMap map = {},
-                            DictAttrs attrs = {});
+                            DictAttrs attrs = {}, Map<String, 
Array<GlobalInfo>> global_infos = {});
 
   /*! \brief default constructor */
   IRModule() : IRModule(Map<GlobalVar, BaseFunc>({})) {}
diff --git a/include/tvm/script/ir_builder/base.h 
b/include/tvm/script/ir_builder/base.h
index 61ca3eb9f7..a00ea5768e 100644
--- a/include/tvm/script/ir_builder/base.h
+++ b/include/tvm/script/ir_builder/base.h
@@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef {
    * \sa tvm::support::With
    */
   static IRBuilder Current();
+  /*! \brief See if the current thread-local scope has an IRBuilder. */
+  static bool IsInScope();
   /*!
    * \brief Give a string name to the `obj`
    * \tparam TObjectRef The type of the object to name.
diff --git a/include/tvm/script/ir_builder/ir/frame.h 
b/include/tvm/script/ir_builder/ir/frame.h
index dacfc361a6..6e758372b9 100644
--- a/include/tvm/script/ir_builder/ir/frame.h
+++ b/include/tvm/script/ir_builder/ir/frame.h
@@ -21,6 +21,7 @@
 
 #include <tvm/ir/expr.h>
 #include <tvm/ir/function.h>
+#include <tvm/ir/module.h>
 #include <tvm/node/node.h>
 #include <tvm/script/ir_builder/base.h>
 
@@ -45,11 +46,17 @@ class IRModuleFrameNode : public IRBuilderFrameNode {
    * \note Only defined functions are in the map, while declared functions are 
not included.
    */
   Map<GlobalVar, BaseFunc> functions;
+  /*! \brief IRModule's attributes. */
+  Map<String, ObjectRef> attrs;
+  /*! \brief IRModule's global_infos */
+  Map<String, Array<GlobalInfo>> global_infos;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     IRBuilderFrameNode::VisitAttrs(v);
     v->Visit("global_vars", &global_var_map);
     v->Visit("functions", &functions);
+    v->Visit("attrs", &attrs);
+    v->Visit("global_infos", &global_infos);
   }
 
   static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame";
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 4f63cbecd9..01fea2abbd 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -34,6 +34,7 @@ from .base import (
 from .container import Array, Map
 from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr
 from .function import BaseFunc, CallingConv
+from .global_info import GlobalInfo, DummyGlobalInfo
 from .memory_pools import (
     ConstantMemoryPools,
     ConstantPoolInfo,
diff --git a/python/tvm/script/parser/ir/__init__.py 
b/python/tvm/ir/global_info.py
similarity index 53%
copy from python/tvm/script/parser/ir/__init__.py
copy to python/tvm/ir/global_info.py
index fedd2f0a14..17011e76a6 100644
--- a/python/tvm/script/parser/ir/__init__.py
+++ b/python/tvm/ir/global_info.py
@@ -14,9 +14,29 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""The ir module parser"""
+"""Global Info."""
+import tvm
+from tvm.runtime.object import Object
+from . import _ffi_api
 
-from . import parser as _parser
-from .entry import ir_module
 
-__all__ = ["ir_module"]
+class GlobalInfo(Object):
+    """Base node for all global info that can appear in the IR"""
+
+    def __eq__(self, other):
+        """Compare two struct info for structural equivalence."""
+        return tvm.ir.structural_equal(self, other)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def same_as(self, other):
+        """Overload with structural equality."""
+        return super().__eq__(other)
+
+
+class DummyGlobalInfo(GlobalInfo):
+    def __init__(self) -> None:
+        self.__init_handle_by_constructor__(
+            _ffi_api.DummyGlobalInfo,
+        )
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 6a151d5a89..707d46d0cd 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -42,7 +42,7 @@ class IRModule(Node, Scriptable):
         Map of global var to BaseFunc
     """
 
-    def __init__(self, functions=None, type_definitions=None):
+    def __init__(self, functions=None, type_definitions=None, attrs=None, 
global_infos=None):
         if functions is None:
             functions = {}
         elif isinstance(functions, dict):
@@ -65,7 +65,20 @@ class IRModule(Node, Scriptable):
                     raise TypeError("Expect type_definitions to be 
Dict[GlobalTypeVar, Type]")
                 mapped_type_defs[k] = v
             type_definitions = mapped_type_defs
-        self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, 
type_definitions)
+
+        attrs = None if not attrs else attrs
+        if attrs is not None:
+            attrs = ast.literal_eval(str(attrs))
+            attrs = tvm.ir.make_node("DictAttrs", **attrs)
+        if global_infos is None:
+            global_infos = {}
+        self.__init_handle_by_constructor__(
+            _ffi_api.IRModule,
+            functions,
+            type_definitions,
+            attrs,
+            global_infos,
+        )
 
     def __setitem__(self, var, val):
         """Add a mapping to the module.
@@ -140,6 +153,19 @@ class IRModule(Node, Scriptable):
         """
         return _ffi_api.Module_UpdateFunction(self, var, func)
 
+    def update_global_info(self, name, global_info):
+        """Update global info in the module
+
+        Parameters
+        ----------
+        name: str
+            The name for the global info.
+
+        global_info: List[GlobalInfo]
+            The global info to be updated.
+        """
+        return _ffi_api.Module_UpdateGlobalInfo(self, name, global_info)
+
     def get_global_var(self, name):
         """Get a global variable in the function by name.
 
diff --git a/python/tvm/script/ir_builder/base.py 
b/python/tvm/script/ir_builder/base.py
index b35bbd0a7d..1d5d050444 100644
--- a/python/tvm/script/ir_builder/base.py
+++ b/python/tvm/script/ir_builder/base.py
@@ -138,6 +138,17 @@ class IRBuilder(_Object):
         """
         return _ffi_api.IRBuilderCurrent()  # type: ignore[attr-defined] # 
pylint: disable=no-member
 
+    @staticmethod
+    def is_in_scope() -> bool:
+        """See if the current thread-local scope has an IRBuilder.
+
+        Returns
+        -------
+        bool
+            Whether the current thread-local scope has an IRBuilder
+        """
+        return _ffi_api.IRBuilderIsInScope()  # type: ignore[attr-defined] # 
pylint: disable=no-member
+
     def get(self) -> _Object:
         """Get the constructed IR."""
         return _ffi_api.IRBuilderGet(self)  # type: ignore[attr-defined] # 
pylint: disable=no-member
diff --git a/python/tvm/script/ir_builder/ir/__init__.py 
b/python/tvm/script/ir_builder/ir/__init__.py
index 946be263a7..68eda2cfee 100644
--- a/python/tvm/script/ir_builder/ir/__init__.py
+++ b/python/tvm/script/ir_builder/ir/__init__.py
@@ -16,4 +16,11 @@
 # under the License.
 """Package tvm.script.ir_builder.ir"""
 from .frame import IRModuleFrame
-from .ir import decl_function, def_function, ir_module
+from .ir import (
+    decl_function,
+    def_function,
+    ir_module,
+    module_attrs,
+    module_global_infos,
+    dummy_global_info,
+)
diff --git a/python/tvm/script/ir_builder/ir/ir.py 
b/python/tvm/script/ir_builder/ir/ir.py
index 796d6f3aad..53c48b4cc5 100644
--- a/python/tvm/script/ir_builder/ir/ir.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -16,7 +16,11 @@
 # under the License.
 """Package tvm.script.ir_builder.ir.ir"""
 
-from tvm.ir import BaseFunc, GlobalVar
+from typing import Dict, List
+
+from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, DummyGlobalInfo
+from tvm.runtime import Object as tvm_Object
+
 
 from . import _ffi_api
 from .frame import IRModuleFrame
@@ -67,3 +71,36 @@ def def_function(func_name: str, func: BaseFunc) -> None:
         The given function implementation
     """
     return _ffi_api.DefFunction(func_name, func)  # type: ignore[attr-defined] 
# pylint: disable=no-member
+
+
+def module_attrs(attrs: Dict[str, tvm_Object]) -> None:
+    """Specify the attrs of the ir_module frame.
+    Parameters
+    ----------
+    attrs: Dict[str, Object]
+        The module attrs.
+    """
+    return _ffi_api.ModuleAttrs(attrs)  # type: ignore[attr-defined] # pylint: 
disable=no-member
+
+
+def module_global_infos(global_infos: Dict[str, List[GlobalInfo]]) -> None:
+    """Specify the global infos of the ir_module frame.
+    Parameters
+    ----------
+    global_infos: Dict[str, List[GlobalInfo]]
+        The module global infos.
+    """
+    return _ffi_api.ModuleGlobalInfos(global_infos)  # type: 
ignore[attr-defined] # pylint: disable=no-member
+
+
+############################### GlobalInfo ###############################
+
+
+def dummy_global_info() -> DummyGlobalInfo:
+    """Create a dummy global info expression.
+    Returns
+    -------
+    res : DummyGlobalInfo
+        The result dummy global info.
+    """
+    return DummyGlobalInfo()  # type: ignore[attr-defined] # pylint: 
disable=no-member
diff --git a/python/tvm/script/parser/ir/__init__.py 
b/python/tvm/script/parser/ir/__init__.py
index fedd2f0a14..f8c9d4f0af 100644
--- a/python/tvm/script/parser/ir/__init__.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -15,8 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """The ir module parser"""
-
+from ...ir_builder.ir import *  # pylint: disable=redefined-builtin
 from . import parser as _parser
 from .entry import ir_module
 
-__all__ = ["ir_module"]
+__all__ = ["ir_module", "module_attrs", "module_global_infos", 
"dummy_global_info"]
diff --git a/python/tvm/script/parser/ir/parser.py 
b/python/tvm/script/parser/ir/parser.py
index 13b3e29859..201c99074f 100644
--- a/python/tvm/script/parser/ir/parser.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -35,11 +35,17 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> 
None:
 
     with self.var_table.with_frame():
         with I.ir_module():
+            with self.with_dispatch_token("ir"):
+                for stmt in node.body:
+                    if not isinstance(stmt, doc.FunctionDef):
+                        self.visit(stmt)
             for stmt in node.body:
                 if isinstance(stmt, doc.FunctionDef):
                     self.visit_tvm_declare_function(stmt)
             with self.with_dispatch_token("ir"):
-                self.visit_body(node.body)
+                for stmt in node.body:
+                    if isinstance(stmt, doc.FunctionDef):
+                        self.visit(stmt)
 
 
 @dispatch.register(token="ir", type_name="Assign")
@@ -57,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None:
 
 
 @dispatch.register(token="ir", type_name="Expr")
-def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
+def _visit_expr(self: Parser, node: doc.Expr) -> None:
     """The expression visiting method for ir module.
 
     Parameters
@@ -68,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
     node : doc.ClassDef
         The doc AST expression node.
     """
+    self.eval_expr(node.value)
 
 
 @dispatch.register(token="default", type_name="Assign")
diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc
new file mode 100644
index 0000000000..48f56d60d6
--- /dev/null
+++ b/src/ir/global_info.cc
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/ir/global_info.cc
+ * \brief Module global info.
+ */
+
+#include <tvm/ir/global_info.h>
+namespace tvm {
+TVM_REGISTER_NODE_TYPE(DummyGlobalInfoNode);
+TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() {
+  auto n = DummyGlobalInfo(make_object<DummyGlobalInfoNode>());
+  return n;
+});
+}  // namespace tvm
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 8f23f19d35..da1f3942c7 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -34,7 +34,8 @@ namespace tvm {
 
 IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
                    tvm::Map<GlobalTypeVar, TypeData> type_definitions,
-                   std::unordered_set<String> import_set, SourceMap 
source_map, DictAttrs attrs) {
+                   std::unordered_set<String> import_set, SourceMap 
source_map, DictAttrs attrs,
+                   Map<String, Array<GlobalInfo>> global_infos) {
   auto n = make_object<IRModuleNode>();
   n->functions = std::move(functions);
   n->type_definitions = std::move(type_definitions);
@@ -44,6 +45,7 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
   n->import_set_ = std::move(import_set);
   n->source_map = source_map;
   n->attrs = std::move(attrs);
+  n->global_infos = std::move(global_infos);
 
   for (const auto& kv : n->functions) {
     // set global var map
@@ -64,7 +66,10 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
 
 bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer 
equal) const {
   if (!equal(this->attrs, other->attrs)) return false;
-
+  if (this->global_infos.size() != other->global_infos.size()) return false;
+  for (const auto& kv : this->global_infos) {
+    if (!equal(kv.second, other->global_infos[kv.first])) return false;
+  }
   if (functions.size() != other->functions.size()) return false;
   // Update GlobalVar remap
   for (const auto& gv : this->GetGlobalVars()) {
@@ -116,6 +121,7 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) 
const {
   }
   reduce_temp();
   hash_reduce(this->attrs);
+  hash_reduce(this->global_infos);
 }
 
 bool IRModuleNode::ContainGlobalVar(const String& name) const {
@@ -239,6 +245,10 @@ void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, 
const TypeData& type)
   this->AddTypeDef(var, type, true);
 }
 
+void IRModuleNode::UpdateGlobalInfo(const String& name, const 
Array<GlobalInfo>& info) {
+  this->global_infos.Set(name, info);
+}
+
 void IRModuleNode::Remove(const GlobalVar& var) {
   auto functions_node = this->functions.CopyOnWrite();
   functions_node->erase(var);
@@ -359,9 +369,9 @@ IRModule IRModule::FromText(const String& text, const 
String& source_path) {
 TVM_REGISTER_NODE_TYPE(IRModuleNode);
 
 TVM_REGISTER_GLOBAL("ir.IRModule")
-    .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
-                       tvm::Map<GlobalTypeVar, TypeData> types) {
-      return IRModule(funcs, types, {});
+    .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, 
tvm::Map<GlobalTypeVar, TypeData> types,
+                       tvm::DictAttrs attrs, Map<String, Array<GlobalInfo>> 
global_infos) {
+      return IRModule(funcs, types, {}, {}, attrs, global_infos);
     });
 
 TVM_REGISTER_GLOBAL("ir.Module_Add")
@@ -423,6 +433,11 @@ 
TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule
 TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction")
     .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { 
mod->Update(gv, func); });
 
+TVM_REGISTER_GLOBAL("ir.Module_UpdateGlobalInfo")
+    .set_body_typed([](IRModule mod, String name, Array<GlobalInfo> 
global_info) {
+      mod->UpdateGlobalInfo(name, global_info);
+    });
+
 TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String 
path) {
   mod->Import(path);
 });
diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc
index 8303efff4f..879db4f3d7 100644
--- a/src/script/ir_builder/base.cc
+++ b/src/script/ir_builder/base.cc
@@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() {
   return stack->back();
 }
 
+bool IRBuilder::IsInScope() {
+  std::vector<IRBuilder>* stack = ThreadLocalBuilderStack();
+  return !stack->empty();
+}
+
 namespace details {
 
 Namer::FType& Namer::vtable() {
@@ -106,6 +111,7 @@ 
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return
 
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope);
 
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope);
 
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current);
+TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope);
 TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet")
     .set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>);
 
TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>);
diff --git a/src/script/ir_builder/ir/frame.cc 
b/src/script/ir_builder/ir/frame.cc
index addf129284..3d917cee88 100644
--- a/src/script/ir_builder/ir/frame.cc
+++ b/src/script/ir_builder/ir/frame.cc
@@ -38,7 +38,8 @@ void IRModuleFrameNode::ExitWithScope() {
   }
   IRBuilder builder = IRBuilder::Current();
   ICHECK(!builder->result.defined()) << "ValueError: Builder.result has 
already been set";
-  builder->result = tvm::IRModule(func_map);
+  auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
+  builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs, 
global_infos);
 }
 
 TVM_REGISTER_NODE_TYPE(IRModuleFrameNode);
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index da2330b577..148e90b28c 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -69,9 +69,33 @@ void DefFunction(const String& func_name, const BaseFunc& 
func) {
   }
 }
 
+void ModuleAttrs(Map<String, ObjectRef> attrs) {
+  if (IRBuilder::IsInScope()) {
+    // TODO(hongyi): add comments to explain why we need to check if the 
module frame is in scope
+    IRModuleFrame frame = FindModuleFrame("I.ModuleAttr");
+    if (!frame->attrs.empty()) {
+      LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" 
<< frame->attrs;
+    }
+    frame->attrs = attrs;
+  }
+}
+
+void ModuleGlobalInfos(Map<String, Array<GlobalInfo>> global_infos) {
+  if (IRBuilder::IsInScope()) {
+    IRModuleFrame frame = FindModuleFrame("I.ModuleGlobalInfos");
+    if (!frame->global_infos.empty()) {
+      LOG(FATAL) << "ValueError: Duplicate module global_infos, previous one 
is:\n"
+                 << frame->global_infos;
+    }
+    frame->global_infos = global_infos;
+  }
+}
+
 TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
 
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction);
 
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos);
 
 }  // namespace ir
 }  // namespace ir_builder
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index e6f4a1eaee..62919246b0 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -64,6 +64,16 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       std::sort(functions.begin(), functions.end());
       With<IRFrame> f(d);
       (*f)->AddDispatchToken(d, "ir");
+      if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
+        (*f)->stmts.push_back(
+            ExprStmtDoc(IR(d, "module_attrs")  //
+                            ->Call({d->AsDoc<ExprDoc>(mod->attrs, 
p->Attr("attrs"))})));
+      }
+      if (mod->global_infos.defined() && !mod->global_infos.empty()) {
+        (*f)->stmts.push_back(ExprStmtDoc(
+            IR(d, "module_global_infos")  //
+                ->Call({d->AsDoc<ExprDoc>(mod->global_infos, 
p->Attr("global_infos"))})));
+      }
       for (const auto& entry : functions) {
         const GlobalVar& gv = entry.gv;
         const BaseFunc& func = entry.func;
@@ -92,6 +102,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, 
p->Attr("name_hint"))});
     });
 
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+    .set_dispatch<DummyGlobalInfo>("", [](GlobalInfo ginfo, ObjectPath p, 
IRDocsifier d) -> Doc {
+      return IR(d, "dummy_global_info")->Call({});
+    });
+
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<Op>("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc {
       return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))});
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 7724c8e761..9636a98b41 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -22,6 +22,7 @@ import tvm
 import tvm.script
 import tvm.testing
 from tvm import IRModule, relax, tir, topi
+from tvm.ir import DummyGlobalInfo
 from tvm.script.parser import ir as I
 from tvm.script.parser import relax as R
 from tvm.script.parser import tir as T
@@ -183,6 +184,47 @@ def test_simple_module():
     _check(TestModule, bb.get())
 
 
+def test_module_with_attr_and_global_info():
+    @I.ir_module
+    class TestModule:
+        I.module_attrs({"attr": 10})
+        I.module_global_infos(
+            {
+                "dummy": [
+                    I.dummy_global_info(),  # dummy[0]
+                    I.dummy_global_info(),  # dummy[1]
+                ]
+            }
+        )
+
+        @T.prim_func
+        def tir_func(
+            x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+            y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j in T.grid(T.int64(128), T.int64(128)):
+                with T.block():
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    y[vi, vj] = x[vi, vj] + 1.0
+
+        @R.function
+        def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), 
"float32"):
+            # TODO(Siyuan): Need to change to `TestModule.tir_func`
+            gv0 = R.call_tir(tir_func, x, R.Tensor((128, 128), 
dtype="float32"))
+            return gv0
+
+    x = relax.Var("x", R.Tensor((128, 128), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", (x,)):
+        out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
+        bb.emit_func_output(out)
+    mod = bb.get()
+    mod.update_global_info("dummy", [DummyGlobalInfo(), DummyGlobalInfo()])
+    mod = mod.with_attr("attr", tvm.tir.IntImm("int32", 10))
+    _check(TestModule, mod)
+
+
 def test_relax_tensor_op():
     @R.function
     def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"):

Reply via email to