This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 11c13ace0b5cef71f50193248ecaac7e845ee25e Author: Siyuan Feng <[email protected]> AuthorDate: Wed Feb 8 22:31:47 2023 +0800 [TVMScript] IRModule TVMScript Parser. This PR adds the TVMScript parser/ir_builder support based on the blockbuilder. This commit contains the non-relax portions from https://github.com/apache/tvm/pull/13932. Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Tianqi Chen <[email protected]> Co-authored-by: Yuchen Jin <[email protected]> Co-authored-by: Steven S. Lyubomirsky <[email protected]> Co-authored-by: Yong Wu <[email protected]> --- include/tvm/script/ir_builder/ir/frame.h | 11 ++++-- include/tvm/script/ir_builder/ir/ir.h | 17 +++++++++ python/tvm/script/ir_builder/base.py | 6 ++-- python/tvm/script/ir_builder/ir/__init__.py | 2 +- python/tvm/script/ir_builder/ir/ir.py | 45 +++++++++++++++++++++++ python/tvm/script/parser/core/diagnostics.py | 2 +- python/tvm/script/parser/core/evaluator.py | 2 +- python/tvm/script/parser/core/parser.py | 50 ++++++++++++++++++-------- python/tvm/script/parser/ir/parser.py | 4 +++ python/tvm/script/parser/tir/entry.py | 4 +-- python/tvm/script/parser/tir/parser.py | 26 ++++++++++++++ src/script/ir_builder/ir/frame.cc | 12 ++++--- src/script/ir_builder/ir/ir.cc | 32 ++++++++++++++++- src/script/ir_builder/ir/{frame.cc => utils.h} | 30 +++++++++------- src/script/ir_builder/tir/frame.cc | 15 ++++++-- src/script/ir_builder/tir/utils.h | 2 +- 16 files changed, 213 insertions(+), 47 deletions(-) diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 887981ccff..dacfc361a6 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -38,12 +38,17 @@ namespace ir { */ class IRModuleFrameNode : public IRBuilderFrameNode { public: - Array<GlobalVar> global_vars; - Array<BaseFunc> functions; + /*! \brief A map from string names to global variables that ensures global uniqueness. */ + Map<String, GlobalVar> global_var_map; + /*! + * \brief A map from GlobalVar to all global functions. + * \note Only defined functions are in the map, while declared functions are not included. + */ + Map<GlobalVar, BaseFunc> functions; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_vars); + v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); } diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index f0e7cc6f5c..49bdcf60e6 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -37,6 +37,23 @@ namespace ir { */ TVM_DLL IRModuleFrame IRModule(); +/*! + * \brief Declare a Function without given the specific function implementation. + * \note It is usually used in cross-function call. And we can specify the function by `DefFunction` + * \param func_name The function unique name. + * \param func_signature A Function w/o body, which used to specify the function signature + * (i.e. func params and func return type/shape). + * \return The corresponding GlobalVar. + */ +TVM_DLL GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature); + +/*! + * \brief Define the function which is declared before. + * \param func_name The function unique name. + * \param func The given function implementation + */ +TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); + } // namespace ir } // namespace ir_builder } // namespace script diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 7aa33ee49c..b35bbd0a7d 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -64,8 +64,10 @@ class IRBuilderFrame(_Object): _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member return self - def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument - _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member + def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument + if exc_type is None and exc_value is None: + # Do not execute `FrameExit` if the with scope exits because of exceptions + _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member def add_callback(self, callback: Callable[[], None]) -> None: """Add a callback method invoked when exiting the with-scope. diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index ebb9728737..946be263a7 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,4 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import ir_module +from .ir import decl_function, def_function, ir_module diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 213180463c..796d6f3aad 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,9 +16,54 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from tvm.ir import BaseFunc, GlobalVar + from . import _ffi_api from .frame import IRModuleFrame def ir_module() -> IRModuleFrame: + """Start a ir_module frame. + Returns + ------- + frame: IRModuleFrame + The constructed frame. + """ return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_function(func_name: str, func_signature: BaseFunc) -> GlobalVar: + """Declare a Function without given the specific function implementation. + Parameters + ---------- + func_name : str + The function unique name. + + func_signature: Optional[BaseFunc] + A Function w/o body, which used to specify the function signature + (i.e. func params and func return type/shape). + + Note + ---- + It is usually used in cross-function call. And we can specify the function by `DefFunction` + Returns + ------- + gv : GlobalVar + The corresponding GlobalVar. + """ + + return _ffi_api.DeclFunction( # type: ignore[attr-defined] # pylint: disable=no-member + func_name, func_signature + ) + + +def def_function(func_name: str, func: BaseFunc) -> None: + """Define the function which is declared before. + Parameters + ---------- + func_name : str + The function unique name. + func: BaseFunc + The given function implementation + """ + return _ffi_api.DefFunction(func_name, func) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py index ad7ae50347..2767a97f60 100644 --- a/python/tvm/script/parser/core/diagnostics.py +++ b/python/tvm/script/parser/core/diagnostics.py @@ -220,7 +220,7 @@ class Diagnostics: level : diagnostics.DiagnosticLevel The diagnostic level. """ - lineno = node.lineno or self.source.start_line + lineno = node.lineno or 1 col_offset = node.col_offset or self.source.start_column end_lineno = node.end_lineno or lineno end_col_offset = node.end_col_offset or col_offset diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index 3a72a3c331..075aedd891 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -203,7 +203,7 @@ class ExprEvaluator: else: value = self._eval_expr(node.__class__(**fields)) except Exception as e: # pylint: disable=broad-except,invalid-name - self.parser.report_error(node, str(e)) + self.parser.report_error(node, e) return self._add_intermediate_result(value) def _eval_lambda(self, node: doc.Lambda) -> Any: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index fdccabcd23..837b7cce5d 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -60,6 +60,10 @@ def _deferred(exit_f: Callable[[], None]): return context() +def _do_nothing(*args, **kwargs): # pylint: disable=unused-argument + pass + + class VarTableFrame: """The variable table frame. A frame of variable table stores the variables created in one block or scope. @@ -260,6 +264,17 @@ class Parser(doc.NodeVisitor): node = self.diag.source.as_ast() self.visit(node) + def get_dispatch_token(self, node: doc.FunctionDef) -> str: + if not isinstance(node, doc.FunctionDef): + self.report_error(node, "Only can get dispatch token for function.") + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + return decorator.dispatch_token + def with_dispatch_token(self, token: str): """Add a new dispatching token as with statement. @@ -389,6 +404,8 @@ class Parser(doc.NodeVisitor): # Only take the last line of the error message if isinstance(err, TVMError): msg = list(filter(None, str(err).split("\n")))[-1] + elif isinstance(err, KeyError): + msg = "KeyError: " + str(err) else: msg = str(err) self.diag.error(node, msg) @@ -458,30 +475,33 @@ class Parser(doc.NodeVisitor): """ return _dispatch(self, "tvm_annotation")(self, node) - def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name - """The general function definition visiting method. + def visit_FunctionDef(self, node: doc.FunctionDef) -> None: # pylint: disable=invalid-name + """The general function definition visit method. Parameters ---------- node : doc.FunctionDef - The doc AST function definition node. - - Returns - ------- - res : Any - The visiting result. + The doc FunctionDef node. """ - if not node.decorator_list: - self.report_error(node, "Function must be decorated") - # TODO: only the last decorator is parsed - decorator = self.eval_expr(node.decorator_list[-1]) - if not hasattr(decorator, "dispatch_token"): - self.report_error(node, "The parser does not understand the decorator") - token = decorator.dispatch_token + token = self.get_dispatch_token(node) + current_token = self.dispatch_tokens[-1] func = dispatch.get(token=token, type_name="FunctionDef", default=None) if func is None: self.report_error(node, "The parser does not understand the decorator") + pre_func = dispatch.get( + token=current_token, type_name="pre_token_switch", default=_do_nothing + ) + post_func = dispatch.get( + token=current_token, type_name="post_token_switch", default=_do_nothing + ) + pre_func(self, node) _dispatch_wrapper(func)(self, node) + post_func(self, node) + + def visit_tvm_declare_function(self, node: doc.FunctionDef) -> None: + token = self.get_dispatch_token(node) + with self.with_dispatch_token(token): + _dispatch(self, "tvm_declare_function")(self, node) def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name """The general class definition visiting method. diff --git a/python/tvm/script/parser/ir/parser.py b/python/tvm/script/parser/ir/parser.py index e0268412d2..13b3e29859 100644 --- a/python/tvm/script/parser/ir/parser.py +++ b/python/tvm/script/parser/ir/parser.py @@ -32,8 +32,12 @@ def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: node : doc.ClassDef The doc AST class definition node. """ + with self.var_table.with_frame(): with I.ir_module(): + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + self.visit_tvm_declare_function(stmt) with self.with_dispatch_token("ir"): self.visit_body(node.body) diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tir/entry.py index 411a7f8f3c..649f817411 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tir/entry.py @@ -83,7 +83,7 @@ class BufferProxy: return self(keys) if len(keys) >= 2 and not isinstance(keys[1], str): return self(keys) - return self(*keys) # pylint: disable=no-member # type: ignore + return self(*keys) # type: ignore[attr-defined] # pylint: disable=no-member class PtrProxy: @@ -93,7 +93,7 @@ class PtrProxy: def __call__(self, dtype, storage_scope="global"): if callable(dtype): dtype = dtype().dtype - return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + return ptr(dtype, storage_scope) # type: ignore[attr-defined] # pylint: disable=no-member @deprecated("T.Ptr[...]", "T.handle(...)") def __getitem__(self, keys): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 8a067267a3..63171f6722 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -24,6 +24,7 @@ import tvm from tvm.ir import PrimType from tvm.tir import Buffer, IterVar, PrimExpr, Var +from ...ir_builder import ir as I from ...ir_builder import tir as T from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame @@ -473,3 +474,28 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ self.report_error(node, "Return is not allowed.") + + [email protected](token="tir", type_name="tvm_declare_function") +def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> None: + """The function declaration step for tir + + Parameters + ---------- + self : Parser + The visiting parser. + + node : doc.Return + The doc AST return node. + """ + + ret_type = None + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + + # Only ret_type is needed for func_signature. + func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) + global_var = I.decl_function(node.name, func_signature) + self.var_table.add(node.name, global_var) diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index a81c56922d..addf129284 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -26,11 +26,15 @@ namespace ir_builder { namespace ir { void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); Map<GlobalVar, BaseFunc> func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); + CHECK_EQ(functions.size(), global_var_map.size()) + << "All functions must be defined in the IRModule. Got " << global_var_map.size() + << "declared function(s), but only " << functions.size() << "defined function(s)."; + for (const auto& kv : functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined"; + func_map.Set(gv, func); } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index a8cc452e4f..5764e90c8d 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -20,6 +20,8 @@ #include <tvm/runtime/registry.h> #include <tvm/script/ir_builder/ir/ir.h> +#include "./utils.h" + namespace tvm { namespace script { namespace ir_builder { @@ -27,12 +29,40 @@ namespace ir { IRModuleFrame IRModule() { ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>(); - n->global_vars.clear(); + n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); } +GlobalVar DeclFunction(const String& func_name, const BaseFunc& func_signature) { + IRModuleFrame frame = FindModuleFrame("I.DeclFunction"); + CHECK(!frame->global_var_map.count(func_name)) + << "ValueError: function " << func_name << " already exists"; + GlobalVar gv = GlobalVar(func_name); + CHECK(frame->functions.find(gv) == frame->functions.end()) + << "ValueError: function " << func_name << " has already been defined."; + frame->global_var_map.Set(func_name, gv); + if (func_signature.defined()) { + frame->functions.Set(gv, func_signature); + } + return gv; +} + +void DefFunction(const String& func_name, const BaseFunc& func) { + IRModuleFrame frame = FindModuleFrame("I.DefFunction"); + auto it = frame->global_var_map.find(func_name); + CHECK(it != frame->global_var_map.end()) + << "ValueError: function " << func_name << " does not exist, please declare it first."; + const GlobalVar& gv = (*it).second; + frame->functions.Set(gv, func); + if (func->checked_type_.defined()) { + gv->checked_type_ = func->checked_type_; + } +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/utils.h similarity index 59% copy from src/script/ir_builder/ir/frame.cc copy to src/script/ir_builder/ir/utils.h index a81c56922d..58d5e53f70 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/utils.h @@ -16,8 +16,9 @@ * specific language governing permissions and limitations * under the License. */ -#include <tvm/ir/module.h> -#include <tvm/runtime/registry.h> +#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ + #include <tvm/script/ir_builder/ir/frame.h> namespace tvm { @@ -25,21 +26,24 @@ namespace script { namespace ir_builder { namespace ir { -void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); - Map<GlobalVar, BaseFunc> func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); - } +inline IRModuleFrame FindModuleFrame(const String& method) { IRBuilder builder = IRBuilder::Current(); - ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; - builder->result = tvm::IRModule(func_map); + if (Optional<IRModuleFrame> frame = builder->FindFrame<IRModuleFrame>()) { + const Optional<IRModuleFrame>& last_module_frame = builder->GetLastFrame<IRModuleFrame>(); + if (last_module_frame.defined() && last_module_frame.value() == frame) { + return frame.value(); + } + } else { + LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method + << "' is called under I.ir_module()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()"; + throw; } -TVM_REGISTER_NODE_TYPE(IRModuleFrameNode); - } // namespace ir } // namespace ir_builder } // namespace script } // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 1e63201a40..dd8d3c2ed3 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include <tvm/script/ir_builder/ir/ir.h> #include <tvm/script/ir_builder/tir/frame.h> #include <tvm/tir/function.h> @@ -41,9 +42,17 @@ void PrimFuncFrameNode::ExitWithScope() { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; } else if (Optional<ir::IRModuleFrame> opt_frame = builder->FindFrame<ir::IRModuleFrame>()) { - ir::IRModuleFrame frame = opt_frame.value(); - frame->global_vars.push_back(GlobalVar(name.value_or(""))); - frame->functions.push_back(func); + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const ir::IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // Case. First time visiting the function. + ir::DeclFunction(func_name, func); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); } else { LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; } diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tir/utils.h index 7ccc132fa1..f3b547532c 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tir/utils.h @@ -87,7 +87,7 @@ inline PrimFuncFrame FindPrimFuncFrame(const String& method) { * \return The top frame of BlockFrame. */ inline BlockFrame FindBlockFrame(const String& method) { - if (Optional<BlockFrame> frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) { + if (Optional<BlockFrame> frame = IRBuilder::Current()->FindFrame<BlockFrame>()) { return frame.value(); } else if (Optional<BlockFrame> frame = IRBuilder::Current()->FindFrame<BlockFrame>()) { LOG(FATAL) << "ValueError: " << method << " must be called at the top of a T.block(). "
