This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new a83f11dde8 [Unity] Relax TVMScript Printer (#13944)
a83f11dde8 is described below
commit a83f11dde88163edbce7cb8e847acc5589a40f52
Author: Junru Shao <[email protected]>
AuthorDate: Fri Feb 10 07:41:11 2023 -0800
[Unity] Relax TVMScript Printer (#13944)
This PR introduces Relax as a dialect supported by the TVMScript
Printer. Some caveats:
- Needs to rebase to mainline before merging.
- Some tests are skiped because some operators are not upstreamed to
the unity branch yet.
Co-authored-by: Tianqi Chen <[email protected]>
Co-authored-by: Yuchen Jin <[email protected]>
Co-authored-by: Steven S. Lyubomirsky <[email protected]>
Co-authored-by: Yong Wu <[email protected]>
Co-authored-by: Prakalp Srivastava <[email protected]>
Co-authored-by: Sunghyun Park <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
---
python/tvm/relax/expr.py | 53 +--
src/relax/ir/expr.cc | 45 --
src/relax/ir/struct_info.cc | 43 --
src/script/printer/relax/binding.cc | 87 ++++
src/script/printer/relax/call.cc | 212 +++++++++
src/script/printer/relax/expr.cc | 136 ++++++
src/script/printer/relax/function.cc | 78 ++++
src/script/printer/relax/region.cc | 100 +++++
src/script/printer/relax/struct_info.cc | 149 +++++++
src/script/printer/relax/tir.cc | 89 ++++
src/script/printer/relax/type.cc | 89 ++++
src/script/printer/relax/utils.h | 101 +++++
tests/python/relax/test_tvmscript_printer_relax.py | 489 +++++++++++++++++++++
13 files changed, 1542 insertions(+), 129 deletions(-)
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 138724ed06..f1cf815d8e 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -22,16 +22,18 @@ from numbers import Number
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as _np # type: ignore
+
import tvm
import tvm._ffi
-import tvm.relax
import tvm.ir
+import tvm.relax
from tvm import DataType
from tvm._ffi import base as _base
-from tvm.runtime import ndarray as _nd, Object
+from tvm.runtime import Object
+from tvm.runtime import ndarray as _nd
from ..ir import BaseFunc, Node, SourceName, Span
-from ..runtime import String
+from ..runtime import Scriptable, String
from ..tir import PrimExpr
from . import _ffi_api
@@ -55,7 +57,7 @@ class Id(Object):
# NOTE: place base struct info in expr to avoid cyclic dep
# from expr to struct info.
-class StructInfo(Node):
+class StructInfo(Node, Scriptable):
"""The base class of all StructInfo.
StructInfo contains both the static type
@@ -110,7 +112,7 @@ def _binary_rhs_helper(rhs: "ExprWithOp") -> "ExprWithOp":
raise TypeError(f"type {type(rhs)} not supported")
-class ExprWithOp(Expr):
+class ExprWithOp(Expr, Scriptable):
"""Basetype of all relax expressions that defines op overloading."""
def astype(self, dtype: Union[str, DataType]) -> "ExprWithOp":
@@ -436,7 +438,7 @@ class DataflowVar(Var):
@tvm._ffi.register_object("relax.expr.PrimValue")
-class PrimValue(Expr):
+class PrimValue(Expr, Scriptable):
"""The prim expr representing the value."""
value: PrimExpr
@@ -448,7 +450,7 @@ class PrimValue(Expr):
@tvm._ffi.register_object("relax.expr.StringImm")
-class StringImm(Expr):
+class StringImm(Expr, Scriptable):
"""Represent a string literal constant."""
value: str
@@ -458,7 +460,7 @@ class StringImm(Expr):
@tvm._ffi.register_object("relax.expr.DataTypeImm")
-class DataTypeImm(Expr):
+class DataTypeImm(Expr, Scriptable):
"""Represent a data type constant."""
value: DataType
@@ -468,11 +470,9 @@ class DataTypeImm(Expr):
@tvm._ffi.register_object("relax.expr.Binding")
-class Binding(Node):
+class Binding(Node, Scriptable):
"""The base class of a binding in Relax."""
- ...
-
@tvm._ffi.register_object("relax.expr.MatchCast")
class MatchCast(Binding):
@@ -548,7 +548,7 @@ class SeqExpr(ExprWithOp):
@tvm._ffi.register_object("relax.expr.Function")
-class Function(BaseFunc):
+class Function(BaseFunc, Scriptable):
"""A Relax function."""
params: List[Var]
@@ -588,35 +588,6 @@ class Function(BaseFunc):
"""
return Call(self, args, None, None)
- def script(self, show_meta: bool = False) -> str:
- """Print relax.Function into TVMScript
-
- Parameters
- ----------
- show_meta : bool
- Whether to show meta information
-
- Returns
- -------
- script : str
- The TVM Script of the relax.Function
- """
- return tvm._ffi.get_global_func("script.AsRelaxScript")(self,
show_meta) # type: ignore
-
- def show(self, style: str = "light") -> None:
- """
- A sugar for print highlighted TVM script.
-
- Parameters
- ----------
- style : str, optional
- Pygments styles extended by "light" (default) and "dark", by
default "light"
- """
- from tvm.script.highlight import cprint # pylint:
disable=import-outside-toplevel
-
- # Use deferred import to avoid circular import while keeping cprint
under tvm/script
- cprint(self, style=style)
-
@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 45868a488a..a0aaea886d 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -94,13 +94,6 @@ TVM_REGISTER_GLOBAL("relax.Call")
.set_body_typed([](Expr op, Array<Expr> args, Attrs attrs,
Array<StructInfo> sinfo_args,
Span span) { return Call(op, args, attrs, sinfo_args,
span); });
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const CallNode*>(ref.get());
- p->stream << "CallNode(" << node->op << ", " << node->args << ", " <<
node->attrs << ", "
- << node->sinfo_args << ")";
- });
-
If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) {
ObjectPtr<IfNode> n = make_object<IfNode>();
n->cond = std::move(cond);
@@ -137,13 +130,6 @@ TVM_REGISTER_GLOBAL("relax.If")
return If(cond, true_branch, false_branch, span);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<IfNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const IfNode*>(ref.get());
- p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", "
- << node->false_branch << ")";
- });
-
Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) {
ObjectPtr<TupleNode> n = make_object<TupleNode>();
n->fields = std::move(fields);
@@ -179,12 +165,6 @@ Tuple WithFields(Tuple tuple, Optional<Array<Expr>>
opt_fields, Optional<Span> o
return tuple;
}
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const TupleNode*>(ref.get());
- p->stream << "Tuple(" << node->fields << ")";
- });
-
TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) {
ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>();
n->tuple = std::move(tuple);
@@ -216,12 +196,6 @@
TVM_REGISTER_GLOBAL("relax.TupleGetItem").set_body_typed([](Expr tuple, int inde
return TupleGetItem(tuple, index);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const TupleGetItemNode*>(ref.get());
- p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index
<< ")";
- });
-
TVM_REGISTER_NODE_TYPE(ShapeExprNode);
ShapeExpr::ShapeExpr(Array<PrimExpr> values, Span span) {
@@ -245,19 +219,6 @@
TVM_REGISTER_GLOBAL("relax.ShapeExpr").set_body_typed([](Array<PrimExpr> values,
return ShapeExpr(values, span);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ShapeExprNode>([](const ObjectRef& ref, ReprPrinter* p) {
- const ShapeExprNode* node = static_cast<const ShapeExprNode*>(ref.get());
- p->stream << "ShapeExpr(";
- for (auto it = node->values.begin(); it != node->values.end(); it++) {
- if (it != node->values.begin()) {
- p->stream << ", ";
- }
- p->stream << *it;
- }
- p->stream << ")";
- });
-
TVM_REGISTER_NODE_TYPE(VarNode);
Var::Var(Id vid, Optional<StructInfo> struct_info_annotation, Span span) {
@@ -572,12 +533,6 @@
TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol,
return ExternFunc(global_symbol, span);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ExternFuncNode>([](const ObjectRef& ref, ReprPrinter* p) {
- const auto* node = static_cast<const ExternFuncNode*>(ref.get());
- p->stream << "ExternFunc(\"" << node->global_symbol << "\")";
- });
-
Expr GetShapeOf(const Expr& expr) {
// default case, to be normalized.
ICHECK(expr->struct_info_.defined()) << "GetShapeOf can only be applied to
normalized expr";
diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc
index 9db7cea672..4004ad28d5 100644
--- a/src/relax/ir/struct_info.cc
+++ b/src/relax/ir/struct_info.cc
@@ -41,11 +41,6 @@
TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) {
return ObjectStructInfo(span);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ObjectStructInfoNode>([](const ObjectRef& ref, ReprPrinter*
p) {
- p->stream << "ObjectStructInfo()";
- });
-
// Prim
PrimStructInfo::PrimStructInfo(DataType dtype, Span span) {
ObjectPtr<PrimStructInfoNode> n = make_object<PrimStructInfoNode>();
@@ -60,12 +55,6 @@
TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Sp
return PrimStructInfo(dtype, span);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<PrimStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p)
{
- const auto* node = static_cast<const PrimStructInfoNode*>(ref.get());
- p->stream << "PrimStructInfo(" << node->dtype << ")";
- });
-
// Shape
ShapeStructInfo::ShapeStructInfo(Array<PrimExpr> values, Span span) {
ObjectPtr<ShapeStructInfoNode> n = make_object<ShapeStructInfoNode>();
@@ -102,16 +91,6 @@ TVM_REGISTER_GLOBAL("relax.ShapeStructInfo")
}
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ShapeStructInfoNode>([](const ObjectRef& ref, ReprPrinter*
p) {
- const auto* node = static_cast<const ShapeStructInfoNode*>(ref.get());
- if (node->values.defined()) {
- p->stream << "ShapeStructInfo(" << node->values.value() << ")";
- } else {
- p->stream << "ShapeStructInfo(ndim=" << node->ndim << ")";
- }
- });
-
// Tensor
TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) {
ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>();
@@ -150,16 +129,6 @@ TVM_REGISTER_GLOBAL("relax.TensorStructInfo")
}
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<TensorStructInfoNode>([](const ObjectRef& ref, ReprPrinter*
p) {
- const auto* node = static_cast<const TensorStructInfoNode*>(ref.get());
- if (node->shape.defined()) {
- p->stream << "TensorStructInfo(" << node->shape.value() << ", " <<
node->dtype << ")";
- } else {
- p->stream << "TensorStructInfo(" << node->dtype << ", ndim=" <<
node->ndim << ")";
- }
- });
-
// Tuple
TupleStructInfo::TupleStructInfo(Array<StructInfo> fields, Span span) {
ObjectPtr<TupleStructInfoNode> n = make_object<TupleStructInfoNode>();
@@ -175,12 +144,6 @@ TVM_REGISTER_GLOBAL("relax.TupleStructInfo")
return TupleStructInfo(fields, span);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<TupleStructInfoNode>([](const ObjectRef& ref, ReprPrinter*
p) {
- const auto* node = static_cast<const TupleStructInfoNode*>(ref.get());
- p->stream << "TupleStructInfo(" << node->fields << ")";
- });
-
// Func
FuncStructInfo::FuncStructInfo(Array<StructInfo> params, StructInfo ret, Span
span) {
ObjectPtr<FuncStructInfoNode> n = make_object<FuncStructInfoNode>();
@@ -223,12 +186,6 @@ TVM_REGISTER_GLOBAL("relax.FuncStructInfoOpaqueFunc")
}
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<FuncStructInfoNode>([](const ObjectRef& ref, ReprPrinter* p)
{
- const auto* node = static_cast<const FuncStructInfoNode*>(ref.get());
- p->stream << "FuncStructInfo(" << node->params << ", " << node->ret <<
")";
- });
-
// Helper functions
void UpdateStructInfo(Expr expr, StructInfo struct_info) {
ICHECK(!expr->struct_info_.defined())
diff --git a/src/script/printer/relax/binding.cc
b/src/script/printer/relax/binding.cc
new file mode 100644
index 0000000000..8a50fe9698
--- /dev/null
+++ b/src/script/printer/relax/binding.cc
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const
IRDocsifier& d, //
+ const Optional<ExprDoc>& var, const Optional<ExprDoc>& ann) {
+ using relax::SeqExpr;
+ ExprDoc cond = d->AsDoc<ExprDoc>(n->cond, n_p->Attr("cond"));
+ std::vector<Array<StmtDoc>> branches{
+ PrintSeqExpr(Downcast<SeqExpr>(n->true_branch),
n_p->Attr("true_branch"), d, false),
+ PrintSeqExpr(Downcast<SeqExpr>(n->false_branch),
n_p->Attr("false_branch"), d, false),
+ };
+ if (var.defined()) {
+ for (Array<StmtDoc>& stmts : branches) {
+ ExprDoc ret = Downcast<ExprStmtDoc>(stmts.back())->expr;
+ stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann));
+ }
+ }
+ return IfDoc(cond, branches[0], branches[1]);
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::MatchCast>(
+ "", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ using relax::StructInfo;
+ using relax::MatchStructInfo;
+ Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d,
n->value);
+ ExprDoc rhs = Relax(d, "match_cast")
+ ->Call({d->AsDoc<ExprDoc>(n->value,
n_p->Attr("value")),
+ d->AsDoc<ExprDoc>(n->struct_info,
n_p->Attr("struct_info_"))});
+ ExprDoc lhs = DefineVar(n->var, d->frames.back(), d);
+ return AssignDoc(lhs, rhs, ann);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::VarBinding>( //
+ "", [](relax::VarBinding n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ if (const auto if_ = n->value.as<relax::IfNode>()) {
+ Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"),
d, n->value);
+ ExprDoc lhs = DefineVar(n->var, d->frames.back(), d);
+ return PrintIfExpr(GetRef<relax::If>(if_), n_p->Attr("value"), d,
lhs, ann);
+ } else if (n->value->IsInstance<tvm::BaseFuncNode>()) {
+ IdDoc lhs = DefineVar(n->var, d->frames.back(), d);
+ d->cfg->binding_names.push_back(lhs->name);
+ Doc ret = d->AsDoc(n->value, n_p->Attr("value"));
+ d->cfg->binding_names.pop_back();
+ return ret;
+ } else {
+ ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value"));
+ Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"),
d, n->value);
+ ExprDoc lhs = DefineVar(n->var, d->frames.back(), d);
+ return AssignDoc(lhs, rhs, ann);
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::If>("", [](relax::If n, ObjectPath n_p, IRDocsifier
d) -> Doc {
+ return PrintIfExpr(n, n_p, d, NullOpt, NullOpt);
+ });
+
+TVM_SCRIPT_REPR(relax::MatchCastNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::VarBindingNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::IfNode, ReprPrintRelax);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc
new file mode 100644
index 0000000000..2feb2082c5
--- /dev/null
+++ b/src/script/printer/relax/call.cc
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+class AttrPrinter : public tvm::AttrVisitor {
+ public:
+ explicit AttrPrinter(const ObjectPath& p, const IRDocsifier& d,
Array<String>* keys,
+ Array<ExprDoc>* values)
+ : p(p), d(d), keys(keys), values(values) {}
+
+ void Visit(const char* key, double* value) final {
+ keys->push_back(key);
+ values->push_back(LiteralDoc::Float(*value, p->Attr(key)));
+ }
+
+ void Visit(const char* key, int64_t* value) final {
+ keys->push_back(key);
+ values->push_back(LiteralDoc::Int(*value, p->Attr(key)));
+ }
+
+ void Visit(const char* key, uint64_t* value) final {
+ keys->push_back(key);
+ values->push_back(LiteralDoc::Int(*value, p->Attr(key)));
+ }
+
+ void Visit(const char* key, int* value) final {
+ keys->push_back(key);
+ values->push_back(LiteralDoc::Int(*value, p->Attr(key)));
+ }
+
+ void Visit(const char* key, bool* value) final {
+ keys->push_back(key);
+ values->push_back(LiteralDoc::Boolean(*value, p->Attr(key)));
+ }
+
+ void Visit(const char* key, std::string* value) final {
+ keys->push_back(key);
+ values->push_back(LiteralDoc::Str(*value, p->Attr(key)));
+ }
+
+ void Visit(const char* key, DataType* value) final {
+ keys->push_back(key);
+ values->push_back(LiteralDoc::DataType(*value, p->Attr(key)));
+ }
+
+ void Visit(const char* key, runtime::ObjectRef* value) final {
+ keys->push_back(key);
+ values->push_back(d->AsDoc<ExprDoc>(*value, p->Attr(key)));
+ }
+
+ void Visit(const char* key, void** value) final {
+ LOG(FATAL) << "TypeError: void is not allowed in Attrs";
+ }
+
+ void Visit(const char* key, runtime::NDArray* value) final {
+ LOG(FATAL) << "TypeError: NDArray is not allowed in Attrs";
+ }
+
+ const ObjectPath& p;
+ const IRDocsifier& d;
+ Array<String>* keys;
+ Array<ExprDoc>* values;
+};
+
+ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const
IRDocsifier& d) {
+ // TODO(@junrushao): handle callee better
+ if (const auto* ext = n.as<relax::ExternFuncNode>()) {
+ return LiteralDoc::Str(ext->global_symbol, n_p);
+ } else if (const auto* gv = n.as<tvm::GlobalVarNode>()) {
+ IdDoc callee(gv->name_hint);
+ callee->source_paths.push_back(n_p);
+ return callee;
+ } else {
+ return d->AsDoc<ExprDoc>(n, n_p);
+ }
+}
+
+Optional<ExprDoc> PrintCallTIR(const relax::Call& n, const ObjectPath& n_p,
const IRDocsifier& d) {
+ static const Op& call_tir_op = Op::Get("relax.call_tir");
+ if (!n->op.same_as(call_tir_op)) {
+ return NullOpt;
+ }
+ ICHECK(n->args.size() == 2 || n->args.size() == 3);
+ ICHECK(n->sinfo_args.size() == 1);
+ Array<ExprDoc> args;
+ Array<String> kwargs_keys;
+ Array<ExprDoc> kwargs_values;
+ // Step 1. Print n->args[0], the callee
+ args.push_back(PrintCallee(n->args[0], n_p->Attr("args")->ArrayIndex(0), d));
+ // Step 2. Print n->args[1], the input arguments
+ args.push_back(d->AsDoc<ExprDoc>(n->args[1],
n_p->Attr("args")->ArrayIndex(1)));
+ // Step 3. Print n->sinfo_args, the output struct info
+ relax::StructInfo o_sinfo = n->sinfo_args[0];
+ ObjectPath o_sinfo_p = n_p->Attr("sinfo_args")->ArrayIndex(0);
+ kwargs_keys.push_back("out_sinfo");
+ if (const auto* o = o_sinfo.as<relax::TupleStructInfoNode>()) {
+ Array<ExprDoc> fields;
+ ObjectPath fields_p = o_sinfo_p->Attr("fields");
+ for (int i = 0, l = o->fields.size(); i < l; ++i) {
+ fields.push_back(d->AsDoc<ExprDoc>(o->fields[i],
fields_p->ArrayIndex(i)));
+ }
+ kwargs_values.push_back(ListDoc(fields));
+ } else {
+ kwargs_values.push_back(d->AsDoc<ExprDoc>(o_sinfo, o_sinfo_p));
+ }
+ // Step 4. Print n->args[2], the tir variables
+ if (n->args.size() == 3) {
+ kwargs_keys.push_back("tir_vars");
+ kwargs_values.push_back(d->AsDoc<ExprDoc>(n->args[2],
n_p->Attr("args")->ArrayIndex(2)));
+ }
+ return Relax(d, "call_tir")->Call(args, kwargs_keys, kwargs_values);
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::Call>( //
+ "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ // Special case: call_tir
+ if (Optional<ExprDoc> doc = PrintCallTIR(n, n_p, d)) {
+ return doc.value();
+ }
+ ExprDoc prefix{nullptr};
+ Array<ExprDoc> args;
+ Array<String> kwargs_keys;
+ Array<ExprDoc> kwargs_values;
+ // Step 1. Print op
+ if (const auto* op = n->op.as<relax::ExternFuncNode>()) {
+ prefix = Relax(d, "call_packed");
+ args.push_back(LiteralDoc::Str(op->global_symbol,
n_p->Attr("op")));
+ } else if (const auto* op = n->op.as<tvm::GlobalVarNode>()) {
+ prefix = IdDoc(op->name_hint);
+ prefix->source_paths.push_back(n_p->Attr("op"));
+ } else if (const auto* op = n->op.as<tvm::OpNode>()) {
+ std::string name = op->name;
+ if (name.rfind("relax.", 0) == 0) {
+ prefix = Relax(d, name.substr(6));
+ } else {
+ prefix = IdDoc(name);
+ }
+ prefix->source_paths.push_back(n_p->Attr("op"));
+ } else if (n->op->IsInstance<relax::VarNode>()) {
+ prefix = d->AsDoc<ExprDoc>(n->op, n_p->Attr("op"));
+ } else {
+ LOG(FATAL) << "TypeError: Unsupported op: " << n->op->GetTypeKey();
+ }
+ // Step 2. Print args
+ if (!n->args.empty()) {
+ args.push_back(PrintCallee(n->args[0],
n_p->Attr("args")->ArrayIndex(0), d));
+ }
+ for (int i = 1, l = n->args.size(); i < l; ++i) {
+ args.push_back(d->AsDoc<ExprDoc>(n->args[i],
n_p->Attr("args")->ArrayIndex(i)));
+ }
+ // Step 3. Print attrs
+ if (n->attrs.defined()) {
+ if (n->op->IsInstance<relax::ExternFuncNode>()) {
+ kwargs_keys.push_back("attrs_type_key");
+ kwargs_values.push_back(LiteralDoc::Str(n->attrs->GetTypeKey(),
n_p->Attr("attrs")));
+ }
+ if (const auto* attrs = n->attrs.as<tvm::DictAttrsNode>()) {
+ std::vector<std::pair<String, ObjectRef>> sorted;
+ for (const auto& kv : attrs->dict) {
+ sorted.push_back(kv);
+ }
+ std::sort(sorted.begin(), sorted.end());
+ for (const auto& kv : sorted) {
+ kwargs_keys.push_back(kv.first);
+ kwargs_values.push_back(
+ d->AsDoc<ExprDoc>(kv.second,
n_p->Attr("attrs")->Attr(kv.first)));
+ }
+ } else {
+ AttrPrinter printer(n_p->Attr("attrs"), d, &kwargs_keys,
&kwargs_values);
+ const_cast<BaseAttrsNode*>(n->attrs.get())->VisitAttrs(&printer);
+ }
+ }
+ // Step 4. Print type_args
+ if (n->sinfo_args.size() > 0) {
+ ObjectPath sinfo_args_p = n_p->Attr("sinfo_args");
+ Array<ExprDoc> sinfo_args;
+ for (int i = 0, l = n->sinfo_args.size(); i < l; ++i) {
+ sinfo_args.push_back(
+ d->AsDoc<ExprDoc>(n->sinfo_args[i],
sinfo_args_p->ArrayIndex(i)));
+ }
+ kwargs_keys.push_back("sinfo_args");
+ kwargs_values.push_back(TupleDoc(sinfo_args));
+ }
+ return prefix->Call(args, kwargs_keys, kwargs_values);
+ });
+
+TVM_SCRIPT_REPR(relax::CallNode, ReprPrintRelax);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc
new file mode 100644
index 0000000000..a786932fc3
--- /dev/null
+++ b/src/script/printer/relax/expr.cc
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::PrimValue>( //
+ "", [](relax::PrimValue n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ // TODO(@junrushao): float numbers
+ return Relax(d, "prim_value")->Call({d->AsDoc<ExprDoc>(n->value,
n_p->Attr("value"))});
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::StringImm>( //
+ "", [](relax::StringImm n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ return Relax(d, "str")->Call({LiteralDoc::Str(n->value,
n_p->Attr("value"))});
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::DataTypeImm>( //
+ "", [](relax::DataTypeImm n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ return Relax(d, "dtype")->Call({LiteralDoc::DataType(n->value,
n_p->Attr("value"))});
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::Tuple>( //
+ "", [](relax::Tuple n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ // TODO(@junrushao): revisit tuple printing
+ if (n->fields.empty()) {
+ return Relax(d, "tuple")->Call({});
+ }
+ Array<ExprDoc> fields_doc;
+ ObjectPath fields_p = n_p->Attr("fields");
+ for (int i = 0, l = n->fields.size(); i < l; ++i) {
+ fields_doc.push_back(d->AsDoc<ExprDoc>(n->fields[i],
fields_p->ArrayIndex(i)));
+ }
+ return TupleDoc(fields_doc);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::TupleGetItem>( //
+ "", [](relax::TupleGetItem n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ ExprDoc idx = LiteralDoc::Int(n->index, n_p->Attr("index"));
+ return d->AsDoc<ExprDoc>(n->tuple, n_p->Attr("tuple"))[{idx}];
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::ShapeExpr>( //
+ "", [](relax::ShapeExpr n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ Array<ExprDoc> values_doc;
+ ObjectPath values_p = n_p->Attr("values");
+ for (int i = 0, l = n->values.size(); i < l; ++i) {
+ values_doc.push_back(PrintShapeVar(n->values[i],
values_p->ArrayIndex(i), d));
+ }
+ return TupleDoc(values_doc);
+ });
+
+Optional<ExprDoc> SpecialScalar(const runtime::NDArray& n, const ObjectPath&
p) {
+ DataType dtype = n.DataType();
+ const void* data = n->data;
+ if (n->ndim != 0 || n->device.device_type != kDLCPU) {
+ return NullOpt;
+ }
+ if (dtype == DataType::Int(32)) {
+ return LiteralDoc::Int(*reinterpret_cast<const int32_t*>(data), p);
+ } else if (dtype == DataType::Int(64)) {
+ return LiteralDoc::Int(*reinterpret_cast<const int64_t*>(data), p);
+ } else if (dtype == DataType::Float(32)) {
+ return LiteralDoc::Float(*reinterpret_cast<const float*>(data), p);
+ } else if (dtype == DataType::Float(64)) {
+ return LiteralDoc::Float(*reinterpret_cast<const double*>(data), p);
+ } else if (dtype == DataType::Bool()) {
+ return LiteralDoc::Boolean(*reinterpret_cast<const uint8_t*>(data), p);
+ } else {
+ return NullOpt;
+ }
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::Constant>( //
+ "", [](relax::Constant n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ if (Optional<ExprDoc> s = SpecialScalar(n->data, n_p->Attr("data")))
{
+ return Relax(d, "const")
+ ->Call({
+ s.value(),
+ LiteralDoc::DataType(n->data.DataType(),
n_p->Attr("data")->Attr("dtype")),
+ });
+ }
+ return d->AddMetadata(n);
+ });
+
+Doc PrintRelaxVar(relax::Var n, ObjectPath p, IRDocsifier d) {
+ if (!d->IsVarDefined(n)) {
+ ExprDoc ann = d->AsDoc<ExprDoc>(n->struct_info_, p->Attr("struct_info_"));
+ Frame f = d->frames.back();
+ ExprDoc var = DefineVar(n, f, d);
+ f->stmts.push_back(AssignDoc(var, NullOpt, ann));
+ }
+ return d->GetVarDoc(n).value();
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<relax::Var>("",
PrintRelaxVar);
+TVM_STATIC_IR_FUNCTOR(IRDocsifier,
vtable).set_dispatch<relax::DataflowVar>("", PrintRelaxVar);
+
+TVM_SCRIPT_REPR(relax::PrimValueNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::StringImmNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::DataTypeImmNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::TupleNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::TupleGetItemNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::ShapeExprNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::VarNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::DataflowVarNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::ConstantNode, ReprPrintRelax);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/relax/function.cc
b/src/script/printer/relax/function.cc
new file mode 100644
index 0000000000..fa085fcad4
--- /dev/null
+++ b/src/script/printer/relax/function.cc
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+TVM_REGISTER_NODE_TYPE(RelaxFrameNode);
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::Function>("", [](relax::Function n, ObjectPath n_p,
IRDocsifier d) -> Doc {
+ std::unordered_set<const tir::VarNode*> func_vars;
+ With<RelaxFrame> f(d);
+ (*f)->AddDispatchToken(d, "relax");
+ (*f)->is_func = true;
+ (*f)->func_vars = &func_vars;
+ // Step 1. Print the return type
+ Optional<ExprDoc> ret_type = NullOpt;
+ if (const auto& func_sinfo =
relax::MatchStructInfo<relax::FuncStructInfo>(n)) {
+ ret_type = d->AsDoc<ExprDoc>(func_sinfo.value()->ret, //
+ n_p->Attr("struct_info_")->Attr("ret"));
+ }
+ // Step 2. Print params
+ Array<AssignDoc> params;
+ {
+ ObjectPath params_p = n_p->Attr("params");
+ for (int i = 0, l = n->params.size(); i < l; ++i) {
+ params.push_back(AssignDoc(
+ /*lhs=*/DefineVar(n->params[i], *f, d),
+ /*rhs=*/NullOpt, StructInfoAsAnn(n->params[i],
params_p->ArrayIndex(i), d, NullOpt)));
+ }
+ }
+ // Step 3. Clean up func variables
+ (*f)->func_vars = nullptr;
+ // Step 4. Print attributes
+ if (n->attrs.defined() && !n->attrs->dict.empty()) {
+ (*f)->stmts.push_back(
+ ExprStmtDoc(Relax(d, "func_attr") //
+ ->Call({d->AsDoc<ExprDoc>(n->attrs,
n_p->Attr("attrs"))})));
+ }
+ // Step 5. Print body
+ Array<StmtDoc> body =
+ PrintSeqExpr(Downcast<relax::SeqExpr>(n->body), n_p->Attr("body"),
d, /*use_ret=*/true);
+ (*f)->stmts.insert((*f)->stmts.end(), body.begin(), body.end());
+ return HeaderWrapper(d, FunctionDoc(IdDoc(FindFunctionName(d,
n).value_or("main")), params,
+ {Relax(d, "function")}, ret_type,
(*f)->stmts));
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::ExternFunc>( //
+ "", [](relax::ExternFunc n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ // TODO(@junrushao): print more information out of extern function.
+ return ExprStmtDoc(LiteralDoc::Str(n->global_symbol, n_p));
+ });
+
+TVM_SCRIPT_REPR(relax::FunctionNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::ExternFuncNode, ReprPrintRelax);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/relax/region.cc
b/src/script/printer/relax/region.cc
new file mode 100644
index 0000000000..1ac0b5ba14
--- /dev/null
+++ b/src/script/printer/relax/region.cc
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+Array<StmtDoc> PrintSeqExpr(const relax::SeqExpr& n, const ObjectPath& n_p,
const IRDocsifier& d,
+ bool use_ret) {
+ With<RelaxFrame> f(d);
+ const Array<relax::BindingBlock>& blocks = n->blocks;
+ ObjectPath blocks_p = n_p->Attr("blocks");
+ Array<StmtDoc>* stmts = &(*f)->stmts;
+ for (int i = 0, l = blocks.size(); i < l; ++i) {
+ Doc block = d->AsDoc(blocks[i], blocks_p->ArrayIndex(i));
+ if (const auto* stmt_block = block.as<StmtBlockDocNode>()) {
+ stmts->insert(stmts->end(), stmt_block->stmts.begin(),
stmt_block->stmts.end());
+ } else if (const auto* stmt = block.as<StmtDocNode>()) {
+ stmts->push_back(GetRef<StmtDoc>(stmt));
+ } else {
+ LOG(FATAL) << "TypeError: Unknown type: " << block->GetTypeKey();
+ }
+ }
+ ExprDoc ret = d->AsDoc<ExprDoc>(n->body, n_p->Attr("body"));
+ if (use_ret) {
+ stmts->push_back(ReturnDoc(ret));
+ } else {
+ stmts->push_back(ExprStmtDoc(ret));
+ }
+ return *stmts;
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::SeqExpr>("", [](relax::SeqExpr n, ObjectPath n_p,
IRDocsifier d) -> Doc {
+ return StmtBlockDoc(PrintSeqExpr(n, n_p, d, false));
+ });
+
+Array<StmtDoc> PrintBindingBlock(const relax::BindingBlock& n, const
ObjectPath& n_p,
+ const IRDocsifier& d, Array<ExprDoc>*
non_dataflow_vars) {
+ const Array<relax::Binding>& bindings = n->bindings;
+ ObjectPath bindings_p = n_p->Attr("bindings");
+ Array<StmtDoc> stmts;
+ for (int i = 0, l = bindings.size(); i < l; ++i) {
+ const relax::Binding& binding = bindings[i];
+ ObjectPath binding_p = bindings_p->ArrayIndex(i);
+ ICHECK(binding->var.defined());
+ Doc binding_doc = d->AsDoc(binding, binding_p);
+ if (const auto* stmt = binding_doc.as<StmtDocNode>()) {
+ stmts.push_back(GetRef<StmtDoc>(stmt));
+ } else if (const auto* stmt_block = binding_doc.as<StmtBlockDocNode>()) {
+ stmts.insert(stmts.end(), stmt_block->stmts.begin(),
stmt_block->stmts.end());
+ } else {
+ LOG(FATAL) << "TypeError: Unknown type: " << binding_doc->GetTypeKey();
+ }
+ if (non_dataflow_vars != nullptr &&
!binding->var->IsInstance<relax::DataflowVarNode>()) {
+ non_dataflow_vars->push_back(d->AsDoc<ExprDoc>(binding->var,
binding_p->Attr("var")));
+ }
+ }
+ return stmts;
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::BindingBlock>( //
+ "", [](relax::BindingBlock n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ return StmtBlockDoc(PrintBindingBlock(n, n_p, d, nullptr));
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::DataflowBlock>( //
+ "", [](relax::DataflowBlock n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ Array<ExprDoc> non_dataflow_vars;
+ Array<StmtDoc> stmts = PrintBindingBlock(n, n_p, d,
&non_dataflow_vars);
+ stmts.push_back(ExprStmtDoc(Relax(d,
"output")->Call(non_dataflow_vars)));
+ return ScopeDoc(NullOpt, Relax(d, "dataflow")->Call({}), stmts);
+ });
+
+TVM_SCRIPT_REPR(relax::SeqExprNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::BindingBlockNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::DataflowBlockNode, ReprPrintRelax);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/relax/struct_info.cc
b/src/script/printer/relax/struct_info.cc
new file mode 100644
index 0000000000..6f4a66c991
--- /dev/null
+++ b/src/script/printer/relax/struct_info.cc
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/stmt_functor.h>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::ObjectStructInfo>( //
+ "", [](relax::ObjectStructInfo n, ObjectPath n_p, IRDocsifier d) ->
Doc {
+ return Relax(d, "Object");
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::PrimStructInfo>(
+ "", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ return Relax(d, "Prim")->Call({LiteralDoc::DataType(n->dtype,
n_p->Attr("dtype"))});
+ });
+
+ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const
IRDocsifier& d) {
+ ExprDoc expr_doc = d->AsDoc<ExprDoc>(e, e_p);
+ // Step 1. Find if `func_vars` are being collected
+ const RelaxFrameNode* f = nullptr;
+ for (const Frame& frame : d->frames) {
+ if (const auto* relax_frame = frame.as<RelaxFrameNode>()) {
+ if (relax_frame->func_vars) {
+ f = relax_frame;
+ break;
+ }
+ }
+ }
+ // Step 2. Figure out if the PrimExpr contains at least a func var
+ bool func_var_mode = false;
+ if (f != nullptr) {
+ tir::PostOrderVisit(e, [f, &func_var_mode](const ObjectRef& obj) -> void {
+ if (const auto* var = obj.as<tir::VarNode>()) {
+ if (f->func_vars->count(var)) {
+ func_var_mode = true;
+ }
+ }
+ });
+ }
+ // Step 3. Stringify the PrimExpr if func var exists
+ if (func_var_mode) {
+ return LiteralDoc::Str(DocToPythonScript(expr_doc, d->cfg), e_p);
+ }
+ return expr_doc;
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::ShapeStructInfo>(
+ "", [](relax::ShapeStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc
{
+ if (n->values.defined()) {
+ Array<PrimExpr> shape = n->values.value();
+ ObjectPath shape_p = n_p->Attr("values");
+ Array<ExprDoc> shape_docs;
+ for (int i = 0, ndim = shape.size(); i < ndim; ++i) {
+ shape_docs.push_back(PrintShapeVar(shape[i],
shape_p->ArrayIndex(i), d));
+ }
+ return Relax(d, "Shape")->Call({ListDoc(shape_docs)});
+ }
+ return Relax(d, "Shape")
+ ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim,
n_p->Attr("ndim"))});
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::TensorStructInfo>( //
+ "", [](relax::TensorStructInfo n, ObjectPath n_p, IRDocsifier d) ->
Doc {
+ Array<ExprDoc> args;
+ Array<String> kwargs_keys;
+ Array<ExprDoc> kwargs_values;
+ if (n->shape.defined()) {
+ args.push_back(d->AsDoc<ExprDoc>(n->shape.value(),
n_p->Attr("shape")));
+ }
+ if (!n->IsUnknownDtype()) {
+ kwargs_keys.push_back("dtype");
+ kwargs_values.push_back(LiteralDoc::DataType(n->dtype,
n_p->Attr("dtype")));
+ }
+ if (!n->shape.defined() && !n->IsUnknownNdim()) {
+ kwargs_keys.push_back("ndim");
+ kwargs_values.push_back(LiteralDoc::Int(n->ndim,
n_p->Attr("ndim")));
+ }
+ if (args.empty() && kwargs_keys.empty()) {
+ return Relax(d, "Tensor");
+ }
+ return Relax(d, "Tensor")->Call(args, kwargs_keys, kwargs_values);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::TupleStructInfo>( //
+ "", [](relax::TupleStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc
{
+ if (n->fields.empty()) {
+ return Relax(d, "Tuple");
+ }
+ Array<ExprDoc> fields_doc;
+ ObjectPath fields_p = n_p->Attr("fields");
+ for (int i = 0, l = n->fields.size(); i < l; ++i) {
+ fields_doc.push_back(d->AsDoc<ExprDoc>(n->fields[i],
fields_p->ArrayIndex(i)));
+ }
+ return Relax(d, "Tuple")->Call(fields_doc);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::FuncStructInfo>( //
+ "", [](relax::FuncStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ if (n->IsOpaque()) {
+ return Relax(d, "Callable");
+ }
+ // TODO(@junrushao): track symbolic shape relation
+ Array<ExprDoc> params_doc;
+ Array<relax::StructInfo> params = n->params.value();
+ ObjectPath params_p = n_p->Attr("params");
+ for (int i = 0, n_params = params.size(); i < n_params; ++i) {
+ params_doc.push_back(d->AsDoc<ExprDoc>(params[i],
params_p->ArrayIndex(i)));
+ }
+ return Relax(d, "Callable")
+ ->Call({TupleDoc(params_doc), //
+ d->AsDoc<ExprDoc>(n->ret, n_p->Attr("ret"))});
+ });
+
+TVM_SCRIPT_REPR(relax::ObjectStructInfoNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::PrimStructInfoNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::ShapeStructInfoNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::TensorStructInfoNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::TupleStructInfoNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::FuncStructInfoNode, ReprPrintRelax);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc
new file mode 100644
index 0000000000..2c8bb0f1da
--- /dev/null
+++ b/src/script/printer/relax/tir.cc
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/ir/expr.h>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) {
+ ICHECK(n->dtype.is_int() && n->dtype.is_scalar()) << "TypeError: Relax only
uses "
+ "scalar integer TIR
variables, but gets: "
+ << n;
+ if (!d->IsVarDefined(n)) {
+ // Find the outmost Relax function frame. If not exist, the outmost Relax
frame.
+ RelaxFrameNode* f = nullptr;
+ for (const Frame& frame : d->frames) {
+ if (const auto* relax_frame = frame.as<RelaxFrameNode>()) {
+ if (relax_frame->is_func) {
+ f = const_cast<RelaxFrameNode*>(relax_frame);
+ break;
+ } else if (f == nullptr) {
+ f = const_cast<RelaxFrameNode*>(relax_frame);
+ }
+ }
+ }
+ // There should be at least one Relax frame
+ if (f == nullptr) {
+ LOG(FATAL) << "IndexError: No relax environment is found when printing a
TIR var under "
+ "relax's dispatch token";
+ }
+ // If the Relax function frame is collecting func vars
+ if (f->func_vars) {
+ ICHECK(f->is_func);
+ f->func_vars->insert(n.get());
+ }
+ IdDoc var = d->Define(n, GetRef<Frame>(f), n->name_hint.empty() ? "v" :
n->name_hint);
+ var->source_paths.push_back(n_p);
+ f->stmts.push_back(AssignDoc(var,
+ TIR(d, "Var")->Call({
+ LiteralDoc::Str(var->name,
n_p->Attr("name_hint")),
+ LiteralDoc::DataType(n->dtype,
n_p->Attr("dtype")),
+ }),
+ NullOpt));
+ }
+ if (Optional<ExprDoc> doc = d->GetVarDoc(n)) {
+ return doc.value();
+ }
+ LOG(FATAL) << "IndexError: Variable is not defined in the environment: " <<
n;
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<tir::Var>("relax",
PrintTIRVar);
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<tir::SizeVar>("relax",
PrintTIRVar);
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tvm::IntImm>( //
+ "relax", [](tvm::IntImm n, ObjectPath n_p, IRDocsifier d) -> Doc { //
+ // TODO(@junrushao): support non-int64 cases
+ return LiteralDoc::Int(n->value, n_p);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tvm::GlobalVar>(
//
+ "relax", [](tvm::GlobalVar n, ObjectPath n_p, IRDocsifier d) -> Doc {
//
+ IdDoc ret(n->name_hint);
+ ret->source_paths.push_back(n_p);
+ return ret;
+ });
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/relax/type.cc b/src/script/printer/relax/type.cc
new file mode 100644
index 0000000000..d13d90b1d5
--- /dev/null
+++ b/src/script/printer/relax/type.cc
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::ShapeType>( //
+ "", [](relax::ShapeType n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ return Relax(d, "Shape")
+ ->Call({}, {"ndim"}, {LiteralDoc::Int(n->ndim,
n_p->Attr("ndim"))});
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::ObjectType>( //
+ "", [](relax::ObjectType n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ return Relax(d, "Object");
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::DynTensorType>(
+ "", [](relax::DynTensorType n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ return Relax(d, "Tensor")
+ ->Call({}, {"ndim", "dtype"},
+ {LiteralDoc::Int(n->ndim, n_p->Attr("ndim")),
+ LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))});
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<relax::PackedFuncType>(
+ "", [](relax::PackedFuncType n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ return Relax(d, "PackedFunc"); // TODO(@junrushao): verify if this
is correct
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tvm::TupleType>( //
+ "relax", [](tvm::TupleType n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ if (n->fields.empty()) {
+ return Relax(d, "Tuple");
+ }
+ Array<ExprDoc> fields_doc;
+ ObjectPath fields_p = n_p->Attr("fields");
+ for (int i = 0, l = n->fields.size(); i < l; ++i) {
+ fields_doc.push_back(d->AsDoc<ExprDoc>(n->fields[i],
fields_p->ArrayIndex(i)));
+ }
+ return Relax(d, "Tuple")->Call(fields_doc);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tvm::FuncType>(
+ "relax", [](tvm::FuncType n, ObjectPath n_p, IRDocsifier d) -> Doc {
+ Array<ExprDoc> arg_types_doc;
+ Array<Type> arg_types = n->arg_types;
+ ObjectPath arg_types_p = n_p->Attr("arg_types");
+ for (int i = 0, n_params = arg_types.size(); i < n_params; ++i) {
+ arg_types_doc.push_back(d->AsDoc<ExprDoc>(arg_types[i],
arg_types_p->ArrayIndex(i)));
+ }
+ return Relax(d, "Callable")
+ ->Call({TupleDoc(arg_types_doc), //
+ d->AsDoc<ExprDoc>(n->ret_type, n_p->Attr("ret_type"))});
+ });
+
+TVM_SCRIPT_REPR(relax::ShapeTypeNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::ObjectTypeNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::DynTensorTypeNode, ReprPrintRelax);
+TVM_SCRIPT_REPR(relax::PackedFuncTypeNode, ReprPrintRelax);
+TVM_REGISTER_GLOBAL("script.printer.ReprPrintRelax").set_body_typed(ReprPrintRelax);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h
new file mode 100644
index 0000000000..7702f7b22d
--- /dev/null
+++ b/src/script/printer/relax/utils.h
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
+#define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
+
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/type.h>
+#include <tvm/script/printer/ir_docsifier.h>
+
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "../utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+class RelaxFrameNode : public FrameNode {
+ public:
+ bool is_func = false;
+ std::unordered_set<const tir::VarNode*>* func_vars = nullptr;
+
+ void VisitAttrs(AttrVisitor* v) {
+ FrameNode::VisitAttrs(v);
+ v->Visit("is_global_func", &is_func);
+ // `func_var_to_define` is not visited
+ }
+
+ static constexpr const char* _type_key = "script.printer.RelaxFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(RelaxFrameNode, FrameNode);
+};
+
+class RelaxFrame : public Frame {
+ public:
+ explicit RelaxFrame(const IRDocsifier& d) {
+ ObjectPtr<RelaxFrameNode> n = make_object<RelaxFrameNode>();
+ n->stmts.clear();
+ n->d = d.get();
+ n->is_func = false;
+ n->func_vars = nullptr;
+ data_ = std::move(n);
+ }
+
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, Frame,
RelaxFrameNode);
+};
+
+/*! \brief Redirected method for the ReprPrinter */
+inline std::string ReprPrintRelax(const ObjectRef& obj, const PrinterConfig&
cfg) {
+ IRDocsifier d(cfg);
+ With<RelaxFrame> f(d);
+ (*f)->AddDispatchToken(d, "relax");
+ return Docsify(obj, d, *f, cfg);
+}
+
+inline IdDoc DefineVar(const relax::Var& var, const Frame& frame, const
IRDocsifier& d) {
+ return d->Define(var, frame, var->name_hint().empty() ? "v" :
var->name_hint());
+}
+
+inline Optional<ExprDoc> StructInfoAsAnn(const relax::Var& v, const
ObjectPath& v_p,
+ const IRDocsifier& d, const
Optional<relax::Expr>& rhs) {
+ if (!v->struct_info_.defined()) {
+ return NullOpt;
+ }
+ if (const auto* call = rhs.as<relax::CallNode>()) {
+ static const Op& call_tir_op = Op::Get("relax.call_tir");
+ if (call->op.same_as(call_tir_op)) {
+ return NullOpt;
+ }
+ }
+ return d->AsDoc<ExprDoc>(v->struct_info_, v_p->Attr("struct_info_"));
+}
+
+Array<StmtDoc> PrintSeqExpr(const relax::SeqExpr& n, const ObjectPath& n_p,
const IRDocsifier& d,
+ bool use_ret);
+
+ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const
IRDocsifier& d);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
+
+#endif // TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py
b/tests/python/relax/test_tvmscript_printer_relax.py
new file mode 100644
index 0000000000..75fc4d1429
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -0,0 +1,489 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import pytest
+from tvm import IRModule, relax, tir
+from tvm.script import relax as R
+
+
+def _assert_print(obj, expected):
+ if not isinstance(obj, str):
+ obj = obj.script(verbose_expr=True)
+ obj = obj.strip()
+ assert obj == expected.strip(), "\n" + obj
+
+
+def test_function():
+ @R.function
+ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore
+ return a
+
+ _assert_print(
+ func,
+ """
+# from tvm.script import relax as R
+
[email protected]
+def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
+ return a""",
+ )
+
+
+def test_extern_func():
+ @R.function
+ def relax_func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type:
ignore
+ return a
+
+ obj = IRModule(
+ {
+ "func": relax_func,
+ "my_ext": relax.ExternFunc("my_ext"),
+ }
+ )
+ _assert_print(
+ obj,
+ """
+# from tvm.script import ir as I
+# from tvm.script import relax as R
+
[email protected]_module
+class Module:
+ @R.function
+ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
+ return a
+
+ "my_ext"
+""",
+ )
+
+
+def test_object_struct_info():
+ obj = relax.ObjectStructInfo()
+ _assert_print(
+ obj,
+ "R.Object",
+ )
+
+
+def test_prim_struct_info():
+ obj = relax.PrimStructInfo("float32")
+ _assert_print(obj, 'R.Prim("float32")')
+
+
+def test_shape_struct_info_0():
+ obj = relax.ShapeStructInfo(ndim=-1)
+ _assert_print(obj, "R.Shape(ndim=-1)")
+
+
+def test_shape_struct_info_1():
+ obj = relax.ShapeStructInfo([1, 2, 3])
+ _assert_print(obj, "R.Shape([1, 2, 3])")
+
+
+def test_shape_struct_info_2():
+ obj = relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3])
+ _assert_print(
+ obj,
+ """
+a = T.Var("a", "int64")
+R.Shape([1, a, 3])""",
+ )
+
+
+def test_tensor_struct_info():
+ obj = relax.TensorStructInfo(
+ shape=relax.ShapeExpr([1, tir.Var("a", "int64"), 3]),
+ dtype="float32",
+ )
+ _assert_print(
+ obj,
+ """
+a = T.Var("a", "int64")
+R.Tensor((1, a, 3), dtype="float32")
+""",
+ )
+
+
+def test_tuple_struct_info_empty():
+ obj = relax.TupleStructInfo([])
+ _assert_print(obj, "R.Tuple")
+
+
+def test_tuple_struct_info():
+ obj = relax.TupleStructInfo(
+ [
+ relax.PrimStructInfo("float32"),
+ relax.ObjectStructInfo(),
+ relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]),
+ ]
+ )
+ _assert_print(
+ obj,
+ """
+a = T.Var("a", "int64")
+R.Tuple(R.Prim("float32"), R.Object, R.Shape([1, a, 3]))
+""",
+ )
+
+
+def test_func_struct_info():
+ obj = relax.FuncStructInfo(
+ params=[
+ relax.PrimStructInfo("float32"),
+ relax.ObjectStructInfo(),
+ relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]),
+ ],
+ ret=relax.TensorStructInfo(
+ shape=relax.ShapeExpr([1, 2, 3]),
+ dtype="float32",
+ ),
+ )
+ _assert_print(
+ obj,
+ """
+a = T.Var("a", "int64")
+R.Callable((R.Prim("float32"), R.Object, R.Shape([1, a, 3])), R.Tensor((1, 2,
3), dtype="float32"))
+""",
+ )
+
+
+def test_shape_type():
+ obj = relax.ShapeType(ndim=3)
+ _assert_print(obj, "R.Shape(ndim=3)")
+
+
+def test_object_type():
+ obj = relax.ObjectType()
+ _assert_print(obj, "R.Object")
+
+
+def test_dyn_tensor_type():
+ obj = relax.DynTensorType()
+ _assert_print(obj, 'R.Tensor(ndim=-1, dtype="float32")')
+
+
+def test_packed_func_type():
+ obj = relax.PackedFuncType()
+ _assert_print(obj, "R.PackedFunc")
+
+
+def test_tuple_type():
+ obj = relax.TupleType([relax.ShapeType(ndim=3), relax.ObjectType()])
+ _assert_print(
+ obj._relax_script(), # pylint: disable=protected-access
+ "R.Tuple(R.Shape(ndim=3), R.Object)",
+ )
+
+
+def test_func_type():
+ obj = relax.FuncType(
+ arg_types=[
+ relax.ObjectType(),
+ relax.ShapeType(ndim=3),
+ ],
+ ret_type=relax.DynTensorType(
+ ndim=3,
+ dtype="float32",
+ ),
+ )
+ _assert_print(
+ obj._relax_script(), # pylint: disable=protected-access
+ 'R.Callable((R.Object, R.Shape(ndim=3)), R.Tensor(ndim=3,
dtype="float32"))',
+ )
+
+
+def test_prim_value():
+ obj = relax.PrimValue(1)
+ _assert_print(obj, "R.prim_value(1)")
+
+
+def test_string_imm():
+ obj = relax.StringImm("hello")
+ _assert_print(obj, 'R.str("hello")')
+
+
+def test_data_type_imm():
+ obj = relax.DataTypeImm("float32")
+ _assert_print(obj, 'R.dtype("float32")')
+
+
+def test_var():
+ obj = relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3],
"float32"))
+ _assert_print(
+ obj,
+ """
+x = T.Var("x", "int64")
+a: R.Tensor((1, x, 3), dtype="float32")
+a""",
+ )
+
+
+def test_dataflow_var():
+ obj = relax.DataflowVar("a", relax.TensorStructInfo([1, tir.Var("x",
"int64"), 3], "float32"))
+ _assert_print(
+ obj,
+ """
+x = T.Var("x", "int64")
+a: R.Tensor((1, x, 3), dtype="float32")
+a""",
+ )
+
+
+def test_tuple():
+ obj = relax.Tuple(
+ [
+ relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"),
3], "float32")),
+ relax.Var("b", relax.TensorStructInfo([1, tir.Var("y", "int64"),
3], "float32")),
+ relax.Var("c", relax.TensorStructInfo([1, tir.Var("z", "int64"),
3], "float32")),
+ ]
+ )
+ _assert_print(
+ obj,
+ """
+x = T.Var("x", "int64")
+a: R.Tensor((1, x, 3), dtype="float32")
+y = T.Var("y", "int64")
+b: R.Tensor((1, y, 3), dtype="float32")
+z = T.Var("z", "int64")
+c: R.Tensor((1, z, 3), dtype="float32")
+(a, b, c)
+""",
+ )
+
+
+def test_tuple_get_item():
+ obj = relax.TupleGetItem(
+ relax.Tuple(
+ [
+ relax.Var("a", relax.TensorStructInfo([1, tir.Var("x",
"int64"), 3], "float32")),
+ relax.Var("b", relax.TensorStructInfo([1, tir.Var("y",
"int64"), 3], "float32")),
+ relax.Var("c", relax.TensorStructInfo([1, tir.Var("z",
"int64"), 3], "float32")),
+ ]
+ ),
+ 0,
+ )
+ _assert_print(
+ obj,
+ """
+x = T.Var("x", "int64")
+a: R.Tensor((1, x, 3), dtype="float32")
+y = T.Var("y", "int64")
+b: R.Tensor((1, y, 3), dtype="float32")
+z = T.Var("z", "int64")
+c: R.Tensor((1, z, 3), dtype="float32")
+(a, b, c)[0]
+""",
+ )
+
+
+def test_shape_expr():
+ obj = relax.ShapeExpr([1, 2, 3])
+ _assert_print(obj, "(1, 2, 3)")
+
+
+def test_call():
+ x = tir.Var("x", "int64")
+ a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32"))
+ obj = relax.call_tir("my_func", args=a, out_sinfo=a.struct_info,
tir_vars=[x])
+ _assert_print(
+ obj,
+ """
+x = T.Var("x", "int64")
+a: R.Tensor((1, x, 3), dtype="float32")
+R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"),
tir_vars=(x,))
+""",
+ )
+
+
[email protected](reason="`relax.op.sin` is not upstreamed yet")
+def test_seq_expr():
+ x = tir.Var("x", "int64")
+ a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32"))
+ b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32"))
+ c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32"))
+
+ obj = relax.SeqExpr(
+ blocks=[
+ relax.DataflowBlock(
+ bindings=[
+ relax.VarBinding(b, relax.op.sin(a)),
+ relax.VarBinding(c, relax.op.sin(b)),
+ ]
+ ),
+ ],
+ body=c,
+ )
+ _assert_print(
+ obj,
+ """
+x = T.Var("x", "int64")
+a: R.Tensor((1, x, 3), dtype="float32")
+with R.dataflow():
+ b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a)
+ c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b)
+ R.output(c)
+c
+""",
+ )
+
+
[email protected](reason="`relax.op.sin` is not upstreamed yet")
+def test_binding_block():
+ x = tir.Var("x", "int64")
+ a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32"))
+ b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32"))
+ c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32"))
+ obj = relax.BindingBlock(
+ bindings=[
+ relax.VarBinding(b, relax.op.sin(a)),
+ relax.VarBinding(c, relax.op.sin(b)),
+ ]
+ )
+ _assert_print(
+ obj,
+ """
+x = T.Var("x", "int64")
+a: R.Tensor((1, x, 3), dtype="float32")
+b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a)
+c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b)
+""",
+ )
+
+
[email protected](reason="`relax.op.sin` is not upstreamed yet")
+def test_dataflow_block():
+ x = tir.Var("x", "int64")
+ a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32"))
+ b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32"))
+ c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32"))
+ obj = relax.DataflowBlock(
+ bindings=[
+ relax.VarBinding(b, relax.op.sin(a)),
+ relax.VarBinding(c, relax.op.sin(b)),
+ ]
+ )
+ _assert_print(
+ obj,
+ """
+x = T.Var("x", "int64")
+a: R.Tensor((1, x, 3), dtype="float32")
+with R.dataflow():
+ b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a)
+ c: R.Tensor((1, x, 3), dtype="float32") = R.sin(b)
+ R.output(c)
+""",
+ )
+
+
+def test_match_cast():
+ x = tir.Var("x", "int64")
+ a = relax.Var("a", relax.TensorStructInfo([1, x, 3]))
+ b = relax.Var("b", relax.TensorStructInfo([1, 5, 3]))
+ obj = relax.MatchCast(
+ var=b,
+ value=a,
+ struct_info=b.struct_info,
+ )
+ _assert_print(
+ obj,
+ """
+x = T.Var("x", "int64")
+a: R.Tensor((1, x, 3), dtype="float32")
+b: R.Tensor((1, 5, 3), dtype="float32") = R.match_cast(a, R.Tensor((1, 5, 3),
dtype="float32"))
+""",
+ )
+
+
[email protected](reason="`relax.op.sin` is not upstreamed yet")
+def test_var_binding():
+ x = tir.Var("x", "int64")
+ a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32"))
+ b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32"))
+ obj = relax.VarBinding(b, relax.op.sin(a))
+ _assert_print(
+ obj,
+ """
+x = T.Var("x", "int64")
+a: R.Tensor((1, x, 3), dtype="float32")
+b: R.Tensor((1, x, 3), dtype="float32") = R.sin(a)
+""",
+ )
+
+
+def test_if():
+ a = relax.Var("a", relax.TensorStructInfo([], "bool"))
+ b = relax.Var("b", relax.TensorStructInfo([1, 2, 3], "float32"))
+ c = relax.Var("c", relax.TensorStructInfo([1, 2, 3], "float32"))
+ obj = relax.If(
+ a,
+ relax.SeqExpr([], b),
+ relax.SeqExpr([], c),
+ )
+ _assert_print(
+ obj,
+ """
+a: R.Tensor((), dtype="bool")
+if a:
+ b: R.Tensor((1, 2, 3), dtype="float32")
+ b
+else:
+ c: R.Tensor((1, 2, 3), dtype="float32")
+ c
+""",
+ )
+
+
+if __name__ == "__main__":
+ test_function()
+ test_extern_func()
+
+ test_object_struct_info()
+ test_prim_struct_info()
+ test_shape_struct_info_0()
+ test_shape_struct_info_1()
+ test_shape_struct_info_2()
+ test_tensor_struct_info()
+ test_tuple_struct_info_empty()
+ test_tuple_struct_info()
+ test_func_struct_info()
+
+ test_shape_type()
+ test_object_type()
+ test_dyn_tensor_type()
+ test_packed_func_type()
+ test_tuple_type()
+ test_func_type()
+
+ test_prim_value()
+ test_string_imm()
+ test_data_type_imm()
+
+ test_var()
+ test_dataflow_var()
+ #
+ test_tuple()
+ test_tuple_get_item()
+ test_shape_expr()
+ test_call()
+
+ test_seq_expr()
+ test_binding_block()
+ test_dataflow_block()
+
+ test_match_cast()
+ test_var_binding()
+ test_if()