This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7950271ceb [TVMScript] TIR parser (#13190)
7950271ceb is described below

commit 7950271ceb16be076208bd2d144eb475f956c982
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Oct 25 06:04:44 2022 -0700

    [TVMScript] TIR parser (#13190)
---
 python/tvm/script/_parser/__init__.py              |   3 +-
 python/tvm/script/_parser/core/parser.py           |  15 +
 python/tvm/script/_parser/{ => tir}/__init__.py    |  14 +-
 python/tvm/script/_parser/tir/entry.py             | 108 +++++
 python/tvm/script/_parser/tir/operation.py         |  85 ++++
 python/tvm/script/_parser/tir/parser.py            | 468 +++++++++++++++++++++
 tests/python/unittest/test_tvmscript_parser_tir.py |  63 +++
 7 files changed, 751 insertions(+), 5 deletions(-)

diff --git a/python/tvm/script/_parser/__init__.py 
b/python/tvm/script/_parser/__init__.py
index fd4e45818c..38c8b88cc7 100644
--- a/python/tvm/script/_parser/__init__.py
+++ b/python/tvm/script/_parser/__init__.py
@@ -15,5 +15,6 @@
 # specific language governing permissions and limitations
 # under the Licens.
 """The parser"""
-from . import _core, ir
+from . import _core, ir, tir
 from .ir import ir_module
+from .tir import prim_func
diff --git a/python/tvm/script/_parser/core/parser.py 
b/python/tvm/script/_parser/core/parser.py
index daf95cb3cd..c6d43f11cb 100644
--- a/python/tvm/script/_parser/core/parser.py
+++ b/python/tvm/script/_parser/core/parser.py
@@ -571,6 +571,21 @@ class Parser(doc.NodeVisitor):
         """
         return _dispatch(self, "Assign")(self, node)
 
+    def visit_AnnAssign(self, node: doc.AnnAssign) -> Any:  # pylint: 
disable=invalid-name
+        """The general annotated assign visiting method.
+
+        Parameters
+        ----------
+        node : doc.Assign
+            The doc AST annotated assign node.
+
+        Returns
+        -------
+        res : Any
+            The visiting result.
+        """
+        return _dispatch(self, "AnnAssign")(self, node)
+
     def visit_Expr(self, node: doc.Expr) -> Any:  # pylint: 
disable=invalid-name
         """The general expression visiting method.
 
diff --git a/python/tvm/script/_parser/__init__.py 
b/python/tvm/script/_parser/tir/__init__.py
similarity index 70%
copy from python/tvm/script/_parser/__init__.py
copy to python/tvm/script/_parser/tir/__init__.py
index fd4e45818c..7754baf087 100644
--- a/python/tvm/script/_parser/__init__.py
+++ b/python/tvm/script/_parser/tir/__init__.py
@@ -13,7 +13,13 @@
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
-# under the Licens.
-"""The parser"""
-from . import _core, ir
-from .ir import ir_module
+# under the License.
+"""The tir parser"""
+
+from ...ir_builder.tir import *  # pylint: disable=redefined-builtin
+from ...ir_builder.tir import ir as _tir
+from . import operation as _operation
+from . import parser as _parser
+from .entry import Buffer, Ptr, prim_func
+
+__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"]
diff --git a/python/tvm/script/_parser/tir/entry.py 
b/python/tvm/script/_parser/tir/entry.py
new file mode 100644
index 0000000000..632b87aa24
--- /dev/null
+++ b/python/tvm/script/_parser/tir/entry.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The entry point of TVM parser for tir."""
+
+import inspect
+from typing import Callable, Union
+
+from tvm.tir import Buffer, PrimFunc
+
+from ...ir_builder.tir import buffer_decl, ptr
+from .._core import parse, utils
+from ..ir import is_defined_in_class
+
+
+def prim_func(func: Callable) -> Union[PrimFunc, Callable]:
+    """The parsing method for tir prim func, by using `@prim_func` as 
decorator.
+
+    Parameters
+    ----------
+    func : Callable
+        The function to be parsed as prim func.
+
+    Returns
+    -------
+    res : Union[PrimFunc, Callable]
+        The parsed tir prim func.
+    """
+    if not inspect.isfunction(func):
+        raise TypeError(f"Expect a function, but got: {func}")
+    if is_defined_in_class(inspect.stack()):
+        return func
+    return parse(func, utils.inspect_function_capture(func))
+
+
+setattr(prim_func, "dispatch_token", "tir")
+
+
+class BufferProxy:
+    """Buffer proxy class for constructing tir buffer.
+    Overload __call__ and __getitem__ to support syntax as T.Buffer() and 
T.Buffer[].
+    """
+
+    def __call__(
+        self,
+        shape,
+        dtype="float32",
+        data=None,
+        strides=None,
+        elem_offset=None,
+        scope="global",
+        align=0,
+        offset_factor=0,
+        buffer_type="",
+        axis_separators=None,
+    ) -> Buffer:
+        return buffer_decl(
+            shape,
+            dtype=dtype,
+            data=data,
+            strides=strides,
+            elem_offset=elem_offset,
+            scope=scope,
+            align=align,
+            offset_factor=offset_factor,
+            buffer_type=buffer_type,
+            axis_separators=axis_separators,
+        )
+
+    def __getitem__(self, keys) -> Buffer:
+        if not isinstance(keys, tuple):
+            return self(keys)
+        if len(keys) >= 2 and not isinstance(keys[1], str):
+            return self(keys)
+        return self(*keys)  # pylint: disable=no-member # type: ignore
+
+
+class PtrProxy:
+    """Ptr proxy class for constructing tir pointer.
+    Overload __call__ and __getitem__ to support syntax as T.Ptr() and T.Ptr[].
+    """
+
+    def __call__(self, dtype, storage_scope="global"):
+        if callable(dtype):
+            dtype = dtype().dtype
+        return ptr(dtype, storage_scope)  # pylint: disable=no-member # type: 
ignore
+
+    def __getitem__(self, keys):
+        if not isinstance(keys, tuple):
+            return self(keys)
+        return self(*keys)
+
+
+Buffer = BufferProxy()  # pylint: disable=invalid-name
+Ptr = PtrProxy()  # pylint: disable=invalid-name
diff --git a/python/tvm/script/_parser/tir/operation.py 
b/python/tvm/script/_parser/tir/operation.py
new file mode 100644
index 0000000000..ed8f07a063
--- /dev/null
+++ b/python/tvm/script/_parser/tir/operation.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The tir expression operation registration"""
+
+from typing import Type
+
+from tvm import tir
+from tvm.tir import IntImm
+
+from .._core import OpMethod, doc, register_op
+
+
+def _register_expr_op(ty: Type):  # pylint: disable=invalid-name
+    ty._dispatch_type = ty  # pylint: disable=protected-access
+
+    def _and(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.And(a, b)
+
+    def _or(a, b):
+        if isinstance(a, bool):
+            a = IntImm("bool", a)
+        if isinstance(b, bool):
+            b = IntImm("bool", b)
+        return tir.Or(a, b)
+
+    def r(op: Type, i: int, m: OpMethod):  # pylint: disable=invalid-name
+        register_op(ty, op, i)(m)
+
+    for i in [0, 1]:
+        # Case 1. binop
+        r(doc.Add, i, lambda a, b: a + b)
+        r(doc.Sub, i, lambda a, b: a - b)
+        r(doc.Mult, i, lambda a, b: a * b)
+        r(doc.Div, i, lambda a, b: a / b)
+        r(doc.FloorDiv, i, lambda a, b: a // b)
+        r(doc.Mod, i, lambda a, b: a % b)
+        r(doc.LShift, i, lambda a, b: a << b)
+        r(doc.RShift, i, lambda a, b: a >> b)
+        r(doc.BitOr, i, lambda a, b: a | b)
+        r(doc.BitXor, i, lambda a, b: a ^ b)
+        r(doc.BitAnd, i, lambda a, b: a & b)
+        # doc.MatMult <-- not implemented
+        # doc.Pow <-- not implemented
+        # Case 2. cmpop
+        r(doc.Eq, i, tir.EQ)
+        r(doc.NotEq, i, tir.NE)
+        r(doc.Lt, i, tir.LT)
+        r(doc.LtE, i, tir.LE)
+        r(doc.Gt, i, tir.GT)
+        r(doc.GtE, i, tir.GE)
+        # doc.Is <-- not implemented
+        # doc.IsNot <-- not implemented
+        # doc.In <-- not implemented
+        # doc.NotIn <-- not implemented
+        # Case 3. boolop
+        r(doc.And, i, _and)
+        r(doc.Or, i, _or)
+    for i in [0]:
+        #  Case 4. unaryop
+        r(doc.Invert, i, lambda a: ~a)
+        r(doc.Not, i, tir.Not)
+        r(doc.UAdd, i, lambda a: +a)
+        r(doc.USub, i, lambda a: -a)
+
+
+_register_expr_op(tir.PrimExpr)
+_register_expr_op(tir.IterVar)
diff --git a/python/tvm/script/_parser/tir/parser.py 
b/python/tvm/script/_parser/tir/parser.py
new file mode 100644
index 0000000000..909238563f
--- /dev/null
+++ b/python/tvm/script/_parser/tir/parser.py
@@ -0,0 +1,468 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The base parser for tir"""
+
+import contextlib
+from functools import partial
+from typing import Any
+
+from tvm.ir import PrimType
+from tvm.tir import Buffer, IterVar, PrimExpr, Var
+
+from ...ir_builder import tir as T
+from ...ir_builder.base import IRBuilder
+from ...ir_builder.base import IRBuilderFrame as Frame
+from .._core import Parser, dispatch, doc
+
+
+def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) 
-> Any:
+    """Value binding methods when parsing with statement.
+    e.g. binding i, j, k with T.grid(128, 128, 128), when parsing
+        with T.grid(128, 128, 18) as i, j, k.
+
+    Parameters
+    ----------
+    self : Parser
+        The current parser.
+
+    node : doc.expr
+        The doc AST expression node for error reporting.
+
+    var_name : str
+        The variable name.
+
+    value : Any
+        The value to be bound with.
+
+    Returns
+    -------
+    res : Any
+        The bound value.
+    """
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_with_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, (Buffer, Var)):
+        IRBuilder.name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} 
in with statement")
+        raise NotImplementedError
+
+
+def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> 
Any:
+    """Value binding methods when parsing for statement.
+    e.g. binding i, j, k with T.grid(128, 128, 128), when parsing
+        for i, j, k in T.grid(128, 128, 128).
+
+    Parameters
+    ----------
+    self : Parser
+        The current parser.
+
+    node : doc.expr
+        The doc AST expression node for error reporting.
+
+    var_name : str
+        The variable name.
+
+    value : Any
+        The value to be bound with.
+
+    Returns
+    -------
+    res : Any
+        The bound value.
+    """
+    if isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_for_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Var):
+        IRBuilder.name(var_name, value)
+        return value
+    else:
+        self.report_error(node, f"Do not know how to bind type: {type(value)} 
in for statement")
+        raise NotImplementedError
+
+
+def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) 
-> Any:
+    """Value binding methods when parsing assign statement.
+    e.g. binding vi, vj, vk with T.axis.remap("SSR", [i, j, k]), when parsing
+        vi, vj, vk = T.axis.remap("SSR", [i, j, k]).
+
+    Parameters
+    ----------
+    self : Parser
+        The current parser.
+
+    node : doc.expr
+        The doc AST expression node for error reporting.
+
+    var_name : str
+        The variable name.
+
+    value : Any
+        The value to be bound with.
+
+    Returns
+    -------
+    res : Any
+        The bound value.
+    """
+    if isinstance(value, T.inline):
+        return value.value
+    elif isinstance(value, (list, tuple)):
+        for i, v in enumerate(value):
+            bind_assign_value(self, node, f"{var_name}_{i}", v)
+        return value
+    elif isinstance(value, Frame):
+        value.add_callback(partial(value.__exit__, None, None, None))
+        res = value.__enter__()
+        IRBuilder.name(var_name, res)
+        return res
+    elif isinstance(value, (Buffer, IterVar)) or (
+        isinstance(value, Var) and not self.var_table.exist(value)
+    ):
+        IRBuilder.name(var_name, value)
+        return value
+    elif isinstance(value, PrimExpr):
+        var = T.var(value.dtype)
+        IRBuilder.name(var_name, var)
+        frame = T.let(var, value)
+        frame.add_callback(partial(frame.__exit__, None, None, None))
+        frame.__enter__()
+        return var
+    return value
+
+
[email protected](token="tir", type_name="For")
+def visit_for(self: Parser, node: doc.For) -> None:
+    """The for visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.For
+        The doc AST for node.
+    """
+    for_frame = self.eval_expr(node.iter)
+    if not isinstance(for_frame, T.frame.ForFrame):
+        self.report_error(
+            node.iter,
+            "Expect the for loop to be one of the following: "
+            "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, 
T.thread_binding",
+        )
+    with self.var_table.with_frame():
+        with for_frame as iters:
+            self.eval_assign(target=node.target, source=iters, 
bind_value=bind_for_value)
+            self.visit_body(node.body)
+
+
[email protected](token="tir", type_name="While")
+def visit_while(self: Parser, node: doc.While) -> None:
+    """The while visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.While
+        The doc AST while node.
+    """
+    with self.var_table.with_frame():
+        cond = self.eval_expr(node.test)
+        with T.While(cond):
+            self.visit_body(node.body)
+
+
[email protected](token="tir", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+    """The assign visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.Assign
+        The doc AST assign node.
+    """
+    if len(node.targets) != 1:
+        self.report_error(node, "Consequential assignments like 'a = b = c' 
are not supported.")
+    lhs = node.targets[0]
+    rhs = self.eval_expr(node.value)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
[email protected](token="tir", type_name="AugAssign")
+def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None:
+    """The augmented assign visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.AugAssign
+        The doc AST augmented assign node.
+    """
+    lhs_pos = (
+        node.target.lineno,
+        node.target.col_offset,
+        node.target.end_lineno,
+        node.target.end_col_offset,
+    )
+    rhs_pos = (
+        node.value.lineno,
+        node.value.col_offset,
+        node.value.end_lineno,
+        node.value.end_col_offset,
+    )
+    node.target.ctx = doc.Load(*lhs_pos)
+    with self.var_table.with_frame():
+        lhs_name = "__tvm_tmp_value_aug_assign_lhs"
+        rhs_name = "__tvm_tmp_value_aug_assign_rhs"
+        lhs_expr = self.eval_expr(node.target)
+        rhs_expr = self.eval_expr(node.value)
+        self.var_table.add(lhs_name, lhs_expr)
+        self.var_table.add(rhs_name, rhs_expr)
+        op = doc.BinOp(
+            doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos),
+            node.op,
+            doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos),
+            *lhs_pos,
+        )
+        rhs = self.eval_expr(op)
+    lhs = node.target
+    lhs.ctx = doc.Store(*lhs_pos)
+    if isinstance(lhs, doc.Subscript):
+        if isinstance(lhs.slice, doc.Tuple):
+            indices = []
+            for index in lhs.slice.elts:
+                indices.append(self.eval_expr(index))
+        else:
+            indices = [self.eval_expr(lhs.slice)]
+        T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
+    else:
+        self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
+
+
[email protected](token="tir", type_name="AnnAssign")
+def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
+    """The annotated assign visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.AnnAssign
+        The doc AST annotated assign node.
+    """
+    lhs = node.target
+    rhs = self.eval_expr(node.value)
+    ann_var = self.visit_tvm_annotation(node.annotation)
+    if not isinstance(ann_var, Var):
+        self.report_error(node.annotation, "Annotation should be Var")
+    self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
+    frame = T.let(ann_var, rhs)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
[email protected](token="tir", type_name="With")
+def visit_with(self: Parser, node: doc.With) -> None:
+    """The with visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.With
+        The doc AST with node.
+    """
+    with contextlib.ExitStack() as stack:
+        stack.enter_context(self.var_table.with_frame())
+        for item in node.items:
+            frame = self.eval_expr(item.context_expr)
+            if not isinstance(frame, Frame):
+                self.report_error(
+                    item.context_expr, "Invalid context expression in the 
with-statement."
+                )
+            rhs = stack.enter_context(frame)
+            if item.optional_vars is not None:
+                self.eval_assign(target=item.optional_vars, source=rhs, 
bind_value=bind_with_value)
+        self.visit_body(node.body)
+
+
[email protected](token="tir", type_name="FunctionDef")
+def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
+    """The function definition visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.FunctionDef
+        The doc AST function definition node.
+    """
+    with self.var_table.with_frame():
+        self.var_table.add("range", T.serial)
+        with T.prim_func():
+            T.func_name(node.name)
+            if node.returns is not None:
+                ret_type = self.eval_expr(node.returns)
+                if callable(ret_type):
+                    ret_type = PrimType(ret_type().dtype)
+                T.func_ret(ret_type)
+            with self.with_dispatch_token("tir"):
+                self.visit(node.args)
+                self.visit_body(node.body)
+
+
[email protected](token="tir", type_name="arguments")
+def visit_arguments(self: Parser, node: doc.arguments) -> None:
+    """The arguments visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.arguments
+        The doc AST arguments node.
+    """
+    # TODO: handle different types of arguments:
+    # - vararg: arg | None
+    # - kwonlyargs: list[arg]
+    # - kw_defaults: list[expr | None]
+    # - kwarg: arg | None
+    # - defaults: list[expr]
+    # - posonlyargs: list[arg]
+    arg: doc.arg
+    for arg in node.args:
+        if arg.annotation is None:
+            self.report_error(arg, "Type annotation is required for function 
parameters.")
+        param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation))
+        self.var_table.add(arg.arg, param)
+
+
[email protected](token="tir", type_name="tvm_annotation")
+def visit_tvm_annotation(self: Parser, node: doc.expr):
+    """The TVM annotation visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.expr
+        The doc AST expr node.
+    """
+    annotation = self.eval_expr(node)
+    if callable(annotation):
+        annotation = annotation()
+    return annotation
+
+
[email protected](token="tir", type_name="Expr")
+def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
+    """The expr statement visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.Expr
+        The doc AST Expr node.
+    """
+    res = self.eval_expr(node.value)
+    if isinstance(res, Frame):
+        res.add_callback(partial(res.__exit__, None, None, None))
+        res.__enter__()
+
+
[email protected](token="tir", type_name="If")
+def visit_if(self: Parser, node: doc.If) -> None:
+    """The if visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.If
+        The doc AST if node.
+    """
+    with self.var_table.with_frame():
+        with T.If(self.eval_expr(node.test)):
+            with T.Then():
+                self.visit_body(node.body)
+            if node.orelse:
+                with T.Else():
+                    self.visit_body(node.orelse)
+
+
[email protected](token="tir", type_name="Assert")
+def visit_assert(self: Parser, node: doc.Assert) -> None:
+    """The assert visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.Assert
+        The doc AST assert node.
+    """
+    cond = self.eval_expr(node.test)
+    msg = self.eval_expr(node.msg)
+    frame = T.Assert(cond, msg)
+    frame.add_callback(partial(frame.__exit__, None, None, None))
+    frame.__enter__()
+
+
[email protected](token="tir", type_name="Return")
+def visit_return(self: Parser, node: doc.Return) -> None:
+    """The return visiting method for tir.
+
+    Parameters
+    ----------
+    self : Parser
+        The visiting parser.
+
+    node : doc.Return
+        The doc AST return node.
+    """
+    self.report_error(node, "Return is not allowed.")
diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py 
b/tests/python/unittest/test_tvmscript_parser_tir.py
new file mode 100644
index 0000000000..cfa1dc62b3
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_parser_tir.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unittests for tvm.script.parser.tir"""
+
+import pytest
+import inspect
+import tvm.testing
+from tvm.script._parser import tir as T
+from tvm import ir, tir
+
+
+def test_tir_buffer_proxy():
+    buffer_0 = T.Buffer((128, 128), "float32")
+    assert (
+        isinstance(buffer_0, tir.Buffer)
+        and list(buffer_0.shape) == [128, 128]
+        and buffer_0.dtype == "float32"
+    )
+
+    buffer_1 = T.Buffer[(64, 64, 64), "int32"]
+    assert (
+        isinstance(buffer_1, tir.Buffer)
+        and list(buffer_1.shape) == [64, 64, 64]
+        and buffer_1.dtype == "int32"
+    )
+
+
+def test_tir_ptr_proxy():
+    ptr_0 = T.Ptr("int32", "global")
+    assert (
+        isinstance(ptr_0, tir.Var)
+        and ptr_0.dtype == "handle"
+        and isinstance(ptr_0.type_annotation, ir.PointerType)
+        and ptr_0.type_annotation.element_type == ir.PrimType("int32")
+        and ptr_0.type_annotation.storage_scope == "global"
+    )
+
+    ptr_1 = T.Ptr["float32", "shared"]
+    assert (
+        isinstance(ptr_1, tir.Var)
+        and ptr_1.dtype == "handle"
+        and isinstance(ptr_1.type_annotation, ir.PointerType)
+        and ptr_1.type_annotation.element_type == ir.PrimType("float32")
+        and ptr_1.type_annotation.storage_scope == "shared"
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to