This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7950271ceb [TVMScript] TIR parser (#13190)
7950271ceb is described below
commit 7950271ceb16be076208bd2d144eb475f956c982
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Oct 25 06:04:44 2022 -0700
[TVMScript] TIR parser (#13190)
---
python/tvm/script/_parser/__init__.py | 3 +-
python/tvm/script/_parser/core/parser.py | 15 +
python/tvm/script/_parser/{ => tir}/__init__.py | 14 +-
python/tvm/script/_parser/tir/entry.py | 108 +++++
python/tvm/script/_parser/tir/operation.py | 85 ++++
python/tvm/script/_parser/tir/parser.py | 468 +++++++++++++++++++++
tests/python/unittest/test_tvmscript_parser_tir.py | 63 +++
7 files changed, 751 insertions(+), 5 deletions(-)
diff --git a/python/tvm/script/_parser/__init__.py
b/python/tvm/script/_parser/__init__.py
index fd4e45818c..38c8b88cc7 100644
--- a/python/tvm/script/_parser/__init__.py
+++ b/python/tvm/script/_parser/__init__.py
@@ -15,5 +15,6 @@
# specific language governing permissions and limitations
# under the Licens.
"""The parser"""
-from . import _core, ir
+from . import _core, ir, tir
from .ir import ir_module
+from .tir import prim_func
diff --git a/python/tvm/script/_parser/core/parser.py
b/python/tvm/script/_parser/core/parser.py
index daf95cb3cd..c6d43f11cb 100644
--- a/python/tvm/script/_parser/core/parser.py
+++ b/python/tvm/script/_parser/core/parser.py
@@ -571,6 +571,21 @@ class Parser(doc.NodeVisitor):
"""
return _dispatch(self, "Assign")(self, node)
+ def visit_AnnAssign(self, node: doc.AnnAssign) -> Any: # pylint:
disable=invalid-name
+ """The general annotated assign visiting method.
+
+ Parameters
+ ----------
+ node : doc.Assign
+ The doc AST annotated assign node.
+
+ Returns
+ -------
+ res : Any
+ The visiting result.
+ """
+ return _dispatch(self, "AnnAssign")(self, node)
+
def visit_Expr(self, node: doc.Expr) -> Any: # pylint:
disable=invalid-name
"""The general expression visiting method.
diff --git a/python/tvm/script/_parser/__init__.py
b/python/tvm/script/_parser/tir/__init__.py
similarity index 70%
copy from python/tvm/script/_parser/__init__.py
copy to python/tvm/script/_parser/tir/__init__.py
index fd4e45818c..7754baf087 100644
--- a/python/tvm/script/_parser/__init__.py
+++ b/python/tvm/script/_parser/tir/__init__.py
@@ -13,7 +13,13 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
-# under the Licens.
-"""The parser"""
-from . import _core, ir
-from .ir import ir_module
+# under the License.
+"""The tir parser"""
+
+from ...ir_builder.tir import * # pylint: disable=redefined-builtin
+from ...ir_builder.tir import ir as _tir
+from . import operation as _operation
+from . import parser as _parser
+from .entry import Buffer, Ptr, prim_func
+
+__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"]
diff --git a/python/tvm/script/_parser/tir/entry.py
b/python/tvm/script/_parser/tir/entry.py
new file mode 100644
index 0000000000..632b87aa24
--- /dev/null
+++ b/python/tvm/script/_parser/tir/entry.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The entry point of TVM parser for tir."""
+
+import inspect
+from typing import Callable, Union
+
+from tvm.tir import Buffer, PrimFunc
+
+from ...ir_builder.tir import buffer_decl, ptr
+from .._core import parse, utils
+from ..ir import is_defined_in_class
+
+
+def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
+ """The parsing method for tir prim func, by using `@prim_func` as
decorator.
+
+ Parameters
+ ----------
+ func : Callable
+ The function to be parsed as prim func.
+
+ Returns
+ -------
+ res : Union[PrimFunc, Callable]
+ The parsed tir prim func.
+ """
+ if not inspect.isfunction(func):
+ raise TypeError(f"Expect a function, but got: {func}")
+ if is_defined_in_class(inspect.stack()):
+ return func
+ return parse(func, utils.inspect_function_capture(func))
+
+
+setattr(prim_func, "dispatch_token", "tir")
+
+
+class BufferProxy:
+ """Buffer proxy class for constructing tir buffer.
+ Overload __call__ and __getitem__ to support syntax as T.Buffer() and
T.Buffer[].
+ """
+
+ def __call__(
+ self,
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="global",
+ align=0,
+ offset_factor=0,
+ buffer_type="",
+ axis_separators=None,
+ ) -> Buffer:
+ return buffer_decl(
+ shape,
+ dtype=dtype,
+ data=data,
+ strides=strides,
+ elem_offset=elem_offset,
+ scope=scope,
+ align=align,
+ offset_factor=offset_factor,
+ buffer_type=buffer_type,
+ axis_separators=axis_separators,
+ )
+
+ def __getitem__(self, keys) -> Buffer:
+ if not isinstance(keys, tuple):
+ return self(keys)
+ if len(keys) >= 2 and not isinstance(keys[1], str):
+ return self(keys)
+ return self(*keys) # pylint: disable=no-member # type: ignore
+
+
+class PtrProxy:
+ """Ptr proxy class for constructing tir pointer.
+ Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr[].
+ """
+
+ def __call__(self, dtype, storage_scope="global"):
+ if callable(dtype):
+ dtype = dtype().dtype
+ return ptr(dtype, storage_scope) # pylint: disable=no-member # type:
ignore
+
+ def __getitem__(self, keys):
+ if not isinstance(keys, tuple):
+ return self(keys)
+ return self(*keys)
+
+
+Buffer = BufferProxy() # pylint: disable=invalid-name
+Ptr = PtrProxy() # pylint: disable=invalid-name
diff --git a/python/tvm/script/_parser/tir/operation.py
b/python/tvm/script/_parser/tir/operation.py
new file mode 100644
index 0000000000..ed8f07a063
--- /dev/null
+++ b/python/tvm/script/_parser/tir/operation.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The tir expression operation registration"""
+
+from typing import Type
+
+from tvm import tir
+from tvm.tir import IntImm
+
+from .._core import OpMethod, doc, register_op
+
+
+def _register_expr_op(ty: Type): # pylint: disable=invalid-name
+ ty._dispatch_type = ty # pylint: disable=protected-access
+
+ def _and(a, b):
+ if isinstance(a, bool):
+ a = IntImm("bool", a)
+ if isinstance(b, bool):
+ b = IntImm("bool", b)
+ return tir.And(a, b)
+
+ def _or(a, b):
+ if isinstance(a, bool):
+ a = IntImm("bool", a)
+ if isinstance(b, bool):
+ b = IntImm("bool", b)
+ return tir.Or(a, b)
+
+ def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name
+ register_op(ty, op, i)(m)
+
+ for i in [0, 1]:
+ # Case 1. binop
+ r(doc.Add, i, lambda a, b: a + b)
+ r(doc.Sub, i, lambda a, b: a - b)
+ r(doc.Mult, i, lambda a, b: a * b)
+ r(doc.Div, i, lambda a, b: a / b)
+ r(doc.FloorDiv, i, lambda a, b: a // b)
+ r(doc.Mod, i, lambda a, b: a % b)
+ r(doc.LShift, i, lambda a, b: a << b)
+ r(doc.RShift, i, lambda a, b: a >> b)
+ r(doc.BitOr, i, lambda a, b: a | b)
+ r(doc.BitXor, i, lambda a, b: a ^ b)
+ r(doc.BitAnd, i, lambda a, b: a & b)
+ # doc.MatMult <-- not implemented
+ # doc.Pow <-- not implemented
+ # Case 2. cmpop
+ r(doc.Eq, i, tir.EQ)
+ r(doc.NotEq, i, tir.NE)
+ r(doc.Lt, i, tir.LT)
+ r(doc.LtE, i, tir.LE)
+ r(doc.Gt, i, tir.GT)
+ r(doc.GtE, i, tir.GE)
+ # doc.Is <-- not implemented
+ # doc.IsNot <-- not implemented
+ # doc.In <-- not implemented
+ # doc.NotIn <-- not implemented
+ # Case 3. boolop
+ r(doc.And, i, _and)
+ r(doc.Or, i, _or)
+ for i in [0]:
+ # Case 4. unaryop
+ r(doc.Invert, i, lambda a: ~a)
+ r(doc.Not, i, tir.Not)
+ r(doc.UAdd, i, lambda a: +a)
+ r(doc.USub, i, lambda a: -a)
+
+
+_register_expr_op(tir.PrimExpr)
+_register_expr_op(tir.IterVar)
diff --git a/python/tvm/script/_parser/tir/parser.py
b/python/tvm/script/_parser/tir/parser.py
new file mode 100644
index 0000000000..909238563f
--- /dev/null
+++ b/python/tvm/script/_parser/tir/parser.py
@@ -0,0 +1,468 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The base parser for tir"""
+
+import contextlib
+from functools import partial
+from typing import Any
+
+from tvm.ir import PrimType
+from tvm.tir import Buffer, IterVar, PrimExpr, Var
+
+from ...ir_builder import tir as T
+from ...ir_builder.base import IRBuilder
+from ...ir_builder.base import IRBuilderFrame as Frame
+from .._core import Parser, dispatch, doc
+
+
+def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any)
-> Any:
+ """Value binding methods when parsing with statement.
+ e.g. binding i, j, k with T.grid(128, 128, 128), when parsing
+ with T.grid(128, 128, 18) as i, j, k.
+
+ Parameters
+ ----------
+ self : Parser
+ The current parser.
+
+ node : doc.expr
+ The doc AST expression node for error reporting.
+
+ var_name : str
+ The variable name.
+
+ value : Any
+ The value to be bound with.
+
+ Returns
+ -------
+ res : Any
+ The bound value.
+ """
+ if isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_with_value(self, node, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, (Buffer, Var)):
+ IRBuilder.name(var_name, value)
+ return value
+ else:
+ self.report_error(node, f"Do not know how to bind type: {type(value)}
in with statement")
+ raise NotImplementedError
+
+
+def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) ->
Any:
+ """Value binding methods when parsing for statement.
+ e.g. binding i, j, k with T.grid(128, 128, 128), when parsing
+ for i, j, k in T.grid(128, 128, 128).
+
+ Parameters
+ ----------
+ self : Parser
+ The current parser.
+
+ node : doc.expr
+ The doc AST expression node for error reporting.
+
+ var_name : str
+ The variable name.
+
+ value : Any
+ The value to be bound with.
+
+ Returns
+ -------
+ res : Any
+ The bound value.
+ """
+ if isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_for_value(self, node, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, Var):
+ IRBuilder.name(var_name, value)
+ return value
+ else:
+ self.report_error(node, f"Do not know how to bind type: {type(value)}
in for statement")
+ raise NotImplementedError
+
+
+def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any)
-> Any:
+ """Value binding methods when parsing assign statement.
+ e.g. binding vi, vj, vk with T.axis.remap("SSR", [i, j, k]), when parsing
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k]).
+
+ Parameters
+ ----------
+ self : Parser
+ The current parser.
+
+ node : doc.expr
+ The doc AST expression node for error reporting.
+
+ var_name : str
+ The variable name.
+
+ value : Any
+ The value to be bound with.
+
+ Returns
+ -------
+ res : Any
+ The bound value.
+ """
+ if isinstance(value, T.inline):
+ return value.value
+ elif isinstance(value, (list, tuple)):
+ for i, v in enumerate(value):
+ bind_assign_value(self, node, f"{var_name}_{i}", v)
+ return value
+ elif isinstance(value, Frame):
+ value.add_callback(partial(value.__exit__, None, None, None))
+ res = value.__enter__()
+ IRBuilder.name(var_name, res)
+ return res
+ elif isinstance(value, (Buffer, IterVar)) or (
+ isinstance(value, Var) and not self.var_table.exist(value)
+ ):
+ IRBuilder.name(var_name, value)
+ return value
+ elif isinstance(value, PrimExpr):
+ var = T.var(value.dtype)
+ IRBuilder.name(var_name, var)
+ frame = T.let(var, value)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
+ return var
+ return value
+
+
[email protected](token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+ """The for visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.For
+ The doc AST for node.
+ """
+ for_frame = self.eval_expr(node.iter)
+ if not isinstance(for_frame, T.frame.ForFrame):
+ self.report_error(
+ node.iter,
+ "Expect the for loop to be one of the following: "
+ "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll,
T.thread_binding",
+ )
+ with self.var_table.with_frame():
+ with for_frame as iters:
+ self.eval_assign(target=node.target, source=iters,
bind_value=bind_for_value)
+ self.visit_body(node.body)
+
+
[email protected](token="tir", type_name="While")
+def visit_while(self: Parser, node: doc.While) -> None:
+ """The while visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.While
+ The doc AST while node.
+ """
+ with self.var_table.with_frame():
+ cond = self.eval_expr(node.test)
+ with T.While(cond):
+ self.visit_body(node.body)
+
+
[email protected](token="tir", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+ """The assign visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.Assign
+ The doc AST assign node.
+ """
+ if len(node.targets) != 1:
+ self.report_error(node, "Consequential assignments like 'a = b = c'
are not supported.")
+ lhs = node.targets[0]
+ rhs = self.eval_expr(node.value)
+ if isinstance(lhs, doc.Subscript):
+ if isinstance(lhs.slice, doc.Tuple):
+ indices = []
+ for index in lhs.slice.elts:
+ indices.append(self.eval_expr(index))
+ else:
+ indices = [self.eval_expr(lhs.slice)]
+ T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+ else:
+ self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
[email protected](token="tir", type_name="AugAssign")
+def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
+ """The augmented assign visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.AugAssign
+ The doc AST augmented assign node.
+ """
+ lhs_pos = (
+ node.target.lineno,
+ node.target.col_offset,
+ node.target.end_lineno,
+ node.target.end_col_offset,
+ )
+ rhs_pos = (
+ node.value.lineno,
+ node.value.col_offset,
+ node.value.end_lineno,
+ node.value.end_col_offset,
+ )
+ node.target.ctx = doc.Load(*lhs_pos)
+ with self.var_table.with_frame():
+ lhs_name = "__tvm_tmp_value_aug_assign_lhs"
+ rhs_name = "__tvm_tmp_value_aug_assign_rhs"
+ lhs_expr = self.eval_expr(node.target)
+ rhs_expr = self.eval_expr(node.value)
+ self.var_table.add(lhs_name, lhs_expr)
+ self.var_table.add(rhs_name, rhs_expr)
+ op = doc.BinOp(
+ doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
+ node.op,
+ doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
+ *lhs_pos,
+ )
+ rhs = self.eval_expr(op)
+ lhs = node.target
+ lhs.ctx = doc.Store(*lhs_pos)
+ if isinstance(lhs, doc.Subscript):
+ if isinstance(lhs.slice, doc.Tuple):
+ indices = []
+ for index in lhs.slice.elts:
+ indices.append(self.eval_expr(index))
+ else:
+ indices = [self.eval_expr(lhs.slice)]
+ T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+ else:
+ self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
[email protected](token="tir", type_name="AnnAssign")
+def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
+ """The annotated assign visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.AnnAssign
+ The doc AST annotated assign node.
+ """
+ lhs = node.target
+ rhs = self.eval_expr(node.value)
+ ann_var = self.visit_tvm_annotation(node.annotation)
+ if not isinstance(ann_var, Var):
+ self.report_error(node.annotation, "Annotation should be Var")
+ self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
+ frame = T.let(ann_var, rhs)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
+
+
[email protected](token="tir", type_name="With")
+def visit_with(self: Parser, node: doc.With) -> None:
+ """The with visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.With
+ The doc AST with node.
+ """
+ with contextlib.ExitStack() as stack:
+ stack.enter_context(self.var_table.with_frame())
+ for item in node.items:
+ frame = self.eval_expr(item.context_expr)
+ if not isinstance(frame, Frame):
+ self.report_error(
+ item.context_expr, "Invalid context expression in the
with-statement."
+ )
+ rhs = stack.enter_context(frame)
+ if item.optional_vars is not None:
+ self.eval_assign(target=item.optional_vars, source=rhs,
bind_value=bind_with_value)
+ self.visit_body(node.body)
+
+
[email protected](token="tir", type_name="FunctionDef")
+def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+ """The function definition visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.FunctionDef
+ The doc AST function definition node.
+ """
+ with self.var_table.with_frame():
+ self.var_table.add("range", T.serial)
+ with T.prim_func():
+ T.func_name(node.name)
+ if node.returns is not None:
+ ret_type = self.eval_expr(node.returns)
+ if callable(ret_type):
+ ret_type = PrimType(ret_type().dtype)
+ T.func_ret(ret_type)
+ with self.with_dispatch_token("tir"):
+ self.visit(node.args)
+ self.visit_body(node.body)
+
+
[email protected](token="tir", type_name="arguments")
+def visit_arguments(self: Parser, node: doc.arguments) -> None:
+ """The arguments visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.arguments
+ The doc AST arguments node.
+ """
+ # TODO: handle different types of arguments:
+ # - vararg: arg | None
+ # - kwonlyargs: list[arg]
+ # - kw_defaults: list[expr | None]
+ # - kwarg: arg | None
+ # - defaults: list[expr]
+ # - posonlyargs: list[arg]
+ arg: doc.arg
+ for arg in node.args:
+ if arg.annotation is None:
+ self.report_error(arg, "Type annotation is required for function
parameters.")
+ param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
+ self.var_table.add(arg.arg, param)
+
+
[email protected](token="tir", type_name="tvm_annotation")
+def visit_tvm_annotation(self: Parser, node: doc.expr):
+ """The TVM annotation visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.expr
+ The doc AST expr node.
+ """
+ annotation = self.eval_expr(node)
+ if callable(annotation):
+ annotation = annotation()
+ return annotation
+
+
[email protected](token="tir", type_name="Expr")
+def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
+ """The expr statement visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.Expr
+ The doc AST Expr node.
+ """
+ res = self.eval_expr(node.value)
+ if isinstance(res, Frame):
+ res.add_callback(partial(res.__exit__, None, None, None))
+ res.__enter__()
+
+
[email protected](token="tir", type_name="If")
+def visit_if(self: Parser, node: doc.If) -> None:
+ """The if visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.If
+ The doc AST if node.
+ """
+ with self.var_table.with_frame():
+ with T.If(self.eval_expr(node.test)):
+ with T.Then():
+ self.visit_body(node.body)
+ if node.orelse:
+ with T.Else():
+ self.visit_body(node.orelse)
+
+
[email protected](token="tir", type_name="Assert")
+def visit_assert(self: Parser, node: doc.Assert) -> None:
+ """The assert visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.Assert
+ The doc AST assert node.
+ """
+ cond = self.eval_expr(node.test)
+ msg = self.eval_expr(node.msg)
+ frame = T.Assert(cond, msg)
+ frame.add_callback(partial(frame.__exit__, None, None, None))
+ frame.__enter__()
+
+
[email protected](token="tir", type_name="Return")
+def visit_return(self: Parser, node: doc.Return) -> None:
+ """The return visiting method for tir.
+
+ Parameters
+ ----------
+ self : Parser
+ The visiting parser.
+
+ node : doc.Return
+ The doc AST return node.
+ """
+ self.report_error(node, "Return is not allowed.")
diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py
b/tests/python/unittest/test_tvmscript_parser_tir.py
new file mode 100644
index 0000000000..cfa1dc62b3
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_parser_tir.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unittests for tvm.script.parser.tir"""
+
+import pytest
+import inspect
+import tvm.testing
+from tvm.script._parser import tir as T
+from tvm import ir, tir
+
+
+def test_tir_buffer_proxy():
+ buffer_0 = T.Buffer((128, 128), "float32")
+ assert (
+ isinstance(buffer_0, tir.Buffer)
+ and list(buffer_0.shape) == [128, 128]
+ and buffer_0.dtype == "float32"
+ )
+
+ buffer_1 = T.Buffer[(64, 64, 64), "int32"]
+ assert (
+ isinstance(buffer_1, tir.Buffer)
+ and list(buffer_1.shape) == [64, 64, 64]
+ and buffer_1.dtype == "int32"
+ )
+
+
+def test_tir_ptr_proxy():
+ ptr_0 = T.Ptr("int32", "global")
+ assert (
+ isinstance(ptr_0, tir.Var)
+ and ptr_0.dtype == "handle"
+ and isinstance(ptr_0.type_annotation, ir.PointerType)
+ and ptr_0.type_annotation.element_type == ir.PrimType("int32")
+ and ptr_0.type_annotation.storage_scope == "global"
+ )
+
+ ptr_1 = T.Ptr["float32", "shared"]
+ assert (
+ isinstance(ptr_1, tir.Var)
+ and ptr_1.dtype == "handle"
+ and isinstance(ptr_1.type_annotation, ir.PointerType)
+ and ptr_1.type_annotation.element_type == ir.PrimType("float32")
+ and ptr_1.type_annotation.storage_scope == "shared"
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()