This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new c662638 Rename tvm.hybrid.script to tvm.script. (#6522)
c662638 is described below
commit c6626387720dbbc5f03b3effe8d7315be2819135
Author: Tristan Konolige <[email protected]>
AuthorDate: Sat Sep 26 15:32:08 2020 -0700
Rename tvm.hybrid.script to tvm.script. (#6522)
---
docs/dev/hybrid_script.rst | 14 +-
python/tvm/__init__.py | 3 -
python/tvm/{hybrid => script}/__init__.py | 4 +-
python/tvm/{hybrid => script}/_ffi_api.py | 4 +-
python/tvm/{hybrid => script}/intrin.py | 2 +-
python/tvm/{hybrid => script}/meta_unparser.py | 0
python/tvm/{hybrid => script}/parser.py | 48 +++----
python/tvm/{hybrid => script}/registry.py | 16 +--
python/tvm/{hybrid => script}/scope_emitter.py | 4 +-
python/tvm/{hybrid => script}/scope_handler.py | 8 +-
python/tvm/{hybrid => script}/special_stmt.py | 2 +-
python/tvm/{hybrid => script}/ty.py | 12 +-
python/tvm/{hybrid => script}/utils.py | 33 +++--
src/printer/tir_text_printer.cc | 42 +++---
...{tir_hybrid_printer.cc => tvmscript_printer.cc} | 145 +++++++++++----------
src/tir/ir/stmt.cc | 2 +-
...or_report.py => test_tvmscript_error_report.py} | 26 ++--
...id_roundtrip.py => test_tvmscript_roundtrip.py} | 26 ++--
18 files changed, 200 insertions(+), 191 deletions(-)
diff --git a/docs/dev/hybrid_script.rst b/docs/dev/hybrid_script.rst
index 939cf05..33a65f2 100644
--- a/docs/dev/hybrid_script.rst
+++ b/docs/dev/hybrid_script.rst
@@ -31,7 +31,7 @@ Features
Software Emulation
~~~~~~~~~~~~~~~~~~
-In software emulation, the most interesting thing is the decorator
``tvm.hybrid.script``.
+In software emulation, the most interesting thing is the decorator
``tvm.te.hybrid.script``.
This decorator helps 2 things:
1. Importing runtime variables
@@ -40,7 +40,7 @@ This decorator helps 2 things:
Correct me if I am wrong: I believe that how 1. is implemented is dangerous,
but I have no
choice. What I did is to add those names into python dict ``func.__global__``
and after
-the call to ``func`` is done, those names will be cleaned up.
+the call to ``func`` is done, those names will be cleaned up.
Overload is simple: the decorator checks the arguments' types and determines
which function
should be actually called.
@@ -49,16 +49,16 @@ should be actually called.
Backend Compilation
~~~~~~~~~~~~~~~~~~~
-Compilation is a large module, you can see ``python/tvm/hybrid/var_decl.py``
and
-``python/tvm/hybrid/parser.py`` for more details. The first stage determines
the
-usage, or more accurately the declaration of each variable and the second
stage does
-the actual IR generation.
+Compilation is a large module, you can see ``python/tvm/te/hybrid/`` for more
+details. The first stage determines the usage, or more accurately the
+declaration of each variable and the second stage does the actual IR
+generation.
Attributes
~~~~~~~~~~
So far, ONLY tensors' `shape` attribute is supported. You can see
``visit_Subscript``
-in ``python/tvm/hybrid/parser.py`` for more details. This is a hacky solution,
I just
+in ``python/tvm/te/hybrid/parser.py`` for more details. This is a hacky
solution, I just
check the attributes when subscript.
Loops
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index d3473c6..c3c5e3d 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -57,9 +57,6 @@ from .driver import build, lower
# tvm.parser
from . import parser
-# tvm tir hybrid script
-from . import hybrid
-
# others
from . import arith
diff --git a/python/tvm/hybrid/__init__.py b/python/tvm/script/__init__.py
similarity index 87%
rename from python/tvm/hybrid/__init__.py
rename to python/tvm/script/__init__.py
index 7c3ef75..4b9f073 100644
--- a/python/tvm/hybrid/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Hybrid Script APIs of TVM Python Package, aimed to support TIR"""
+"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
-from .utils import create_module, ashybrid, script
+from .utils import create_module, asscript, tir, module
from .parser import from_source
diff --git a/python/tvm/hybrid/_ffi_api.py b/python/tvm/script/_ffi_api.py
similarity index 91%
rename from python/tvm/hybrid/_ffi_api.py
rename to python/tvm/script/_ffi_api.py
index 929a65c..92c3890 100644
--- a/python/tvm/hybrid/_ffi_api.py
+++ b/python/tvm/script/_ffi_api.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.hybrid"""
+"""FFI APIs for tvm.tvmscript"""
import tvm._ffi
-tvm._ffi._init_api("hybrid", __name__)
+tvm._ffi._init_api("script", __name__)
diff --git a/python/tvm/hybrid/intrin.py b/python/tvm/script/intrin.py
similarity index 98%
rename from python/tvm/hybrid/intrin.py
rename to python/tvm/script/intrin.py
index fdd48f3..21570b9 100644
--- a/python/tvm/hybrid/intrin.py
+++ b/python/tvm/script/intrin.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Hybrid Script Parser Intrinsic Functions
+"""TVM Script Parser Intrinsic Functions
IRNodes (StmtNodes without body, PrimExprNodes and more) are called intrins
"""
diff --git a/python/tvm/hybrid/meta_unparser.py
b/python/tvm/script/meta_unparser.py
similarity index 100%
rename from python/tvm/hybrid/meta_unparser.py
rename to python/tvm/script/meta_unparser.py
diff --git a/python/tvm/hybrid/parser.py b/python/tvm/script/parser.py
similarity index 94%
rename from python/tvm/hybrid/parser.py
rename to python/tvm/script/parser.py
index a1aa652..56710fc 100644
--- a/python/tvm/hybrid/parser.py
+++ b/python/tvm/script/parser.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Hybrid Script Parser For TIR"""
+"""TVM Script Parser For TIR"""
# pylint: disable=invalid-name, missing-docstring,
inconsistent-return-statements, no-else-return
# pylint: disable=unnecessary-comprehension, unused-argument,
import-outside-toplevel
# pylint: disable=unused-import
@@ -35,16 +35,16 @@ from .registry import Registry
from . import _ffi_api
-class HybridParserError(RuntimeError):
- """Hybrid Parser Runtime Error"""
+class TVMScriptParserError(RuntimeError):
+ """TVM script Parser Runtime Error"""
-class HybridParser(ast.NodeVisitor):
+class TVMScriptParser(ast.NodeVisitor):
"""Python AST visitor pass which finally lowers it to TIR
Notes for extension:
1. To support new types of AST nodes. Add a function visit_xxx().
2. To support new functions
- We divide allowed function calls in hybrid script into 3 categories,
+ We divide allowed function calls in TVM script into 3 categories,
which is intrin, scope_handler and special_stmt.
1) intrin functions ought to have return value.
User can also register intrin category function into parser.
@@ -168,7 +168,7 @@ class HybridParser(ast.NodeVisitor):
lineno = self.current_lineno
if col_offset is None:
col_offset = self.current_col_offset
- raise HybridParserError(self.wrap_line_col(message, lineno,
col_offset))
+ raise TVMScriptParserError(self.wrap_line_col(message, lineno,
col_offset))
def get_body(self):
body = []
@@ -196,20 +196,20 @@ class HybridParser(ast.NodeVisitor):
"""Module visitor
AST abstract grammar:
Module(stmt* body, type_ignore* type_ignore)
- By now we support two format of hybrid script shown below.
+ By now we support two format of TVM script shown below.
Example
-------
- 1. Generate a Function(If the code is printed, then it may bring meta)
+ 1. Generate a PrimFunc (If the code is printed, then it may also
contain metadata)
.. code-block:: python
import tvm
- @tvm.hybrid.script
+ @tvm.script
def A(...):
...
- # call hybrid parser when call this function, get a Function
+ # returns a PrimFunc
func = A
2. Generate an IRModule
@@ -217,7 +217,7 @@ class HybridParser(ast.NodeVisitor):
import tvm
- @tvm.hybrid.script
+ @tvm.script
class MyMod():
def A(...):
...
@@ -227,7 +227,7 @@ class HybridParser(ast.NodeVisitor):
__tvm_meta__ = ...
- # call hybrid parser during construction, get an IRModule
+ # returns an IRModule
mod = MyMod()
"""
@@ -237,7 +237,7 @@ class HybridParser(ast.NodeVisitor):
elif len(node.body) == 2:
if isinstance(node.body[0], ast.Assign):
node.body[0], node.body[1] = node.body[1], node.body[0]
- if isinstance(node.body[0], ast.FunctionDef) and
HybridParser.is_meta(node.body[1]):
+ if isinstance(node.body[0], ast.FunctionDef) and
TVMScriptParser.is_meta(node.body[1]):
# function with meta
self.init_meta(MetaUnparser().visit(node.body[1].value))
return self.visit(node.body[0])
@@ -257,7 +257,7 @@ class HybridParser(ast.NodeVisitor):
for body_element in node.body:
if isinstance(body_element, ast.FunctionDef):
pass
- elif HybridParser.is_meta(body_element) and not count:
+ elif TVMScriptParser.is_meta(body_element) and not count:
count = True
self.init_meta(MetaUnparser().visit(body_element.value))
else:
@@ -526,9 +526,9 @@ class HybridParser(ast.NodeVisitor):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
- if not isinstance(node.op, tuple(HybridParser._binop_maker.keys())):
+ if not isinstance(node.op, tuple(TVMScriptParser._binop_maker.keys())):
self.report_error("BinOp " + str(type(node.op)) + " is not
supported now")
- return HybridParser._binop_maker[type(node.op)](lhs, rhs)
+ return TVMScriptParser._binop_maker[type(node.op)](lhs, rhs)
def visit_Compare(self, node):
"""Compare visitor
@@ -542,7 +542,7 @@ class HybridParser(ast.NodeVisitor):
for i in range(len(node.ops)):
lhs = ops[i]
rhs = ops[i + 1]
- res.append(HybridParser._binop_maker[type(node.ops[i])](lhs, rhs))
+ res.append(TVMScriptParser._binop_maker[type(node.ops[i])](lhs,
rhs))
return _all(*res)
def visit_BoolOp(self, node):
@@ -552,7 +552,7 @@ class HybridParser(ast.NodeVisitor):
"""
values = [self.visit(value) for value in node.values]
- return HybridParser._binop_maker[type(node.op)](*values)
+ return TVMScriptParser._binop_maker[type(node.op)](*values)
def visit_UnaryOp(self, node):
"""UnaryOp visitor
@@ -561,9 +561,9 @@ class HybridParser(ast.NodeVisitor):
"""
operand = self.visit(node.operand)
- if not isinstance(node.op, tuple(HybridParser._unaryop_maker.keys())):
+ if not isinstance(node.op,
tuple(TVMScriptParser._unaryop_maker.keys())):
self.report_error("UnaryOp " + str(type(node.op)) + " is not
supported now")
- return HybridParser._unaryop_maker[type(node.op)](operand)
+ return TVMScriptParser._unaryop_maker[type(node.op)](operand)
def visit_Subscript(self, node):
"""Subscript visitor
@@ -734,11 +734,11 @@ def from_source(src, func_lineno=0):
"""
root = ast.parse(src)
- parser = HybridParser(src, func_lineno)
+ parser = TVMScriptParser(src, func_lineno)
try:
return parser.visit(root)
- except HybridParserError as e:
+ except TVMScriptParserError as e:
raise e
except TVMError as e:
# TVM internal c++ error, we have to process the error message and
inject line info
@@ -752,7 +752,7 @@ def from_source(src, func_lineno=0):
raise TVMError("\n".join(inject_e))
except Exception as e:
inject_e = parser.wrap_line_col(str(e), parser.current_lineno,
parser.current_col_offset)
- raise HybridParserError(inject_e)
+ raise TVMScriptParserError(inject_e)
-tvm._ffi._init_api("hybrid", __name__)
+tvm._ffi._init_api("script", __name__)
diff --git a/python/tvm/hybrid/registry.py b/python/tvm/script/registry.py
similarity index 95%
rename from python/tvm/hybrid/registry.py
rename to python/tvm/script/registry.py
index a1b2b3c..acbc444 100644
--- a/python/tvm/hybrid/registry.py
+++ b/python/tvm/script/registry.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Hybrid Script Parser Function Registry """
+"""TVM Script Parser Function Registry """
# pylint: disable=inconsistent-return-statements
import inspect
from enum import Enum
@@ -217,14 +217,14 @@ def get_arg_list(origin_func, category, with_var=False):
if not with_var:
if len(args) < 3 or args[0] != "parser" or args[1] != "node" or
args[2] != "body":
raise RuntimeError(
- "TVM Hybrid Script register error : the first three
arguments of "
+ "TVM Script register error : the first three arguments of "
"this with scope handler must be parser, node, body"
)
args = args[3:]
else:
if len(args) < 2 or args[0] != "parser" or args[1] != "node":
raise RuntimeError(
- "TVM Hybrid Script register error : the first two
arguments of "
+ "TVM Script register error : the first two arguments of "
"this with scope handler must be parser, node"
)
args = args[2:]
@@ -237,26 +237,24 @@ def get_arg_list(origin_func, category, with_var=False):
or args[3] != "loop_vars"
):
raise RuntimeError(
- "TVM Hybrid Script register error : the first three arguments
of for scope handler"
+ "TVM Script register error : the first three arguments of for
scope handler"
"must be parser, node, body, loop_vars"
)
args = args[4:]
elif category == Category.SPECIAL_STMT:
if len(args) < 2 or args[0] != "parser" or args[1] != "node":
raise RuntimeError(
- "TVM Hybrid Script register error : the first three arguments
of special stmt"
+ "TVM Script register error : the first three arguments of
special stmt"
"must be parser, node"
)
args = args[2:]
if full_arg_spec.varkw is not None:
raise RuntimeError(
- "TVM Hybrid Script register error : variable keyword argument is
not supported now"
+ "TVM Script register error : variable keyword argument is not
supported now"
)
if not len(full_arg_spec.kwonlyargs) == 0:
- raise RuntimeError(
- "TVM Hybrid Script register error : keyword only argument is not
supported now"
- )
+ raise RuntimeError("TVM Script register error : keyword only argument
is not supported now")
pos_only = list()
for arg in args[: len(args) - len(defaults)]:
diff --git a/python/tvm/hybrid/scope_emitter.py
b/python/tvm/script/scope_emitter.py
similarity index 94%
rename from python/tvm/hybrid/scope_emitter.py
rename to python/tvm/script/scope_emitter.py
index 629f44b..69ad267 100644
--- a/python/tvm/hybrid/scope_emitter.py
+++ b/python/tvm/script/scope_emitter.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Hybrid Script Scope Emitter for TIR"""
+"""TVM Script Scope Emitter for TIR"""
from tvm.te import schedule
@@ -52,7 +52,7 @@ class ScopeEmitter:
if name in symbols:
symbols.pop(name)
return
- raise RuntimeError("Internal error of hybrid parser: no symbol named"
+ name)
+ raise RuntimeError("Internal error of tvm script parser: no symbol
named" + name)
def lookup_symbol(self, name):
"""Look up symbol by name"""
diff --git a/python/tvm/hybrid/scope_handler.py
b/python/tvm/script/scope_handler.py
similarity index 97%
rename from python/tvm/hybrid/scope_handler.py
rename to python/tvm/script/scope_handler.py
index 126a3dc..08cd7ca 100644
--- a/python/tvm/hybrid/scope_handler.py
+++ b/python/tvm/script/scope_handler.py
@@ -14,12 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Hybrid Script Parser Scope Handler Functions
+"""TVM Script Parser Scope Handler Functions
This module provides the functions registered into parser under with_scope or
for_scope category.
Scope handler nodes are StmtNodes with body, which are used to handle such
scenarios.
1. For scope handler
When registering a for scope handler, the first 4 arguments must be parser,
node, body, loop_vars
-and these arguments will provided by Hybrid Script parser automatically
+and these arguments will provided by TVM Script parser automatically
.. code-block:: python
for loop_vars in tir.xxx():
2. With scope handler
@@ -41,14 +41,14 @@ Example : None atm
with tir.xxx() as target:
3) without as & concise
the first 3 arguments must be parser, node, body
-Hybrid Script parser will parse the body automatically
+TVM Script parser will parse the body automatically
Example : tir.allocate()/tir.realize()/tir.attr()
.. code-block:: python
tir.xxx()
with tir.xxx():
4) without as & not concise
the first 3 arguments must be parser, node, body
-Hybrid Script parser will parse the body automatically
+TVM Script parser will parse the body automatically
Example : tir.assert()/tir.let()
.. code-block:: python
with tir.xxx():
diff --git a/python/tvm/hybrid/special_stmt.py
b/python/tvm/script/special_stmt.py
similarity index 98%
rename from python/tvm/hybrid/special_stmt.py
rename to python/tvm/script/special_stmt.py
index f080071..53c01d4 100644
--- a/python/tvm/hybrid/special_stmt.py
+++ b/python/tvm/script/special_stmt.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Hybrid Script Parser Special Stmt Functions
+"""TVM Script Parser Special Stmt Functions
This module provides the functions registered into parser under special_stmt
category.
special_stmt functions don't correspond to an IRNode in the AST directly. It
is usually
used for some information that is not suitable to be printed directly.
diff --git a/python/tvm/hybrid/ty.py b/python/tvm/script/ty.py
similarity index 82%
rename from python/tvm/hybrid/ty.py
rename to python/tvm/script/ty.py
index c309fbe..430a746 100644
--- a/python/tvm/hybrid/ty.py
+++ b/python/tvm/script/ty.py
@@ -14,9 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Hybrid Script Parser Typing Class
+"""TVM Script Parser Typing Class
-This module provides typing class for hybrid script type annotation usage, it
can be viewed as
+This module provides typing class for TVM script type annotation usage, it can
be viewed as
a wrapper for uniform Type system in IR
"""
# pylint: disable=invalid-name
@@ -24,14 +24,14 @@ import tvm
class TypeGeneric:
- """Base class for all the hybrid script typing class"""
+ """Base class for all the TVM script typing class"""
def evaluate(self):
raise TypeError("Cannot get tvm.Type from a generic type")
class ConcreteType(TypeGeneric):
- """Hybrid script typing class for uniform Type objects"""
+ """TVM script typing class for uniform Type objects"""
def __init__(self, vtype):
self.type = vtype
@@ -41,7 +41,7 @@ class ConcreteType(TypeGeneric):
class GenericPtrType(TypeGeneric):
- """Hybrid script typing class generator for PtrType
+ """TVM script typing class generator for PtrType
[] operator is overloaded, accepts a ConcreteType and returns a
ConcreteType wrapping PtrType
"""
@@ -51,7 +51,7 @@ class GenericPtrType(TypeGeneric):
class GenericTupleType(TypeGeneric):
- """Hybrid script typing class generator for TupleType
+ """TVM script typing class generator for TupleType
[] operator is overloaded, accepts a list of ConcreteType and returns a
ConcreteType
wrapping TupleType
diff --git a/python/tvm/hybrid/utils.py b/python/tvm/script/utils.py
similarity index 74%
rename from python/tvm/hybrid/utils.py
rename to python/tvm/script/utils.py
index 7880fd7..f510ddb 100644
--- a/python/tvm/hybrid/utils.py
+++ b/python/tvm/script/utils.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Helper functions in Hybrid Script Parser"""
+"""Helper functions in TVM Script Parser"""
import inspect
from tvm import IRModule
@@ -40,7 +40,7 @@ def create_module(functions=None):
return IRModule(functions=functions)
-def ashybrid(input_ir, show_meta=False):
+def asscript(input_ir, show_meta=False):
"""Transform a PrimFunc or IRModule to python syntax script
Parameters
@@ -57,13 +57,13 @@ def ashybrid(input_ir, show_meta=False):
The Python script
"""
- return _ffi_api.AsHybrid(input_ir, show_meta)
+ return _ffi_api.AsTVMScript(input_ir, show_meta)
-def script(script_in):
- """Decorate a python function or class as hybrid script.
+def tir(script_in):
+ """Decorate a python function or class as tvm script.
- The hybrid function or parsing support parsing to the internal TIR.
+ The tvm function or parsing support parsing to the internal TIR.
Returns
-------
@@ -75,22 +75,35 @@ def script(script_in):
return _parse(script_in)
if inspect.isclass(script_in):
- return HybridClass(script_in)
+ return TVMScriptClass(script_in)
raise TypeError("Only function and class are supported")
-class HybridClass:
+def module(script_in):
+ """Decorate a python function or class as tvm script.
+
+ Alias for tvm.script.tir for now.
+
+ Returns
+ -------
+ output : Union[Function, Module]
+ The Function or Module in IR.
+ """
+ return tir(script_in)
+
+
+class TVMScriptClass:
"""Helper class for decorating a class"""
def __init__(self, script_in):
self.script = script_in
def __call__(self, *args, **kwargs):
- # call the parser to transform hybrid script into TIR
+ # call the parser to transform tvm script into TIR
return _parse(self.script)
def _parse(script_in):
- """Helper function to parse hybrid_script into TIR"""
+ """Helper function to parse TVM script into TIR"""
return from_source(inspect.getsource(script_in),
inspect.getsourcelines(script_in)[1])
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index 132b12c..7feb0b5 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -246,27 +246,27 @@ Doc TIRTextPrinter::VisitExpr_(const VarNode* op) {
return meta_->InMeta(var) ? meta_->GetMetaNode(var) :
AllocVar(GetRef<Var>(op));
}
-#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString) \
- Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \
- Doc doc; \
- doc << "(" << Print(op->a) << OpString; \
- doc << Print(op->b) << ")"; \
- return doc; \
- }
-
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " && ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " || ")
+#define TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OpName, OpString) \
+ Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \
+ Doc doc; \
+ doc << "(" << Print(op->a) << OpString; \
+ doc << Print(op->b) << ")"; \
+ return doc; \
+ }
+
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(AddNode, " + ")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(SubNode, " - ")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(MulNode, "*")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(DivNode, " / ")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(ModNode, " % ")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(EQNode, " == ")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(NENode, " != ")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(LTNode, " < ")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(LENode, " <= ")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(GTNode, " > ")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(GENode, " >= ")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(AndNode, " && ")
+TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OrNode, " || ")
Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) {
Doc doc;
diff --git a/src/printer/tir_hybrid_printer.cc
b/src/printer/tvmscript_printer.cc
similarity index 86%
rename from src/printer/tir_hybrid_printer.cc
rename to src/printer/tvmscript_printer.cc
index 0fadf17..5add7c1 100644
--- a/src/printer/tir_hybrid_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -18,7 +18,7 @@
*/
/*!
- * \file printer/tir_hybrid_printer.cc
+ * \file printer/tvmscript_printer.cc
* \brief Printer class to print Tensor IR to python syntax script
*/
@@ -42,11 +42,11 @@
namespace tvm {
namespace tir {
-class TIRHybridPrinter : public StmtFunctor<Doc(const Stmt&)>,
+class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
public ExprFunctor<Doc(const PrimExpr&)>,
public TypeFunctor<Doc(const Type&)> {
public:
- explicit TIRHybridPrinter(bool show_meta,
+ explicit TVMScriptPrinter(bool show_meta,
runtime::TypedPackedFunc<std::string(Stmt)>
annotate = nullptr)
: show_meta_(show_meta), annotate_(std::move(annotate)),
meta_collector_(&meta_) {}
@@ -225,7 +225,7 @@ class TIRHybridPrinter : public StmtFunctor<Doc(const
Stmt&)>,
}
};
-Doc TIRHybridPrinter::GetUniqueName(std::string prefix) {
+Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
std::replace(prefix.begin(), prefix.end(), '.', '_');
std::string unique_prefix = prefix;
auto it = name_alloc_map_.find(prefix);
@@ -237,7 +237,7 @@ Doc TIRHybridPrinter::GetUniqueName(std::string prefix) {
return Doc::Text(unique_prefix);
}
-Doc TIRHybridPrinter::AllocVar(const Var& var) {
+Doc TVMScriptPrinter::AllocVar(const Var& var) {
const auto& it = memo_var_.find(var);
if (it != memo_var_.end()) {
return it->second;
@@ -251,7 +251,7 @@ Doc TIRHybridPrinter::AllocVar(const Var& var) {
return val;
}
-Doc TIRHybridPrinter::AllocBufferDeclaration(const Buffer& buf) {
+Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
Doc doc = Print(buf->shape);
if (!runtime::TypeEqual(buf->dtype, DataType::Float(32))) {
doc << ", dtype=" << PrintDType(buf->dtype);
@@ -293,7 +293,7 @@ Doc TIRHybridPrinter::AllocBufferDeclaration(const Buffer&
buf) {
return doc;
}
-Doc TIRHybridPrinter::AllocBuf(const Buffer& buffer) {
+Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) {
const auto& it = memo_buf_.find(buffer);
if (it != memo_buf_.end()) {
return it->second;
@@ -308,7 +308,7 @@ Doc TIRHybridPrinter::AllocBuf(const Buffer& buffer) {
return val;
}
-Doc TIRHybridPrinter::Print(const ObjectRef& node) {
+Doc TVMScriptPrinter::Print(const ObjectRef& node) {
if (!node.defined()) return Doc::Text("None");
if (node->IsInstance<StmtNode>()) {
return PrintOptionalInfo(Downcast<Stmt>(node)) <<
VisitStmt(Downcast<Stmt>(node));
@@ -336,27 +336,27 @@ Doc TIRHybridPrinter::Print(const ObjectRef& node) {
}
}
-Doc TIRHybridPrinter::VisitExprDefault_(const Object* op) {
+Doc TVMScriptPrinter::VisitExprDefault_(const Object* op) {
meta_collector_.Collect(GetRef<ObjectRef>(op));
return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
}
-Doc TIRHybridPrinter::VisitStmtDefault_(const Object* op) {
+Doc TVMScriptPrinter::VisitStmtDefault_(const Object* op) {
meta_collector_.Collect(GetRef<ObjectRef>(op));
return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
}
-Doc TIRHybridPrinter::VisitExpr_(const IntImmNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const IntImmNode* op) {
return PrintConstScalar<int64_t>(op->dtype, &(op->value));
}
-Doc TIRHybridPrinter::VisitExpr_(const FloatImmNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const FloatImmNode* op) {
return PrintConstScalar<double>(op->dtype, &(op->value));
}
-Doc TIRHybridPrinter::VisitExpr_(const StringImmNode* op) { return
Doc::StrLiteral(op->value); }
+Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op) { return
Doc::StrLiteral(op->value); }
-Doc TIRHybridPrinter::VisitExpr_(const CastNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const CastNode* op) {
Doc doc;
if (cast(op->dtype, op->value)->IsInstance<CastNode>()) {
doc << Print(op->value) << ".astype(" << PrintDType(op->dtype) << ")";
@@ -366,76 +366,76 @@ Doc TIRHybridPrinter::VisitExpr_(const CastNode* op) {
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const VarNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const VarNode* op) {
const Var& var = GetRef<Var>(op);
return meta_.InMeta(var) ? meta_.GetMetaNode(var) :
AllocVar(GetRef<Var>(op));
}
-#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString) \
- Doc TIRHybridPrinter::VisitExpr_(const OpName* op) { \
+#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString) \
+ Doc TVMScriptPrinter::VisitExpr_(const OpName* op) { \
Doc doc; \
doc << '(' << Print(op->a) << OpString << Print(op->b) << ")"; \
return doc; \
}
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " and ")
-TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " or ")
-
-Doc TIRHybridPrinter::VisitExpr_(const FloorDivNode* op) {
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, "*")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(ModNode, " % ")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ")
+TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ")
+
+Doc TVMScriptPrinter::VisitExpr_(const FloorDivNode* op) {
Doc doc;
doc << "tir.floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const FloorModNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const FloorModNode* op) {
Doc doc;
doc << "tir.floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const MinNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const MinNode* op) {
Doc doc;
doc << "tir.min(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const MaxNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const MaxNode* op) {
Doc doc;
doc << "tir.max(" << Print(op->a) << ", " << Print(op->b) << ")";
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const NotNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const NotNode* op) {
Doc doc;
doc << "not (" << Print(op->a) << ")";
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const SelectNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const SelectNode* op) {
Doc doc;
doc << "tir.select(" << Print(op->condition) << ", " <<
Print(op->true_value) << ", "
<< Print(op->false_value) << ")";
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const BufferLoadNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op) {
Doc doc;
doc << Print(op->buffer) << Print(op->indices);
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const LoadNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const LoadNode* op) {
Doc doc;
if (op->dtype == DataType::Float(32) && is_one(op->predicate) &&
op->buffer_var->dtype == DataType::Float(32)) {
@@ -451,25 +451,25 @@ Doc TIRHybridPrinter::VisitExpr_(const LoadNode* op) {
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const RampNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const RampNode* op) {
Doc doc;
doc << "tir.ramp(" << Print(op->base) << ", " << Print(op->stride) << ", "
<< op->lanes << ")";
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const BroadcastNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op) {
Doc doc;
doc << "tir.broadcast(" << Print(op->value) << ", " << op->lanes << ")";
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const LetNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const LetNode* op) {
Doc doc;
doc << "tir.let(" << Print(op->var) << ", " << Print(op->value) << ", " <<
Print(op->body) << ")";
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const CallNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const CallNode* op) {
Doc doc;
if (auto* ptr_op = op->op.as<OpNode>()) {
doc << Doc::Text(ptr_op->name) << "(";
@@ -487,20 +487,20 @@ Doc TIRHybridPrinter::VisitExpr_(const CallNode* op) {
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const ShuffleNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const ShuffleNode* op) {
Doc doc;
doc << "tir.shuffle(" << Print(op->vectors) << ", " << Print(op->indices) <<
")";
return doc;
}
-Doc TIRHybridPrinter::VisitExpr_(const ReduceNode* op) {
+Doc TVMScriptPrinter::VisitExpr_(const ReduceNode* op) {
Doc doc;
doc << "tir.reduce(" << Print(op->combiner) << ", " << Print(op->source) <<
", "
<< Print(op->axis) << ", " << op->value_index << ")";
return doc;
}
-Doc TIRHybridPrinter::VisitStmt_(const LetStmtNode* op) {
+Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with tir.let(" << Print(op->var) << ", " << Print(op->value) <<
"):";
@@ -513,7 +513,7 @@ Doc TIRHybridPrinter::VisitStmt_(const LetStmtNode* op) {
return doc;
}
-Doc TIRHybridPrinter::VisitStmt_(const AttrStmtNode* op) {
+Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) {
Doc doc;
// merge attr with allocate when possible
if (op->node->IsInstance<VarNode>() && op->attr_key == "storage_scope" &&
@@ -591,7 +591,7 @@ Doc TIRHybridPrinter::VisitStmt_(const AttrStmtNode* op) {
return doc;
}
-Doc TIRHybridPrinter::VisitStmt_(const AssertStmtNode* op) {
+Doc TVMScriptPrinter::VisitStmt_(const AssertStmtNode* op) {
Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with tir.Assert(" << Print(op->condition) << ", " <<
Print(op->message) << "):";
@@ -603,7 +603,7 @@ Doc TIRHybridPrinter::VisitStmt_(const AssertStmtNode* op) {
return doc;
}
-Doc TIRHybridPrinter::VisitStmt_(const StoreNode* op) {
+Doc TVMScriptPrinter::VisitStmt_(const StoreNode* op) {
Doc doc;
if (!is_one(op->predicate) || op->value.dtype().lanes() != 1) {
doc << "tir.store(" << Print(op->buffer_var) << ", " << Print(op->index)
<< ", "
@@ -614,17 +614,18 @@ Doc TIRHybridPrinter::VisitStmt_(const StoreNode* op) {
return doc;
}
-Doc TIRHybridPrinter::VisitStmt_(const BufferRealizeNode* op) {
- LOG(FATAL) << "Hybrid Printer Internal Error: All the BufferRealize should
be folded with Attr";
+Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
+ LOG(FATAL)
+ << "TVM Script Printer Internal Error: All the BufferRealize should be
folded with Attr";
return Doc();
}
-Doc TIRHybridPrinter::VisitStmt_(const AllocateNode* op) {
- LOG(FATAL) << "Hybrid Printer Internal Error: All the Allocate should be
folded with Attr";
+Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
+ LOG(FATAL) << "TVM Script Printer Internal Error: All the Allocate should be
folded with Attr";
return Doc();
}
-Doc TIRHybridPrinter::VisitStmt_(const IfThenElseNode* op) {
+Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case));
@@ -634,7 +635,7 @@ Doc TIRHybridPrinter::VisitStmt_(const IfThenElseNode* op) {
return doc;
}
-Doc TIRHybridPrinter::VisitStmt_(const SeqStmtNode* op) {
+Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) {
std::vector<Doc> stmts;
for (Stmt stmt : op->seq) {
stmts.push_back(Print(stmt));
@@ -642,7 +643,7 @@ Doc TIRHybridPrinter::VisitStmt_(const SeqStmtNode* op) {
return PrintSep(stmts, Doc::NewLine());
}
-Doc TIRHybridPrinter::VisitStmt_(const EvaluateNode* op) {
+Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
Doc doc;
doc << "tir.evaluate(" << Print(op->value) << ")";
return doc;
@@ -663,7 +664,7 @@ inline const char* ForType2String(ForType t) {
return "Unknown";
}
-Doc TIRHybridPrinter::VisitStmt_(const ForNode* op) {
+Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
var_not_in_headers.insert(op->loop_var.get());
doc << "for " << Print(op->loop_var)
@@ -673,25 +674,25 @@ Doc TIRHybridPrinter::VisitStmt_(const ForNode* op) {
return doc;
}
-Doc TIRHybridPrinter::VisitStmt_(const PrefetchNode* op) {
+Doc TVMScriptPrinter::VisitStmt_(const PrefetchNode* op) {
Doc doc;
doc << "tir.prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) <<
")";
return doc;
}
-Doc TIRHybridPrinter::VisitType_(const PrimTypeNode* node) {
+Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) {
Doc doc;
doc << "ty." << runtime::DLDataType2String(node->dtype);
return doc;
}
-Doc TIRHybridPrinter::VisitType_(const PointerTypeNode* node) {
+Doc TVMScriptPrinter::VisitType_(const PointerTypeNode* node) {
Doc doc;
doc << "ty.Ptr[" << Print(node->element_type) << "]";
return doc;
}
-Doc TIRHybridPrinter::VisitType_(const TupleTypeNode* node) {
+Doc TVMScriptPrinter::VisitType_(const TupleTypeNode* node) {
if (node->fields.empty()) {
return Doc::Text("None");
} else {
@@ -703,13 +704,13 @@ Doc TIRHybridPrinter::VisitType_(const TupleTypeNode*
node) {
}
}
-Doc TIRHybridPrinter::VisitStmt_(const BufferStoreNode* op) {
+Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) {
Doc doc;
doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value);
return doc;
}
-Doc TIRHybridPrinter::PrintBody(const Stmt& body) {
+Doc TVMScriptPrinter::PrintBody(const Stmt& body) {
int memo_num_child, memo_current_num;
std::swap(memo_num_child, num_child_);
std::swap(memo_current_num, current_num_);
@@ -736,7 +737,7 @@ Doc TIRHybridPrinter::PrintBody(const Stmt& body) {
return doc;
}
-Doc TIRHybridPrinter::PrintIRModule(const IRModule& module) {
+Doc TVMScriptPrinter::PrintIRModule(const IRModule& module) {
auto* op = module.operator->();
Doc doc;
doc << "class Module:";
@@ -750,13 +751,13 @@ Doc TIRHybridPrinter::PrintIRModule(const IRModule&
module) {
functions.push_back(Print((*it).second));
}
}
- body << TIRHybridPrinter::PrintSep(functions, Doc::NewLine() <<
Doc::NewLine());
+ body << TVMScriptPrinter::PrintSep(functions, Doc::NewLine() <<
Doc::NewLine());
body << Doc::NewLine() << DumpMeta();
doc << Doc::Indent(4, body);
return doc;
}
-Doc TIRHybridPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
+Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
auto* op = primFunc.operator->();
// clear renaming map
memo_var_.clear();
@@ -851,7 +852,7 @@ Doc TIRHybridPrinter::PrintPrimFunc(const PrimFunc&
primFunc) {
return doc;
}
-Doc TIRHybridPrinter::PrintArray(const ArrayNode* op) {
+Doc TVMScriptPrinter::PrintArray(const ArrayNode* op) {
Doc doc;
doc << '[';
for (size_t i = 0; i < op->size(); ++i) {
@@ -864,7 +865,7 @@ Doc TIRHybridPrinter::PrintArray(const ArrayNode* op) {
return doc;
}
-Doc TIRHybridPrinter::PrintIterVar(const IterVarNode* op) {
+Doc TVMScriptPrinter::PrintIterVar(const IterVarNode* op) {
Doc doc;
doc << "tir.iter_var(" << Print(op->var);
if (op->dom.defined()) {
@@ -877,20 +878,20 @@ Doc TIRHybridPrinter::PrintIterVar(const IterVarNode* op)
{
return doc;
}
-Doc TIRHybridPrinter::PrintRange(const RangeNode* op) {
+Doc TVMScriptPrinter::PrintRange(const RangeNode* op) {
return Print(op->min) << ":" << Print(op->min + op->extent);
}
-Doc TIRHybridPrinter::PrintBuffer(const BufferNode* op) {
+Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
const Buffer& buffer = GetRef<Buffer>(op);
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
}
-TVM_REGISTER_GLOBAL("hybrid.AsHybrid")
+TVM_REGISTER_GLOBAL("script.AsTVMScript")
.set_body_typed<std::string(const ObjectRef&, bool)>([](const ObjectRef&
functions,
bool show_meta) {
CHECK(functions.as<PrimFuncNode>() != nullptr ||
functions.as<IRModuleNode>() != nullptr);
- return "@tvm.hybrid.script\n" +
TIRHybridPrinter(show_meta).Print(functions).str() + "\n";
+ return "@tvm.script.tir\n" +
TVMScriptPrinter(show_meta).Print(functions).str() + "\n";
});
} // namespace tir
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index d9e1df4..f451177 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -265,7 +265,7 @@ Allocate::Allocate(Var buffer_var, DataType dtype,
Array<PrimExpr> extents, Prim
Stmt body) {
// TODO(tvm-team): Add invariant check to make sure
// IsPointerPType(buffer_var->type_annotation, dtype)
- // once we fix the allocate hybrid script printing.
+ // once we fix the allocate tvm script printing.
for (size_t i = 0; i < extents.size(); ++i) {
CHECK(extents[i].defined());
CHECK(extents[i].dtype().is_scalar());
diff --git a/tests/python/unittest/test_hybrid_error_report.py
b/tests/python/unittest/test_tvmscript_error_report.py
similarity index 91%
rename from tests/python/unittest/test_hybrid_error_report.py
rename to tests/python/unittest/test_tvmscript_error_report.py
index 9b3c5bb..dd8621d 100644
--- a/tests/python/unittest/test_hybrid_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -19,17 +19,17 @@ import pytest
import tvm
from tvm import tir
-from tvm.hybrid import ty
-from tvm.hybrid.parser import HybridParserError
+from tvm.script import ty
+from tvm.script.parser import TVMScriptParserError
[email protected]
[email protected]
class Module1:
def buffer_bind_missing_args(a: ty.handle) -> None:
A = tir.match_buffer((16, 16), "float32")
[email protected]
[email protected]
class Module2:
def range_missing_args(a: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float32")
@@ -41,7 +41,7 @@ class Module2:
A[i, j] = 0.0
[email protected]
[email protected]
class Module3:
def undefined_buffer(a: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float32")
@@ -53,14 +53,14 @@ class Module3:
A[i, j] = 0.0
[email protected]
[email protected]
class Module4:
def unsupported_stmt(a: ty.int32) -> None:
if a > 0:
print("I love tvm")
[email protected]
[email protected]
class Module5:
def unsupported_function_call(a: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float32")
@@ -72,26 +72,26 @@ class Module5:
A[i, j] = 0.0
[email protected]
[email protected]
class Module6:
def missing_type_annotation(a) -> None:
pass
[email protected]
[email protected]
class Module7:
def invalid_concise_scoping() -> None:
tir.Assert(1.0 > 0.0, "aaaa")
tir.evaluate(0.0)
[email protected]
[email protected]
class Module8:
def invalid_expr_stmt() -> None:
tir.max(1, 2)
[email protected]
[email protected]
class Module9:
def invalid_for_function(a: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float32")
@@ -101,7 +101,7 @@ class Module9:
A[i, j] = 0.0
[email protected]
[email protected]
class Module10:
def invalid_block_function(a: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float32")
@@ -111,7 +111,7 @@ class Module10:
def wrap_error(module, lineno):
- with pytest.raises(HybridParserError) as error:
+ with pytest.raises(TVMScriptParserError) as error:
mod = module()
assert error is not None
e = error.value
diff --git a/tests/python/unittest/test_hybrid_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
similarity index 99%
rename from tests/python/unittest/test_hybrid_roundtrip.py
rename to tests/python/unittest/test_tvmscript_roundtrip.py
index ea67a4e..c7a38cc 100644
--- a/tests/python/unittest/test_hybrid_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -17,10 +17,10 @@
import tvm
from tvm import tir
-from tvm.hybrid import ty
+from tvm.script import ty
[email protected]
[email protected]
class Module1:
def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None:
# function attr dict
@@ -75,11 +75,11 @@ class Module1:
def test_opt_gemm_normalize():
mod = Module1()
- rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True))
+ rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True))
tvm.ir.assert_structural_equal(mod, rt_mod, True)
[email protected]
[email protected]
class Module2:
def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None:
# function attr dict
@@ -254,11 +254,11 @@ class Module2:
def test_opt_gemm_lower():
mod = Module2()
- rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True))
+ rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True))
tvm.ir.assert_structural_equal(mod, rt_mod, True)
[email protected]
[email protected]
class Module3:
def mmult(
args: ty.handle,
@@ -608,11 +608,11 @@ class Module3:
def test_opt_gemm_mod_host():
mod = Module3()
- rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True))
+ rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True))
tvm.ir.assert_structural_equal(mod, rt_mod, True)
[email protected]
[email protected]
def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle)
-> None:
# function attr dict
tir.func_attr({"global_symbol": "default_function", "tir.noalias": True})
@@ -1071,11 +1071,11 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W:
ty.handle, Conv: ty.handle) -
def test_opt_conv_tensorcore_normalize():
mod = opt_conv_tensorcore_normalize
- rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True))
+ rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True))
tvm.ir.assert_structural_equal(mod, rt_mod, True)
[email protected]
[email protected]
def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) ->
None:
# function attr dict
tir.func_attr({"global_symbol": "default_function", "tir.noalias": True})
@@ -2414,11 +2414,11 @@ def opt_conv_tensorcore_lower(A: ty.handle, W:
ty.handle, Conv: ty.handle) -> No
def test_opt_conv_tensorcore_lower():
mod = opt_conv_tensorcore_lower
- rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True))
+ rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True))
tvm.ir.assert_structural_equal(mod, rt_mod, True)
[email protected]
[email protected]
def opt_conv_tensorcore_mod_host(
args: ty.handle,
arg_type_ids: ty.handle,
@@ -2658,7 +2658,7 @@ def opt_conv_tensorcore_mod_host(
def test_opt_conv_tensorcore_mod_host():
mod = opt_conv_tensorcore_mod_host
- rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True))
+ rt_mod = tvm.script.from_source(tvm.script.asscript(mod, True))
tvm.ir.assert_structural_equal(mod, rt_mod, True)