This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 3effa45b1f [Unity][Debugging] AST printer (#14152)
3effa45b1f is described below
commit 3effa45b1fe7f68d2a57db0743c5051b850ecdc2
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Tue Feb 28 23:24:02 2023 -0500
[Unity][Debugging] AST printer (#14152)
This PR transfers over the AST printer from tlc-pack/relax. The AST printer
is a debugging tool that prints out a Relax AST in a precise and human-readable
format, which can be helpful for debugging the parser or various passes.
Co-authored-by: Yuchen Jin <[email protected]>
Co-authored-by: Lesheng Jin <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Tianqi Chen <[email protected]>
---
python/tvm/relax/testing/__init__.py | 1 +
python/tvm/relax/testing/ast_printer.py | 372 +++++++++++++++++++
tests/python/relax/test_ast_printer.py | 636 ++++++++++++++++++++++++++++++++
3 files changed, 1009 insertions(+)
diff --git a/python/tvm/relax/testing/__init__.py
b/python/tvm/relax/testing/__init__.py
index 7344798f70..a6e3a94251 100644
--- a/python/tvm/relax/testing/__init__.py
+++ b/python/tvm/relax/testing/__init__.py
@@ -19,3 +19,4 @@
from .nn import *
from .relay_translator import *
+from .ast_printer import dump_ast
diff --git a/python/tvm/relax/testing/ast_printer.py
b/python/tvm/relax/testing/ast_printer.py
new file mode 100644
index 0000000000..6727b24292
--- /dev/null
+++ b/python/tvm/relax/testing/ast_printer.py
@@ -0,0 +1,372 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin, abstract-method, arguments-differ
+"""
+Utility script for printing Relax modules as AST diagrams,
+only intended to show how the AST is put together.
+It is not a pretty-printer and, in fact, is more of an ugly-printer,
+but it can be useful for tutorials and debugging.
+"""
+from typing import Iterable
+import tvm
+from tvm import relax
+from tvm.ir.expr import PrimExpr
+from tvm.relax import ExprFunctor
+
+
+def wrap_quotes(text: str) -> str:
+ """
+ Wraps the text in quotes.
+ """
+ return f'"{text}"'
+
+
+class ASTPrinter(ExprFunctor):
+ """
+ Class for recursing down ASTs and printing them in a very simple format,
+ mainly for instructive purposes and, perhaps, debugging.
+ """
+
+ def __init__(
+ self,
+ indent_str=" ",
+ include_struct_info_annotations=True,
+ include_type_annotations=False,
+ include_call_attrs=True,
+ ):
+ self.indent_str = indent_str
+ self.include_type_annotations = include_type_annotations
+ self.include_struct_info_annotations = include_struct_info_annotations
+ self.include_call_attrs = include_call_attrs
+
+ def visit_expr(self, expr: relax.Expr) -> str:
+ # extend so we also dispatch to bindings and binding blocks,
+ # a little silly but IRFunctor hasn't been ported to Python
+ if isinstance(expr, relax.DataflowBlock):
+ return self.visit_dataflow_block_(expr)
+ if isinstance(expr, relax.BindingBlock):
+ return self.visit_binding_block_(expr)
+ if isinstance(expr, relax.Binding):
+ return self.visit_binding_(expr)
+ return super().visit_expr(expr)
+
+ def indent(self, text: str) -> str:
+ """
+ Indent all lines of the input.
+ """
+ if text == "":
+ return ""
+ lines = text.split("\n")
+ return self.indent_str + f"\n{self.indent_str}".join(lines)
+
+ def build_ast_node(self, nodename: str, force_newline=False, **kwargs:
str) -> str:
+ """
+ Returns 'nodename(..., fields[i][0]=fields[i][1], ...)'
+ with appropriate indentation
+ """
+ return self.build_list(
+ map(lambda field: f"{field[0]}={field[1]}", kwargs.items()),
+ open_tok=f"{nodename}(",
+ close_tok=")",
+ force_newline=force_newline,
+ )
+
+ def build_expr(self, node: relax.Expr, nodename: str, force_newline=False,
**kwargs: str):
+ """
+ Renders a Relax expression as a string using `build_ast_node`.
+ Handles whether to include the checked_type_ and struct_info fields.
+ """
+ fields = kwargs.copy()
+ if node.struct_info_ and self.include_struct_info_annotations:
+ fields["struct_info"] = self.visit_struct_info_(node.struct_info)
+ if node._checked_type_ and self.include_type_annotations:
+ fields["checked_type_"] = self.visit_type_(node.checked_type)
+ return self.build_ast_node(nodename, force_newline=force_newline,
**fields)
+
+ def build_list(
+ self, members: Iterable[str], open_tok="[", close_tok="]",
force_newline=False
+ ) -> str:
+ """
+ Builds a list of the members given, appropriately indented,
+ with each field on a line.
+ (special case: if there is only one field, then we do not put it on a
new line
+ unless that field contains a newline or `force_newline` is set to
true).
+ `open_tok` and `close_tok` are used to open and close the list,
respectively.
+ """
+ mem_list = list(members)
+ if not mem_list:
+ return f"{open_tok}{close_tok}"
+ if len(mem_list) == 1 and not force_newline and "\n" not in
mem_list[0]:
+ return f"{open_tok}{mem_list[0]}{close_tok}"
+ member_lines = ",\n".join(map(self.indent, mem_list))
+ return f"{open_tok}\n{member_lines}\n{close_tok}"
+
+ def visit_constant_(self, op: relax.Constant) -> str:
+ # simple rule of thumb: keep scalars inline, but anything larger goes
on a new one
+ force_newline = len(op.data.shape) > 0
+ return self.build_expr(op, "Constant", force_newline=force_newline,
data=str(op.data))
+
+ def visit_tuple_(self, op: relax.Tuple) -> str:
+ return self.build_expr(op, "Tuple",
fields=self.build_list(map(self.visit_expr, op.fields)))
+
+ def visit_dataflow_var_(self, op: relax.DataflowVar) -> str:
+ return self.build_expr(op, "DataflowVar",
name_hint=wrap_quotes(op.name_hint))
+
+ def visit_var_(self, op: relax.Var) -> str:
+ return self.build_expr(op, "Var", name_hint=wrap_quotes(op.name_hint))
+
+ def visit_shape_expr_(self, op: relax.ShapeExpr) -> str:
+ return self.build_expr(
+ op, "ShapeExpr", values=self.build_list(map(self.visit_prim_expr_,
op.values))
+ )
+
+ def visit_extern_func_(self, op: relax.ExternFunc) -> str:
+ # ExternFunc does not inherit from relax.Expr either,
+ # so it doesn't have checked_type_ or struct_info fields and we don't
use build_expr
+ return self.build_ast_node("ExternFunc",
global_symbol=wrap_quotes(op.global_symbol))
+
+ def visit_global_var_(self, op: relax.GlobalVar) -> str:
+ return self.build_expr(op, "GlobalVar",
name_hint=wrap_quotes(op.name_hint))
+
+ def visit_function_(self, op: relax.Function) -> str:
+ fields = {
+ "params": self.build_list(map(self.visit_expr, op.params)),
+ "body": self.visit_expr(op.body),
+ "ret_struct_info": self.visit_struct_info_(op.ret_struct_info),
+ }
+ if op.attrs:
+ fields["attrs"] = self.build_list(
+ map(
+ lambda kv: f"{wrap_quotes(str(kv[0]))}:
{wrap_quotes(str(kv[1]))}",
+ op.attrs.items(),
+ ),
+ open_tok="{",
+ close_tok="}",
+ )
+ return self.build_expr(op, "Function", **fields)
+
+ def visit_call_(self, op: relax.Call) -> str:
+ fields = {
+ "op": self.visit_expr(op.op),
+ "args": self.build_list(map(self.visit_expr, op.args)),
+ }
+ if op.sinfo_args:
+ fields["sinfo_args"] =
self.build_list(map(self.visit_struct_info_, op.sinfo_args))
+ if op.attrs and self.include_call_attrs:
+
+ def display_attrs(attr_key):
+ attr_val = op.attrs[attr_key]
+ # attrs can be strings but also other types;
+ # we want to wrap strings in quotes
+ # (__repr__ would work but it uses single quotes)
+ attr_str = wrap_quotes(attr_val) if isinstance(attr_val, str)
else str(attr_val)
+ return f"{wrap_quotes(attr_key)}: {attr_str}"
+
+ fields["attrs"] = self.build_list(
+ map(display_attrs, op.attrs.keys()),
+ open_tok="{",
+ close_tok="}",
+ )
+ return self.build_expr(op, "Call", **fields)
+
+ def visit_seq_expr_(self, op: relax.SeqExpr) -> str:
+ return self.build_expr(
+ op,
+ "SeqExpr",
+ blocks=self.build_list(map(self.visit_binding_block_, op.blocks)),
+ body=self.visit_expr(op.body),
+ )
+
+ def visit_if_(self, op: relax.If) -> str:
+ return self.build_expr(
+ op,
+ "If",
+ cond=self.visit_expr(op.cond),
+ true_branch=self.visit_expr(op.true_branch),
+ false_branch=self.visit_expr(op.false_branch),
+ )
+
+ def visit_prim_value_(self, op: relax.PrimValue) -> str:
+ return self.build_expr(op, "PrimValue",
value=self.visit_prim_expr_(op.value))
+
+ def visit_string_imm_(self, op: relax.StringImm) -> str:
+ return self.build_expr(op, "StringImm", value=wrap_quotes(op.value))
+
+ def visit_data_type_imm_(self, op: relax.DataTypeImm) -> str:
+ return self.build_expr(op, "DataTypeImm", value=op.value)
+
+ def visit_op_(self, op: tvm.ir.Op) -> str:
+ # TODO: List other attributes?
+ # op is not actually a Relax expr and does not have checked_type_
+ # or struct_info fields, so we don't use build_expr here
+ return self.build_ast_node("Op", name=wrap_quotes(op.name))
+
+ def visit_prim_expr_(self, prim_expr: PrimExpr) -> str:
+ # TODO: We may want to print PrimExpr ASTs, but this is a
simplification for now
+ return self.build_ast_node("PrimExpr", value=f"`{str(prim_expr)}`")
+
+ def visit_tuple_getitem_(self, op: relax.TupleGetItem) -> str:
+ return self.build_expr(
+ op,
+ "TupleGetItem",
+ tuple_value=self.visit_expr(op.tuple_value),
+ index=str(op.index),
+ )
+
+ def visit_type_(self, type_node: relax.Type) -> str:
+ """
+ Recurse down types and print their ASTs too
+ """
+ if isinstance(type_node, relax.ShapeType):
+ return self.build_ast_node("ShapeType", ndim=str(type_node.ndim))
+ if isinstance(type_node, relax.ObjectType):
+ return self.build_ast_node("ObjectType")
+ if isinstance(type_node, relax.PackedFuncType):
+ return self.build_ast_node("PackedFuncType")
+ if isinstance(type_node, tvm.ir.PrimType):
+ return self.build_ast_node("PrimType", dtype=type_node.dtype)
+ if isinstance(type_node, relax.DynTensorType):
+ fields = {}
+ if type_node.ndim is not None:
+ fields["ndim"] = str(type_node.ndim)
+ if type_node.dtype != "":
+ fields["dtype"] = type_node.dtype
+ return self.build_ast_node("DynTensorType", **fields)
+ if isinstance(type_node, relax.TupleType):
+ return self.build_ast_node(
+ "TupleType", fields=self.build_list(map(self.visit_type_,
type_node.fields))
+ )
+ if isinstance(type_node, relax.FuncType):
+ return self.build_ast_node(
+ "FuncType",
+ arg_types=self.build_list(map(self.visit_type_,
type_node.arg_types)),
+ ret_type=self.visit_type_(type_node.ret_type),
+ # TODO: skipping type params and type constraints
+ )
+ raise ValueError(f"Invalid Relax Type {type_node} ({type(type_node)})")
+
+ def visit_struct_info_(self, struct_info_node: relax.StructInfo) -> str:
+ """
+ Recurse down struct info and print their ASTs too
+ """
+ if isinstance(struct_info_node, relax.ShapeStructInfo):
+ fields = {}
+ fields["ndim"] = str(struct_info_node.ndim)
+ if struct_info_node.values is not None:
+ fields["values"] = self.build_list(
+ map(self.visit_prim_expr_, struct_info_node.values)
+ )
+ return self.build_ast_node("ShapeStructInfo", **fields)
+ elif isinstance(struct_info_node, relax.ObjectStructInfo):
+ return self.build_ast_node("ObjectStructInfo")
+ elif isinstance(struct_info_node, relax.PrimStructInfo):
+ return self.build_ast_node("PrimStructInfo",
dtype=struct_info_node.dtype)
+ elif isinstance(struct_info_node, relax.TensorStructInfo):
+ fields = {}
+ fields["dtype"] = struct_info_node.dtype
+ if struct_info_node.shape:
+ fields["shape"] = self.visit_expr(struct_info_node.shape)
+ else:
+ fields["ndim"] = str(struct_info_node.ndim)
+ return self.build_ast_node("TensorStructInfo", **fields)
+ elif isinstance(struct_info_node, relax.TupleStructInfo):
+ return self.build_ast_node(
+ "TupleStructInfo",
+ fields=self.build_list(map(self.visit_struct_info_,
struct_info_node.fields)),
+ )
+ elif isinstance(struct_info_node, relax.FuncStructInfo):
+ fields = {}
+ if struct_info_node.params is not None:
+ fields["params"] = self.build_list(
+ map(self.visit_struct_info_, struct_info_node.params)
+ )
+ fields["ret"] = self.visit_struct_info_(struct_info_node.ret)
+ return self.build_ast_node("FuncStructInfo", **fields)
+ else:
+ raise ValueError(
+ f"Invalid Relax StructInfo {struct_info_node}
({type(struct_info_node)})"
+ )
+
+ def visit_binding_block_(self, block: relax.BindingBlock) -> str:
+ """
+ Recurse down binding blocks
+ """
+ return self.build_ast_node(
+ "BindingBlock",
+ bindings=self.build_list(map(self.visit_binding_, block.bindings),
force_newline=True),
+ )
+
+ def visit_dataflow_block_(self, block: relax.DataflowBlock) -> str:
+ """
+ Recurse down a dataflow block
+ """
+ return self.build_ast_node(
+ "DataflowBlock",
+ bindings=self.build_list(map(self.visit_binding_, block.bindings),
force_newline=True),
+ )
+
+ def visit_binding_(self, binding: relax.Binding) -> str:
+ """
+ Distinguish between binding types
+ """
+ if isinstance(binding, relax.MatchCast):
+ return self.visit_match_cast_(binding)
+ if isinstance(binding, relax.VarBinding):
+ return self.visit_var_binding_(binding)
+ raise ValueError(f"Invalid binding type in {binding}: {type(binding)}")
+
+ def visit_match_cast_(self, match_cast: relax.MatchCast) -> str:
+ """
+ Handle match shape
+ """
+ fields = {
+ "var": self.visit_expr(match_cast.var),
+ "value": self.visit_expr(match_cast.value),
+ "struct_info": self.visit_struct_info_(match_cast.struct_info),
+ }
+ return self.build_ast_node("MatchCast", **fields)
+
+ def visit_var_binding_(self, var_binding: relax.VarBinding) -> str:
+ """
+ Handle ordinary var bindings
+ """
+ return self.build_ast_node(
+ "VarBinding",
+ var=self.visit_expr(var_binding.var),
+ value=self.visit_expr(var_binding.value),
+ )
+
+
+def dump_ast(
+ exp: relax.Expr,
+ indent_str=" ",
+ include_struct_info_annotations=True,
+ include_type_annotations=False,
+ include_call_attrs=True,
+) -> str:
+ """
+ Dump an AST in a text format.
+ Can vary the indentation string and choose whether to include
+ type and shape annotations or call attributes.
+ """
+ printer = ASTPrinter(
+ indent_str=indent_str,
+ include_struct_info_annotations=include_struct_info_annotations,
+ include_type_annotations=include_type_annotations,
+ include_call_attrs=include_call_attrs,
+ )
+ return printer.visit_expr(exp)
diff --git a/tests/python/relax/test_ast_printer.py
b/tests/python/relax/test_ast_printer.py
new file mode 100644
index 0000000000..ba3c930a45
--- /dev/null
+++ b/tests/python/relax/test_ast_printer.py
@@ -0,0 +1,636 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import re
+from functools import partial
+from typing import Dict
+
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import relax as rx
+from tvm import tir
+from tvm.relax.testing import dump_ast
+from tvm.relax.testing.ast_printer import ASTPrinter
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+# Overload dump_ast to test both struct info and type annotations
+dump_ast = partial(dump_ast, include_struct_info_annotations=True,
include_type_annotations=True)
+
+
+def strip_whitespace(text: str) -> str:
+ """
+ Remove all whitespace to avoid reasoning about newlines and indents
+ """
+ return re.sub(r"\s", "", text)
+
+
+def normalize(func: rx.Function) -> rx.Function:
+ """
+ Normalize the expr to fill in the checked_type_ and struct_info fields
everywhere
+ """
+ # using a default mutator to use the BlockBuilder's normalizer,
+ # which oddly differs from the Normalize pass
+ @rx.expr_functor.mutator
+ class DefaultMutator(rx.PyExprMutator):
+ pass
+
+ mod = tvm.IRModule()
+ mod["main"] = func
+ mut = DefaultMutator(mod)
+ mod["main"] = mut.visit_expr(func)
+ return mod["main"]
+
+
+def assert_fields(nodename: str, fields: Dict[str, str], target: str) -> None:
+ """
+ Given a target string, ensure that the string defines the specified node
+ and that the given mappings of fields to values are present in the string.
+ Strips all whitespace in the target and fields.
+ Does not assume any particular ordering for the fields.
+ """
+ stripped_target = strip_whitespace(target)
+ assert stripped_target.startswith(f"{nodename}(")
+ for field, value in fields.items():
+ assert f"{field}={strip_whitespace(value)}" in stripped_target
+
+
+# test cases are mostly adapted from text_expr, only testing very basic
properties
+
+
+def test_var() -> None:
+ v0 = rx.Var("v0")
+ v0_str = dump_ast(v0)
+ assert v0_str == 'Var(name_hint="v0")'
+
+ v1 = rx.Var("v1", R.Tensor([54, 96], "float32"))
+ v1_no_annos = dump_ast(
+ v1, include_struct_info_annotations=False,
include_type_annotations=False
+ )
+ assert v1_no_annos == 'Var(name_hint="v1")'
+ v1_annos = dump_ast(v1)
+ assert v1_annos != v1_no_annos
+ assert "PrimExpr" in v1_annos
+ assert "struct_info" in v1_annos
+ assert "checked_type_" in v1_annos
+
+
+def test_dataflow_var() -> None:
+ v0 = rx.DataflowVar("v0")
+ v0_str = dump_ast(v0)
+ assert v0_str == 'DataflowVar(name_hint="v0")'
+
+ v1 = rx.DataflowVar("v1", R.Tensor([54, 96], "float16"))
+ v1_no_annos = dump_ast(
+ v1, include_struct_info_annotations=False,
include_type_annotations=False
+ )
+ assert v1_no_annos == 'DataflowVar(name_hint="v1")'
+ v1_annos = dump_ast(v1)
+ assert v1_annos != v1_no_annos
+ assert "PrimExpr" in v1_annos
+ assert "struct_info" in v1_annos
+ assert "checked_type_" in v1_annos
+
+
+def test_match_cast() -> None:
+ # match_cast([16, 8], [m, n])
+ m = tir.Var("m", dtype="int64")
+ n = tir.Var("n", dtype="int64")
+ shape = rx.const([16, 8], "int32")
+ var = rx.Var("v0", R.Shape())
+ b0 = rx.MatchCast(var, shape, R.Tensor([m, n], "int32"))
+ b0_str = dump_ast(b0)
+ assert b0_str.startswith("MatchCast(")
+ assert "Constant" in b0_str
+ assert "PrimExpr(value=`m" in b0_str
+ assert "PrimExpr(value=`n" in b0_str
+ assert "16" in b0_str
+ assert "8" in b0_str
+ assert b0_str != dump_ast(b0, include_type_annotations=False)
+
+ # var1: Tensor((m, n), "float32") =
+ # match_cast(var0: R.Tensor("float32"), [m, n])
+ value = rx.Var("value", R.Tensor("float32"))
+ var = rx.Var("v1", R.Tensor([m, n], "float32"))
+ b1 = rx.MatchCast(var, value, R.Tensor([m, n], "float32"))
+ b1_str = dump_ast(b1)
+ assert b1_str.startswith("MatchCast(")
+ assert "PrimExpr(value=`m" in b1_str
+ assert "PrimExpr(value=`n" in b1_str
+ assert b1_str != dump_ast(
+ b1, include_type_annotations=False,
include_struct_info_annotations=False
+ )
+
+
+def test_var_binding() -> None:
+ v0 = rx.Var("v0")
+ val = rx.const(np.random.rand(24, 56))
+ b0 = rx.VarBinding(v0, val)
+ b0_str = dump_ast(b0, include_type_annotations=False,
include_struct_info_annotations=False)
+ assert b0_str.startswith("VarBinding(")
+ assert 'var=Var(name_hint="v0")' in b0_str
+ assert "value=" in b0_str
+ assert "Constant(" in b0_str
+
+
+def test_binding_block() -> None:
+ m = tir.Var("m", dtype="int64")
+ n = tir.Var("n", dtype="int64")
+ shape = rx.const([16, 8], "int32")
+ b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32"))
+
+ v0 = rx.Var("v0")
+ val = rx.const(np.random.rand(24, 56))
+ b1 = rx.VarBinding(v0, val)
+
+ block0 = rx.BindingBlock([b0, b1])
+ block0_str = dump_ast(block0)
+ assert block0_str.startswith("BindingBlock(")
+ assert "bindings=" in block0_str
+ assert "VarBinding(" in block0_str
+ assert "MatchCast(" in block0_str
+ assert '"v0"' in block0_str
+
+
+def test_dataflow_block() -> None:
+ m = tir.Var("m", dtype="int64")
+ n = tir.Var("n", dtype="int64")
+ shape = rx.const([16, 8], "int32")
+ b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32"))
+
+ v0 = rx.Var("v0")
+ val = rx.const(np.random.rand(24, 56))
+ b1 = rx.VarBinding(v0, val)
+
+ block0 = rx.DataflowBlock([b0, b1])
+ block0_str = dump_ast(block0)
+ assert block0_str.startswith("DataflowBlock(")
+ assert "bindings=" in block0_str
+ assert "VarBinding(" in block0_str
+ assert "MatchCast(" in block0_str
+ assert '"v0"' in block0_str
+
+
+def test_seq_expr() -> None:
+ x = rx.Var("foo")
+ bindings = [rx.VarBinding(x, rx.const(1))]
+ blocks = [rx.BindingBlock(bindings)]
+ seqe = rx.SeqExpr(blocks, x)
+ seqe_str = dump_ast(seqe)
+ assert seqe_str.startswith("SeqExpr(")
+ assert "blocks=" in seqe_str
+ assert "BindingBlock(" in seqe_str
+ assert "VarBinding(" in seqe_str
+ assert "Constant(" in seqe_str
+ assert 'var=Var(name_hint="foo")' in seqe_str
+ assert "value=Constant(data" in strip_whitespace(seqe_str)
+ assert "body=" in seqe_str
+
+
+def test_shape_expr() -> None:
+ m = tir.Var("m", dtype="int32")
+ n = tir.Var("n", dtype="int32")
+ s = rx.ShapeExpr([m, n])
+ s_str = dump_ast(s)
+ assert s_str.startswith("ShapeExpr(")
+ assert "values=" in s_str
+ assert "PrimExpr(value=`m: int32`)" in s_str
+ assert "PrimExpr(value=`n: int32`)" in s_str
+
+
+def test_func():
+ x = rx.Var("foo", R.Tensor("float32", ndim=2))
+ bindings = [rx.VarBinding(x, rx.const(1))]
+ blocks = [rx.BindingBlock(bindings)]
+ seqe = rx.SeqExpr(blocks, x)
+ func = rx.Function([x], seqe, R.Tensor("float32"))
+ func = func.with_attr("global_symbol", "func")
+
+ func_str = dump_ast(func)
+ assert func_str.startswith("Function(")
+ assert "params=" in func_str
+ assert "body=" in func_str
+ assert "ret_struct_info=" in func_str
+ assert "attrs=" in func_str
+ assert '"global_symbol": "func"' in func_str
+ assert "SeqExpr(" in func_str
+ assert "blocks=" in func_str
+ assert "VarBinding(" in func_str
+ assert func_str != dump_ast(func, include_type_annotations=False)
+
+
+def test_shape_of():
+ v0 = rx.Var("v0", R.Tensor(ndim=2))
+ s0 = rx.get_shape_of(v0)
+ s0_str = dump_ast(s0)
+ assert s0_str.startswith("Call(")
+ assert 'op=Op(name="relax.shape_of")' in s0_str
+ assert "args=" in s0_str
+ assert 'name_hint="v0"' in s0_str
+
+ v1 = rx.Var("v1", R.Tensor([96, 54]))
+ s1 = rx.get_shape_of(v1)
+ s1_str = dump_ast(s1)
+ assert s1_str.startswith("ShapeExpr("), s1_str
+ assert "values=" in s1_str
+ assert "PrimExpr(value=`T.int64(96)`)" in s1_str
+ assert "PrimExpr(value=`T.int64(54)`)" in s1_str
+
+
+def test_shape_expr():
+ shape_expr = rx.ShapeExpr([10, 20])
+ shape_expr_str = dump_ast(shape_expr)
+ assert shape_expr_str.startswith("ShapeExpr(")
+ assert "values" in shape_expr_str
+ assert "PrimExpr(value=`T.int64(10)`)" in shape_expr_str
+ assert "PrimExpr(value=`T.int64(20)`)" in shape_expr_str
+
+
+def test_types():
+ printer = ASTPrinter()
+ assert strip_whitespace(printer.visit_type_(rx.ShapeType())) ==
"ShapeType(ndim=-1)"
+ assert strip_whitespace(printer.visit_type_(rx.ShapeType(ndim=1))) ==
"ShapeType(ndim=1)"
+ object_type = rx.ObjectType()
+ assert strip_whitespace(printer.visit_type_(object_type)) == "ObjectType()"
+ packed_type = rx.PackedFuncType()
+ assert strip_whitespace(printer.visit_type_(packed_type)) ==
"PackedFuncType()"
+ tensor_type = rx.DynTensorType(ndim=2, dtype="int32")
+ assert strip_whitespace(printer.visit_type_(tensor_type)) ==
"DynTensorType(ndim=2,dtype=int32)"
+ unit_type = rx.TupleType([])
+ assert strip_whitespace(printer.visit_type_(unit_type)) ==
"TupleType(fields=[])"
+ tuple_type = rx.TupleType([rx.ShapeType(), object_type])
+ assert_fields(
+ "TupleType",
+ {"fields": "[ShapeType(ndim=-1),ObjectType()]"},
+ strip_whitespace(printer.visit_type_(tuple_type)),
+ )
+
+ func_type = rx.FuncType([tensor_type], unit_type)
+ assert_fields(
+ "FuncType",
+ {"arg_types": "[DynTensorType(ndim=2, dtype=int32)]", "ret_type":
"TupleType(fields=[])"},
+ printer.visit_type_(func_type),
+ )
+
+
+def test_struct_info():
+ printer = ASTPrinter(include_type_annotations=True)
+
+ assert printer.visit_struct_info_(rx.ObjectStructInfo()) ==
"ObjectStructInfo()"
+
+ assert printer.visit_struct_info_(rx.PrimStructInfo("int32")) ==
"PrimStructInfo(dtype=int32)"
+
+ # empty shape
+ empty_ssi = rx.ShapeStructInfo()
+ assert printer.visit_struct_info_(empty_ssi) == "ShapeStructInfo(ndim=-1)"
+
+ # include some dimensions
+ shape_info = rx.ShapeStructInfo([tir.IntImm("int64", 1),
tir.IntImm("int64", 2)])
+ assert strip_whitespace(printer.visit_struct_info_(shape_info)) ==
strip_whitespace(
+ """
+ ShapeStructInfo(
+ ndim=2,
+ values=[
+ PrimExpr(value=`T.int64(1)`),
+ PrimExpr(value=`T.int64(2)`)
+ ]
+ )
+ """
+ )
+
+ # tensor struct info
+ default_tsi = rx.TensorStructInfo()
+ assert (
+ strip_whitespace(printer.visit_struct_info_(default_tsi))
+ == "TensorStructInfo(dtype=float32,ndim=-1)"
+ )
+
+ # use a var as the shape
+ x = rx.Var("x", struct_info=rx.ShapeStructInfo(values=[]))
+ var_tsi = rx.TensorStructInfo(shape=x, dtype="int32")
+ assert strip_whitespace(printer.visit_struct_info_(var_tsi)) ==
strip_whitespace(
+ """
+ TensorStructInfo(
+ dtype=int32,
+ shape=Var(
+ name_hint="x",
+ struct_info=ShapeStructInfo(ndim=0, values=[]),
+ checked_type_=ShapeType(ndim=0)
+ )
+ )
+ """
+ )
+
+ empty_tuple = rx.TupleStructInfo([])
+ assert printer.visit_struct_info_(empty_tuple) ==
"TupleStructInfo(fields=[])"
+
+ tuple_of_shape = rx.TupleStructInfo([empty_ssi])
+ assert strip_whitespace(printer.visit_struct_info_(tuple_of_shape)) ==
strip_whitespace(
+ """
+ TupleStructInfo(fields=[
+ ShapeStructInfo(ndim=-1)
+ ])
+ """
+ )
+
+ simple_func = rx.FuncStructInfo([], rx.ObjectStructInfo())
+ assert (
+ strip_whitespace(printer.visit_struct_info_(simple_func))
+ == "FuncStructInfo(params=[],ret=ObjectStructInfo())"
+ )
+
+
+def test_call_packed():
+ # test case from test_parser
+ @R.function
+ def f(
+ x: R.Tensor((32, "m"), "float32"),
+ y: R.Tensor(("m",), "float32"),
+ r: R.Tensor(dtype="int64"),
+ ) -> R.Object:
+ m = T.var("int64")
+ z: R.Tensor((32, m), "float32") = R.multiply(x, y)
+ w: R.Tensor = R.multiply(z, z)
+ q: R.Tensor(ndim=2) = R.add(w, w)
+ t = R.add(w, z)
+ sh: R.Shape = R.shape_of(t)
+ o: R.Object = R.call_packed(
+ "contrib.tensor_array_stack", x, y, sinfo_args=R.Object(),
test_attr=True
+ )
+ return o
+
+ # checking that the call_packed call is turned into a call to an extern
func
+ f_str = strip_whitespace(
+ dump_ast(
+ f,
+ include_type_annotations=False,
+ include_struct_info_annotations=False,
+ include_call_attrs=True,
+ )
+ )
+
+ # the function has an annotated return type
+ assert "ret_struct_info=ObjectStructInfo()" in f_str
+
+ assert isinstance(f.body, rx.SeqExpr)
+ extern_call = f.body.blocks[0].bindings[-1].value
+ extern_call_text = dump_ast(
+ extern_call,
+ include_type_annotations=False,
+ include_struct_info_annotations=False,
+ include_call_attrs=True,
+ )
+ assert strip_whitespace(extern_call_text) in f_str
+ assert_fields(
+ "Call",
+ {
+ "op": 'ExternFunc(global_symbol="contrib.tensor_array_stack")',
+ "args": '[Var(name_hint="x"), Var(name_hint="y")]',
+ "sinfo_args": "[ObjectStructInfo()]",
+ "attrs": '{"test_attr": 1}',
+ },
+ extern_call_text,
+ )
+
+ # check that the op call is there too
+ op_call = f.body.blocks[0].bindings[0].value
+ op_call_text = dump_ast(
+ op_call,
+ include_type_annotations=False,
+ include_struct_info_annotations=False,
+ include_call_attrs=True,
+ )
+ assert strip_whitespace(op_call_text) in f_str
+ assert_fields(
+ "Call",
+ {
+ "op": 'Op(name="relax.multiply")',
+ "args": '[Var(name_hint="x"), Var(name_hint="y")]',
+ },
+ op_call_text,
+ )
+
+ # TODO: add testcase for op attrs
+
+
+def test_call_tir():
+ # also from test_parser
+ @R.function
+ def foo(x: R.Tensor(("m", "n"), "float32")):
+ m, n = T.var("int64"), T.var("int64")
+ gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n),
dtype="float32"))
+ return gv0
+
+ foo_str = strip_whitespace(
+ dump_ast(
+ foo,
+ include_type_annotations=False,
+ include_struct_info_annotations=False,
+ include_call_attrs=False,
+ )
+ )
+ assert foo_str.startswith('Function(params=[Var(name_hint="x")]')
+
+ # call_tir is an op in Relax and it takes an extern func as an argument
+ assert isinstance(foo.body, rx.SeqExpr)
+ tir_call = foo.body.blocks[0].bindings[0].value
+ tir_call_text = dump_ast(
+ tir_call,
+ include_type_annotations=False,
+ include_struct_info_annotations=False,
+ include_call_attrs=False,
+ )
+ assert_fields(
+ "Call",
+ {
+ "op": 'Op(name="relax.call_tir")',
+ "args": """[
+ ExternFunc(global_symbol="test.op.identity"),
+ Tuple(fields=[Var(name_hint="x")])
+ ]""",
+ "sinfo_args": """[
+ TensorStructInfo(
+ dtype=float32,
+ shape=ShapeExpr(
+ values=[
+ PrimExpr(value=`m`),
+ PrimExpr(value=`n`)
+ ]
+ )
+ )
+ ]""",
+ },
+ tir_call_text,
+ )
+ assert strip_whitespace(tir_call_text) in foo_str
+
+
+def test_operators():
+ @R.function
+ def foo(x: R.Tensor):
+ return R.unique(x, sorted=True, axis=-1)
+
+ foo_str = strip_whitespace(
+ dump_ast(
+ foo,
+ include_type_annotations=False,
+ include_struct_info_annotations=False,
+ )
+ )
+ assert 'Op(name="relax.unique")' in foo_str
+ # the sorted argument is true, so it will be a PrimValue of 1
+ assert "PrimExpr(value=`T.int64(1)`)" in foo_str
+ # axis is -1
+ assert "PrimExpr(value=`T.int64(-1)`)" in foo_str
+
+ @R.function
+ def bar(x: R.Tensor):
+ return R.print(x, format="{}")
+
+ bar_str = strip_whitespace(
+ dump_ast(
+ bar,
+ include_type_annotations=False,
+ include_struct_info_annotations=False,
+ )
+ )
+ # the format string is a StringImm argument
+ assert 'StringImm(value="{}")' in bar_str
+
+
+def test_print_struct_info_annotation_non_var():
+ @R.function
+ def f() -> R.Tensor:
+ return R.const([1, 2])
+
+ body = normalize(f).body
+ body_str = strip_whitespace(dump_ast(body))
+ # the constant has a shape of (2,)
+ struct_info = strip_whitespace(
+ """
+ struct_info=TensorStructInfo(
+ dtype=int32,
+ shape=ShapeExpr(
+ values=[PrimExpr(value=`T.int64(2)`)],
+ struct_info=ShapeStructInfo(
+ ndim=1,
+ values=[PrimExpr(value=`T.int64(2)`)]
+ ),
+ checked_type_=ShapeType(ndim=1)
+ )
+ )
+ """
+ )
+ assert struct_info in body_str
+
+
+def test_print_type_annotation_non_var():
+ @R.function
+ def f() -> R.Shape:
+ return R.shape_of(R.const(1))
+
+ body = normalize(f).body
+ assert isinstance(body, rx.SeqExpr)
+ call = body.blocks[-1].bindings[-1].value
+ assert isinstance(call, rx.Call)
+ arg = call.args[0]
+ arg_str = strip_whitespace(dump_ast(arg))
+ # the constant should have a tensor type
+ assert "checked_type_=DynTensorType(ndim=0" in arg_str
+
+ call_str = strip_whitespace(dump_ast(call))
+ # we expect the shape_of call to have a checked_type_ of ShapeType
+ type_str = "checked_type_=ShapeType(ndim=-1)"
+ assert type_str in call_str
+
+
+def test_if():
+ @R.function
+ def f(cond: R.Tensor((), dtype="bool")) -> R.Tensor((), dtype="int32"):
+ if cond:
+ x = R.const(1)
+ else:
+ x = R.const(2)
+ return x
+
+ body = normalize(f).body
+ assert isinstance(body, rx.SeqExpr)
+ body_str = strip_whitespace(dump_ast(body))
+ # we expect both branches to be seq exprs
+ assert "If" in body_str
+ assert "true_branch=SeqExpr(" in body_str
+ assert "false_branch=SeqExpr(" in body_str
+
+
+def test_tuple_get_item():
+ @R.function
+ def f(x: R.Tuple(R.Tensor((), dtype="int32"))) -> R.Tensor((),
dtype="int32"):
+ return x[0]
+
+ body = normalize(f).body
+ assert isinstance(body, rx.SeqExpr)
+ body_str = strip_whitespace(dump_ast(body))
+
+ assert "TupleGetItem" in body_str
+ assert 'tuple_value=Var(name_hint="x"' in body_str
+ assert "index=0" in body_str
+
+
+def test_prim_value():
+ prim_value = rx.PrimValue(tir.IntImm("int64", 1))
+ prim_str = strip_whitespace(dump_ast(prim_value))
+ assert prim_str == strip_whitespace(
+ """
+ PrimValue(
+ value=PrimExpr(value=`T.int64(1)`),
+ struct_info=PrimStructInfo(dtype=int64),
+ checked_type_=PrimType(dtype=int64)
+ )
+ """
+ )
+
+
+def test_string_imm():
+ string_imm = rx.StringImm("test")
+ str_str = strip_whitespace(dump_ast(string_imm))
+ assert str_str == strip_whitespace(
+ """
+ StringImm(
+ value="test",
+ struct_info=ObjectStructInfo(),
+ checked_type_=ObjectType()
+ )
+ """
+ )
+
+
+def test_datatype_imm():
+ data_type_imm = rx.DataTypeImm("int32")
+ data_type_str = strip_whitespace(dump_ast(data_type_imm))
+ assert data_type_str == strip_whitespace(
+ """
+ DataTypeImm(
+ value=int32,
+ struct_info=ObjectStructInfo(),
+ checked_type_=ObjectType()
+ )
+ """
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()