This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ff5118f3981b0216c6c961185513d5625a37f82f Author: Hongyi Jin <[email protected]> AuthorDate: Sun Feb 26 11:05:47 2023 -0500 [TVMScript] Expose IRModule::attrs as I.module_attrs This is an upstreaming of the non-relax portions of https://github.com/apache/tvm/pull/14132, including a unit test specically to validate `I.module_attrs`. --- include/tvm/script/ir_builder/base.h | 2 ++ include/tvm/script/ir_builder/ir/frame.h | 3 +++ python/tvm/ir/module.py | 14 ++++++++++++-- python/tvm/script/ir_builder/base.py | 11 +++++++++++ python/tvm/script/ir_builder/ir/__init__.py | 7 ++++++- python/tvm/script/ir_builder/ir/ir.py | 14 ++++++++++++++ python/tvm/script/parser/ir/__init__.py | 4 ++-- python/tvm/script/parser/ir/parser.py | 11 +++++++++-- src/ir/module.cc | 6 ++---- src/script/ir_builder/base.cc | 6 ++++++ src/script/ir_builder/ir/frame.cc | 3 ++- src/script/ir_builder/ir/ir.cc | 12 ++++++++++++ src/script/printer/ir/ir.cc | 5 +++++ tests/python/unittest/test_tvmscript_roundtrip.py | 14 ++++++++++++++ 14 files changed, 100 insertions(+), 12 deletions(-) diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 61ca3eb9f7..a00ea5768e 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -237,6 +237,8 @@ class IRBuilder : public runtime::ObjectRef { * \sa tvm::support::With */ static IRBuilder Current(); + /*! \brief See if the current thread-local scope has an IRBuilder. */ + static bool IsInScope(); /*! * \brief Give a string name to the `obj` * \tparam TObjectRef The type of the object to name. diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index dacfc361a6..ed425cf614 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -45,11 +45,14 @@ class IRModuleFrameNode : public IRBuilderFrameNode { * \note Only defined functions are in the map, while declared functions are not included. */ Map<GlobalVar, BaseFunc> functions; + /*! \brief IRModule's attributes. */ + Map<String, ObjectRef> attrs; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); + v->Visit("attrs", &attrs); } static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame"; diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3daffb2640..232c70aa93 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -37,7 +37,7 @@ class IRModule(Node, Scriptable): Map of global var to BaseFunc """ - def __init__(self, functions=None, type_definitions=None): + def __init__(self, functions=None, type_definitions=None, attrs=None): if functions is None: functions = {} elif isinstance(functions, dict): @@ -60,7 +60,17 @@ class IRModule(Node, Scriptable): raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]") mapped_type_defs[k] = v type_definitions = mapped_type_defs - self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) + + attrs = None if not attrs else attrs + if attrs is not None: + attrs = ast.literal_eval(str(attrs)) + attrs = tvm.ir.make_node("DictAttrs", **attrs) + self.__init_handle_by_constructor__( + _ffi_api.IRModule, + functions, + type_definitions, + attrs, + ) def __setitem__(self, var, val): """Add a mapping to the module. diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index b35bbd0a7d..1d5d050444 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -138,6 +138,17 @@ class IRBuilder(_Object): """ return _ffi_api.IRBuilderCurrent() # type: ignore[attr-defined] # pylint: disable=no-member + @staticmethod + def is_in_scope() -> bool: + """See if the current thread-local scope has an IRBuilder. + + Returns + ------- + bool + Whether the current thread-local scope has an IRBuilder + """ + return _ffi_api.IRBuilderIsInScope() # type: ignore[attr-defined] # pylint: disable=no-member + def get(self) -> _Object: """Get the constructed IR.""" return _ffi_api.IRBuilderGet(self) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index 946be263a7..b796de8113 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,9 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import decl_function, def_function, ir_module +from .ir import ( + decl_function, + def_function, + ir_module, + module_attrs, +) diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 796d6f3aad..c5276f8d13 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,6 +16,10 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from typing import Dict + +from tvm.runtime import Object as tvm_Object + from tvm.ir import BaseFunc, GlobalVar from . import _ffi_api @@ -67,3 +71,13 @@ def def_function(func_name: str, func: BaseFunc) -> None: The given function implementation """ return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member + + +def module_attrs(attrs: Dict[str, tvm_Object]) -> None: + """Specify the attrs of the ir_module frame. + Parameters + ---------- + attrs: Dict[str, Object] + The module attrs. + """ + return _ffi_api.ModuleAttrs(attrs) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/parser/ir/__init__.py b/python/tvm/script/parser/ir/__init__.py index fedd2f0a14..adda176012 100644 --- a/python/tvm/script/parser/ir/__init__.py +++ b/python/tvm/script/parser/ir/__init__.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. """The ir module parser""" - +from ...ir_builder.ir import * # pylint: disable=redefined-builtin from . import parser as _parser from .entry import ir_module -__all__ = ["ir_module"] +__all__ = ["ir_module", "module_attrs"] diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index 13b3e29859..201c99074f 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -35,11 +35,17 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: with self.var_table.with_frame(): with I.ir_module(): + with self.with_dispatch_token("ir"): + for stmt in node.body: + if not isinstance(stmt, doc.FunctionDef): + self.visit(stmt) for stmt in node.body: if isinstance(stmt, doc.FunctionDef): self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): - self.visit_body(node.body) + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit(stmt) @dispatch.register(token="ir", type_name="Assign") @@ -57,7 +63,7 @@ def _visit_assign(_self: Parser, _node: doc.Assign) -> None: @dispatch.register(token="ir", type_name="Expr") -def _visit_expr(_self: Parser, _node: doc.Expr) -> None: +def _visit_expr(self: Parser, node: doc.Expr) -> None: """The expression visiting method for ir module. Parameters @@ -68,6 +74,7 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None: node : doc.ClassDef The doc AST expression node. """ + self.eval_expr(node.value) @dispatch.register(token="default", type_name="Assign") diff --git a/src/ir/module.cc b/src/ir/module.cc index 4d5bebf708..ba66a66894 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -382,10 +382,8 @@ IRModule IRModule::FromText(const String& text, const String& source_path) { TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") - .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, - tvm::Map<GlobalTypeVar, TypeData> types) { - return IRModule(funcs, types, {}); - }); + .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> types, + tvm::DictAttrs attrs) { return IRModule(funcs, types, {}, {}, attrs); }); TVM_REGISTER_GLOBAL("ir.Module_Add") .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 8303efff4f..879db4f3d7 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -77,6 +77,11 @@ IRBuilder IRBuilder::Current() { return stack->back(); } +bool IRBuilder::IsInScope() { + std::vector<IRBuilder>* stack = ThreadLocalBuilderStack(); + return !stack->empty(); +} + namespace details { Namer::FType& Namer::vtable() { @@ -106,6 +111,7 @@ TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilder").set_body_typed([]() { return TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderEnter").set_body_method(&IRBuilder::EnterWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderExit").set_body_method(&IRBuilder::ExitWithScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderCurrent").set_body_typed(IRBuilder::Current); +TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderIsInScope").set_body_typed(IRBuilder::IsInScope); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderGet") .set_body_method<IRBuilder>(&IRBuilderNode::Get<ObjectRef>); TVM_REGISTER_GLOBAL("script.ir_builder.IRBuilderName").set_body_typed(IRBuilder::Name<ObjectRef>); diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index addf129284..92470ec653 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -38,7 +38,8 @@ void IRModuleFrameNode::ExitWithScope() { } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; - builder->result = tvm::IRModule(func_map); + auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs); + builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs); } TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 5764e90c8d..0c34f85246 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -60,9 +60,21 @@ void DefFunction(const String& func_name, const BaseFunc& func) { } } +void ModuleAttrs(Map<String, ObjectRef> attrs) { + if (IRBuilder::IsInScope()) { + // TODO(hongyi): add comments to explain why we need to check if the module frame is in scope + IRModuleFrame frame = FindModuleFrame("I.ModuleAttr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate module attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs); } // namespace ir } // namespace ir_builder diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index 065cfe5168..1c751d40f2 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -64,6 +64,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) std::sort(functions.begin(), functions.end()); With<IRFrame> f(d); (*f)->AddDispatchToken(d, "ir"); + if (mod->attrs.defined() && !mod->attrs->dict.empty()) { + (*f)->stmts.push_back( + ExprStmtDoc(IR(d, "module_attrs") // + ->Call({d->AsDoc<ExprDoc>(mod->attrs, p->Attr("attrs"))}))); + } for (const auto& entry : functions) { const GlobalVar& gv = entry.gv; const BaseFunc& func = entry.func; diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index bbc6dd45a8..52d99550be 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3725,6 +3725,19 @@ def tvm_struct_set_generated_in_cpp(): return tvm.tir.transform.LowerTVMBuiltin()(Module) +def ir_module_with_attrs(): + @I.ir_module + class Module: + I.module_attrs({"attr": 10}) + + @T.prim_func + def tir_func(A: T.Buffer(16, "int32"), B: T.Buffer(16, "int32")): + for i in range(16): + B[i] = A[i] + + return Module + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3791,6 +3804,7 @@ ir_generator = tvm.testing.parameter( if_then_else_var, tvm_shfl_builtins, tvm_struct_set_generated_in_cpp, + ir_module_with_attrs, )
