This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bf4394762a [Unity][TVMScript] Optionally hide StructInfo that can be
inferred (#16356)
bf4394762a is described below
commit bf4394762ab7e577a4fb675f9796f58cf783dadf
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Feb 12 12:06:09 2024 -0600
[Unity][TVMScript] Optionally hide StructInfo that can be inferred (#16356)
* [Unity][TVMScript] Optionally hide StructInfo that can be inferred
By default, TVMScript prints the struct info of every variable being
bound, which can become quite verbose. This commit adds the
configuration option `show_inferable_type_annotations`, which
determines whether struct info annotations are shown in cases where
they can be inferred.
The `show_inferable_type_annotations` option defaults to `True`,
preserving the current default behavior.
* Rename show_inferable_type_annotations to show_all_struct_info
* Add unit test for round-trip of opaque function
---
include/tvm/node/script_printer.h | 35 ++++++++++++++++
python/tvm/runtime/script_printer.py | 24 +++++++++--
src/node/script_printer.cc | 3 ++
src/script/printer/relax/binding.cc | 5 ++-
src/script/printer/relax/utils.h | 39 +++++++++++++++++
tests/python/relax/test_tvmscript_printer_relax.py | 49 ++++++++++++++++++++++
tests/python/tvmscript/test_tvmscript_roundtrip.py | 42 ++++++++++++++++++-
7 files changed, 191 insertions(+), 6 deletions(-)
diff --git a/include/tvm/node/script_printer.h
b/include/tvm/node/script_printer.h
index 703844d99a..2b812eef32 100644
--- a/include/tvm/node/script_printer.h
+++ b/include/tvm/node/script_printer.h
@@ -72,6 +72,40 @@ class PrinterConfigNode : public Object {
bool syntax_sugar = true;
/*! \brief Whether variable names should include the object's address */
bool show_object_address = false;
+
+ /*! \brief In Relax, whether to show all StructInfo annotations
+ *
+ * If true (default), all variable bindings will be annotated with
+ * the struct info of the variable being bound.
+ *
+ * If false, the annotations will only be shown when they are
+ * required for correct parsing of the Relax function. For example,
+ * function parameters must always have struct info annotations, but
+ * the struct info for expressions within a function body may be inferred
from their
+ * arguments, and are therefore
+ *
+ * Example:
+ *
+ * # func.show(show_all_struct_info=True)
+ * @R.function
+ * def func(
+ * A: R.Tensor((10, 20), dtype="float32"),
+ * B: R.Tensor((10,20), dtype="float32"),
+ * ) -> R.Tensor((10, 20), dtype="float32"):
+ * C: R.Tensor((10,20), dtype="float32") = R.add(A, B2)
+ * return C
+ *
+ * # func.show(show_all_struct_info=False)
+ * @R.function
+ * def func(
+ * A: R.Tensor((10, 20), dtype="float32"),
+ * B: R.Tensor((10,20), dtype="float32"),
+ * ) -> R.Tensor((10, 20), dtype="float32"):
+ * C = R.add(A, B2)
+ * return C
+ */
+ bool show_all_struct_info = true;
+
/* \brief Object path to be underlined */
Array<ObjectPath> path_to_underline = Array<ObjectPath>();
/*! \brief Object path to be annotated. */
@@ -97,6 +131,7 @@ class PrinterConfigNode : public Object {
v->Visit("num_context_lines", &num_context_lines);
v->Visit("syntax_sugar", &syntax_sugar);
v->Visit("show_object_address", &show_object_address);
+ v->Visit("show_all_struct_info", &show_all_struct_info);
v->Visit("path_to_underline", &path_to_underline);
v->Visit("path_to_annotate", &path_to_annotate);
v->Visit("obj_to_underline", &obj_to_underline);
diff --git a/python/tvm/runtime/script_printer.py
b/python/tvm/runtime/script_printer.py
index 260d0ead9d..ad3f612c4e 100644
--- a/python/tvm/runtime/script_printer.py
+++ b/python/tvm/runtime/script_printer.py
@@ -44,6 +44,7 @@ class PrinterConfig(Object):
num_context_lines: int
syntax_sugar: bool
show_object_address: bool
+ show_all_struct_info: bool
path_to_underline: Optional[List[ObjectPath]]
path_to_annotate: Optional[Dict[ObjectPath, str]]
obj_to_underline: Optional[List[Object]]
@@ -67,6 +68,7 @@ class PrinterConfig(Object):
num_context_lines: Optional[int] = None,
syntax_sugar: bool = True,
show_object_address: bool = False,
+ show_all_struct_info: bool = True,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
@@ -89,6 +91,7 @@ class PrinterConfig(Object):
"num_context_lines": num_context_lines,
"syntax_sugar": syntax_sugar,
"show_object_address": show_object_address,
+ "show_all_struct_info": show_all_struct_info,
"path_to_underline": path_to_underline,
"path_to_annotate": path_to_annotate,
"obj_to_underline": obj_to_underline,
@@ -132,6 +135,7 @@ class Scriptable:
num_context_lines: int = -1,
syntax_sugar: bool = True,
show_object_address: bool = False,
+ show_all_struct_info: bool = True,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
@@ -169,9 +173,13 @@ class Scriptable:
num_context_lines : int = -1
The number of lines of context to print before and after the line
to underline.
syntax_sugar: bool = True
- Whether to output with syntax sugar, set false for complete
printing.
+ Whether to output with syntax sugar, set false for complete
printing.
show_object_address: bool = False
- Whether to include the object's address as part of the TVMScript
name
+ Whether to include the object's address as part of the TVMScript
name
+ show_all_struct_info: bool = True
+ If True (default), annotate all variable bindings with the struct
+ info of that variable. If False, only add annotations where
+ required for unambiguous round-trip of Relax -> TVMScript -> Relax.
path_to_underline : Optional[List[ObjectPath]] = None
Object path to be underlined
path_to_annotate : Optional[Dict[ObjectPath, str]] = None
@@ -185,6 +193,7 @@ class Scriptable:
-------
script : str
The TVM Script of the given TVM IR
+
"""
return _script(
self,
@@ -204,6 +213,7 @@ class Scriptable:
num_context_lines=num_context_lines,
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
+ show_all_struct_info=show_all_struct_info,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
@@ -279,6 +289,7 @@ class Scriptable:
num_context_lines: int = -1,
syntax_sugar: bool = True,
show_object_address: bool = False,
+ show_all_struct_info: bool = True,
path_to_underline: Optional[List[ObjectPath]] = None,
path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
obj_to_underline: Optional[List[Object]] = None,
@@ -339,9 +350,13 @@ class Scriptable:
num_context_lines : int = -1
The number of lines of context to print before and after the line
to underline.
syntax_sugar: bool = True
- Whether to output with syntax sugar, set false for complete
printing.
+ Whether to output with syntax sugar, set false for complete
printing.
show_object_address: bool = False
- Whether to include the object's address as part of the TVMScript
name
+ Whether to include the object's address as part of the TVMScript
name
+ show_all_struct_info: bool = True
+ If True (default), annotate all variable bindings with the struct
+ info of that variable. If False, only add annotations where
+ required for unambiguous round-trip of Relax -> TVMScript -> Relax.
path_to_underline : Optional[List[ObjectPath]] = None
Object path to be underlined
path_to_annotate : Optional[Dict[ObjectPath, str]] = None
@@ -377,6 +392,7 @@ class Scriptable:
num_context_lines=num_context_lines,
syntax_sugar=syntax_sugar,
show_object_address=show_object_address,
+ show_all_struct_info=show_all_struct_info,
path_to_underline=path_to_underline,
path_to_annotate=path_to_annotate,
obj_to_underline=obj_to_underline,
diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc
index 38334de357..6e7d82ee4a 100644
--- a/src/node/script_printer.cc
+++ b/src/node/script_printer.cc
@@ -112,6 +112,9 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef>
config_dict) {
if (auto v = config_dict.Get("show_object_address")) {
n->show_object_address = Downcast<IntImm>(v)->value;
}
+ if (auto v = config_dict.Get("show_all_struct_info")) {
+ n->show_all_struct_info = Downcast<IntImm>(v)->value;
+ }
// Checking prefixes if they are valid Python identifiers.
CHECK(IsIdentifier(n->ir_prefix)) << "Invalid `ir_prefix`: " << n->ir_prefix;
diff --git a/src/script/printer/relax/binding.cc
b/src/script/printer/relax/binding.cc
index 395b4251fb..acf0072c0f 100644
--- a/src/script/printer/relax/binding.cc
+++ b/src/script/printer/relax/binding.cc
@@ -44,7 +44,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
"", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc {
using relax::StructInfo;
using relax::MatchStructInfo;
- Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d,
n->value);
+ Optional<ExprDoc> ann = NullOpt;
+ if (d->cfg->show_all_struct_info) {
+ ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
+ }
ExprDoc rhs = Relax(d, "match_cast")
->Call({d->AsDoc<ExprDoc>(n->value,
n_p->Attr("value")),
d->AsDoc<ExprDoc>(n->struct_info,
n_p->Attr("struct_info_"))});
diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h
index 58b8bf4431..989e9a63b1 100644
--- a/src/script/printer/relax/utils.h
+++ b/src/script/printer/relax/utils.h
@@ -19,6 +19,8 @@
#ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
#define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/type.h>
#include <tvm/relax/utils.h>
@@ -82,10 +84,47 @@ inline Optional<ExprDoc> StructInfoAsAnn(const relax::Var&
v, const ObjectPath&
if (!v->struct_info_.defined()) {
return NullOpt;
}
+ bool attempt_to_hide_struct_info = !d->cfg->show_all_struct_info;
+
if (const auto* call = rhs.as<relax::CallNode>()) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op))
{
+ attempt_to_hide_struct_info = true;
+ }
+ }
+ if (attempt_to_hide_struct_info) {
+ Optional<relax::StructInfo> inferred_sinfo = NullOpt;
+ if (auto opt = rhs.as<relax::Call>()) {
+ auto call = opt.value();
+ if (auto opt = call->op.as<Op>()) {
+ auto op = opt.value();
+
+ static auto op_map_infer_struct_info =
+ Op::GetAttrMap<relax::FInferStructInfo>("FInferStructInfo");
+
+ auto temp_builder = relax::BlockBuilder::Create(NullOpt);
+ inferred_sinfo = op_map_infer_struct_info[op](call, temp_builder);
+ } else if (auto opt = call->op.as<relax::FuncStructInfo>()) {
+ auto temp_builder = relax::BlockBuilder::Create(NullOpt);
+ inferred_sinfo =
+ DeriveCallRetStructInfo(opt.value(), call, temp_builder,
temp_builder->GetAnalyzer());
+ }
+
+ } else if (const auto* tuple = rhs.as<relax::TupleNode>()) {
+ inferred_sinfo =
relax::TupleStructInfo(tuple->fields.Map(relax::GetStructInfo));
+
+ } else if (const auto* get_item = rhs.as<relax::TupleGetItemNode>()) {
+ if (auto ptr =
get_item->tuple->struct_info_.as<relax::TupleStructInfoNode>();
+ ptr && get_item->index < static_cast<int>(ptr->fields.size())) {
+ inferred_sinfo = ptr->fields[get_item->index];
+ }
+
+ } else if (const auto* trivial_binding = rhs.as<relax::VarNode>()) {
+ inferred_sinfo = trivial_binding->struct_info_.as<relax::StructInfo>();
+ }
+
+ if (inferred_sinfo && StructuralEqual()(inferred_sinfo, v->struct_info_)) {
return NullOpt;
}
}
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py
b/tests/python/relax/test_tvmscript_printer_relax.py
index 530e45e610..a75977ff99 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -829,5 +829,54 @@ def func(x: R.Tensor((128, 128), dtype="float32")) ->
R.Tensor((128, 128), dtype
)
+def test_hide_inferable_struct_info():
+ """Redundant type annotations can be omitted
+
+ When `show_all_struct_info=False`, TVMScript type annotations that
+ provide redundant struct info can be omitted.
+ """
+
+ @R.function
+ def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2,
dtype="float32")):
+ # R.match_cast has the struct info as an argument, so it can
+ # be omitted from the variable annotation.
+ B2 = R.match_cast(B, R.Tensor([10, 20], "float32"))
+
+ # Call nodes may have inferable shapes from their arguments.
+ C = R.add(A, B2)
+
+ # Trivial bindings can be inferred to have the same struct
+ # info as the RHS.
+ D = C
+
+ # Here, the struct info cannot be omitted. `R.add(D,B)` has
+ # struct info `R.Tensor(ndim=2)`, but the variable has a shape
+ # `R.Tensor([10,20])`. This is compatible, so it is not an
+ # error to have this annotation, but it is not inferrable from
+ # the RHS. Therefore, it must still be printed.
+ E: R.Tensor([10, 20], "float32") = R.add(D, B)
+
+ # The return type can be inferred from function body, but is
+ # still always printed in the TVMScript. When parsing an
+ # IRModule with functions calling each other, the return type
+ # of each callee must be available for use in the caller's
+ # shape inference.
+ return E
+
+ _assert_print(
+ func.script(show_all_struct_info=False),
+ """
+# from tvm.script import relax as R
+
[email protected]
+def func(A: R.Tensor((10, 20), dtype="float32"), B: R.Tensor(dtype="float32",
ndim=2)) -> R.Tensor((10, 20), dtype="float32"):
+ B2 = R.match_cast(B, R.Tensor((10, 20), dtype="float32"))
+ C = R.add(A, B2)
+ D = C
+ E: R.Tensor((10, 20), dtype="float32") = R.add(D, B)
+ return E""",
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index 5b3e68e22f..66eef5ad81 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -21,7 +21,7 @@ import pytest
import tvm
import tvm.testing
from tvm import tir
-from tvm.script import tir as T, ir as I
+from tvm.script import tir as T, ir as I, relax as R
import numpy as np
@@ -3996,6 +3996,24 @@ def op_of_literal():
yield make_ir_generator(op, arg)
+def relax_extern_func():
+ @R.function
+ def func(A: R.Tensor([10, 20], "float32")):
+ func = R.ExternFunc("dummy_func")
+
+ B: R.Tensor([10, 20], "float32") = R.call_dps_packed(
+ func, [A], out_sinfo=R.Tensor([10, 20], "float32")
+ )
+
+ C: R.Tensor(ndim=2, dtype="float32") = R.call_dps_packed(
+ func, [B], out_sinfo=R.Tensor([10, 20], "float32")
+ )
+
+ return C
+
+ return func
+
+
ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
@@ -4081,6 +4099,17 @@ ir_generator = tvm.testing.parameter(
*op_of_literal(),
)
+relax_ir_generator = tvm.testing.parameter(
+ relax_extern_func,
+)
+
+show_all_relax_struct_info = tvm.testing.parameter(
+ by_dict={
+ "show_all_struct_info": True,
+ "hide_inferable_struct_info": False,
+ }
+)
+
def test_roundtrip(ir_generator):
original = ir_generator()
@@ -4088,6 +4117,17 @@ def test_roundtrip(ir_generator):
tvm.ir.assert_structural_equal(original, after_roundtrip, True)
+def test_relax_roundtrip(relax_ir_generator, show_all_relax_struct_info):
+ original = relax_ir_generator()
+ after_roundtrip = tvm.script.from_source(
+ original.script(
+ show_meta=True,
+ show_all_struct_info=show_all_relax_struct_info,
+ )
+ )
+ tvm.ir.assert_structural_equal(original, after_roundtrip, True)
+
+
def test_return_none_no_trailing_type():
func = return_none()
script = func.script()