This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bf4394762a [Unity][TVMScript] Optionally hide StructInfo that can be 
inferred (#16356)
bf4394762a is described below

commit bf4394762ab7e577a4fb675f9796f58cf783dadf
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Feb 12 12:06:09 2024 -0600

    [Unity][TVMScript] Optionally hide StructInfo that can be inferred (#16356)
    
    * [Unity][TVMScript] Optionally hide StructInfo that can be inferred
    
    By default, TVMScript prints the struct info of every variable being
    bound, which can become quite verbose.  This commit adds the
    configuration option `show_inferable_type_annotations`, which
    determines whether struct info annotations are shown in cases where
    they can be inferred.
    
    The `show_inferable_type_annotations` option defaults to `True`,
    preserving the current default behavior.
    
    * Rename show_inferable_type_annotations to show_all_struct_info
    
    * Add unit test for round-trip of opaque function
---
 include/tvm/node/script_printer.h                  | 35 ++++++++++++++++
 python/tvm/runtime/script_printer.py               | 24 +++++++++--
 src/node/script_printer.cc                         |  3 ++
 src/script/printer/relax/binding.cc                |  5 ++-
 src/script/printer/relax/utils.h                   | 39 +++++++++++++++++
 tests/python/relax/test_tvmscript_printer_relax.py | 49 ++++++++++++++++++++++
 tests/python/tvmscript/test_tvmscript_roundtrip.py | 42 ++++++++++++++++++-
 7 files changed, 191 insertions(+), 6 deletions(-)

diff --git a/include/tvm/node/script_printer.h 
b/include/tvm/node/script_printer.h
index 703844d99a..2b812eef32 100644
--- a/include/tvm/node/script_printer.h
+++ b/include/tvm/node/script_printer.h
@@ -72,6 +72,40 @@ class PrinterConfigNode : public Object {
   bool syntax_sugar = true;
   /*! \brief Whether variable names should include the object's address */
   bool show_object_address = false;
+
+  /*! \brief In Relax, whether to show all StructInfo annotations
+   *
+   * If true (default), all variable bindings will be annotated with
+   * the struct info of the variable being bound.
+   *
+   * If false, the annotations will only be shown when they are
+   * required for correct parsing of the Relax function.  For example,
+   * function parameters must always have struct info annotations, but
+   * the struct info for expressions within a function body may be inferred 
from their
+   * arguments, and are therefore
+   *
+   * Example:
+   *
+   *     # func.show(show_all_struct_info=True)
+   *     @R.function
+   *     def func(
+   *         A: R.Tensor((10, 20), dtype="float32"),
+   *         B: R.Tensor((10,20), dtype="float32"),
+   *     ) -> R.Tensor((10, 20), dtype="float32"):
+   *         C: R.Tensor((10,20), dtype="float32") = R.add(A, B2)
+   *         return C
+   *
+   *     # func.show(show_all_struct_info=False)
+   *     @R.function
+   *     def func(
+   *         A: R.Tensor((10, 20), dtype="float32"),
+   *         B: R.Tensor((10,20), dtype="float32"),
+   *     ) -> R.Tensor((10, 20), dtype="float32"):
+   *         C = R.add(A, B2)
+   *         return C
+   */
+  bool show_all_struct_info = true;
+
   /* \brief Object path to be underlined */
   Array<ObjectPath> path_to_underline = Array<ObjectPath>();
   /*! \brief Object path to be annotated. */
@@ -97,6 +131,7 @@ class PrinterConfigNode : public Object {
     v->Visit("num_context_lines", &num_context_lines);
     v->Visit("syntax_sugar", &syntax_sugar);
     v->Visit("show_object_address", &show_object_address);
+    v->Visit("show_all_struct_info", &show_all_struct_info);
     v->Visit("path_to_underline", &path_to_underline);
     v->Visit("path_to_annotate", &path_to_annotate);
     v->Visit("obj_to_underline", &obj_to_underline);
diff --git a/python/tvm/runtime/script_printer.py 
b/python/tvm/runtime/script_printer.py
index 260d0ead9d..ad3f612c4e 100644
--- a/python/tvm/runtime/script_printer.py
+++ b/python/tvm/runtime/script_printer.py
@@ -44,6 +44,7 @@ class PrinterConfig(Object):
     num_context_lines: int
     syntax_sugar: bool
     show_object_address: bool
+    show_all_struct_info: bool
     path_to_underline: Optional[List[ObjectPath]]
     path_to_annotate: Optional[Dict[ObjectPath, str]]
     obj_to_underline: Optional[List[Object]]
@@ -67,6 +68,7 @@ class PrinterConfig(Object):
         num_context_lines: Optional[int] = None,
         syntax_sugar: bool = True,
         show_object_address: bool = False,
+        show_all_struct_info: bool = True,
         path_to_underline: Optional[List[ObjectPath]] = None,
         path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
         obj_to_underline: Optional[List[Object]] = None,
@@ -89,6 +91,7 @@ class PrinterConfig(Object):
             "num_context_lines": num_context_lines,
             "syntax_sugar": syntax_sugar,
             "show_object_address": show_object_address,
+            "show_all_struct_info": show_all_struct_info,
             "path_to_underline": path_to_underline,
             "path_to_annotate": path_to_annotate,
             "obj_to_underline": obj_to_underline,
@@ -132,6 +135,7 @@ class Scriptable:
         num_context_lines: int = -1,
         syntax_sugar: bool = True,
         show_object_address: bool = False,
+        show_all_struct_info: bool = True,
         path_to_underline: Optional[List[ObjectPath]] = None,
         path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
         obj_to_underline: Optional[List[Object]] = None,
@@ -169,9 +173,13 @@ class Scriptable:
         num_context_lines : int = -1
             The number of lines of context to print before and after the line 
to underline.
         syntax_sugar: bool = True
-             Whether to output with syntax sugar, set false for complete 
printing.
+            Whether to output with syntax sugar, set false for complete 
printing.
         show_object_address: bool = False
-             Whether to include the object's address as part of the TVMScript 
name
+            Whether to include the object's address as part of the TVMScript 
name
+        show_all_struct_info: bool = True
+            If True (default), annotate all variable bindings with the struct
+            info of that variable.  If False, only add annotations where
+            required for unambiguous round-trip of Relax -> TVMScript -> Relax.
         path_to_underline : Optional[List[ObjectPath]] = None
             Object path to be underlined
         path_to_annotate : Optional[Dict[ObjectPath, str]] = None
@@ -185,6 +193,7 @@ class Scriptable:
         -------
         script : str
             The TVM Script of the given TVM IR
+
         """
         return _script(
             self,
@@ -204,6 +213,7 @@ class Scriptable:
                 num_context_lines=num_context_lines,
                 syntax_sugar=syntax_sugar,
                 show_object_address=show_object_address,
+                show_all_struct_info=show_all_struct_info,
                 path_to_underline=path_to_underline,
                 path_to_annotate=path_to_annotate,
                 obj_to_underline=obj_to_underline,
@@ -279,6 +289,7 @@ class Scriptable:
         num_context_lines: int = -1,
         syntax_sugar: bool = True,
         show_object_address: bool = False,
+        show_all_struct_info: bool = True,
         path_to_underline: Optional[List[ObjectPath]] = None,
         path_to_annotate: Optional[Dict[ObjectPath, str]] = None,
         obj_to_underline: Optional[List[Object]] = None,
@@ -339,9 +350,13 @@ class Scriptable:
         num_context_lines : int = -1
             The number of lines of context to print before and after the line 
to underline.
         syntax_sugar: bool = True
-             Whether to output with syntax sugar, set false for complete 
printing.
+            Whether to output with syntax sugar, set false for complete 
printing.
         show_object_address: bool = False
-             Whether to include the object's address as part of the TVMScript 
name
+            Whether to include the object's address as part of the TVMScript 
name
+        show_all_struct_info: bool = True
+            If True (default), annotate all variable bindings with the struct
+            info of that variable.  If False, only add annotations where
+            required for unambiguous round-trip of Relax -> TVMScript -> Relax.
         path_to_underline : Optional[List[ObjectPath]] = None
             Object path to be underlined
         path_to_annotate : Optional[Dict[ObjectPath, str]] = None
@@ -377,6 +392,7 @@ class Scriptable:
                 num_context_lines=num_context_lines,
                 syntax_sugar=syntax_sugar,
                 show_object_address=show_object_address,
+                show_all_struct_info=show_all_struct_info,
                 path_to_underline=path_to_underline,
                 path_to_annotate=path_to_annotate,
                 obj_to_underline=obj_to_underline,
diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc
index 38334de357..6e7d82ee4a 100644
--- a/src/node/script_printer.cc
+++ b/src/node/script_printer.cc
@@ -112,6 +112,9 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef> 
config_dict) {
   if (auto v = config_dict.Get("show_object_address")) {
     n->show_object_address = Downcast<IntImm>(v)->value;
   }
+  if (auto v = config_dict.Get("show_all_struct_info")) {
+    n->show_all_struct_info = Downcast<IntImm>(v)->value;
+  }
 
   // Checking prefixes if they are valid Python identifiers.
   CHECK(IsIdentifier(n->ir_prefix)) << "Invalid `ir_prefix`: " << n->ir_prefix;
diff --git a/src/script/printer/relax/binding.cc 
b/src/script/printer/relax/binding.cc
index 395b4251fb..acf0072c0f 100644
--- a/src/script/printer/relax/binding.cc
+++ b/src/script/printer/relax/binding.cc
@@ -44,7 +44,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
         "", [](relax::MatchCast n, ObjectPath n_p, IRDocsifier d) -> Doc {
           using relax::StructInfo;
           using relax::MatchStructInfo;
-          Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, 
n->value);
+          Optional<ExprDoc> ann = NullOpt;
+          if (d->cfg->show_all_struct_info) {
+            ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
+          }
           ExprDoc rhs = Relax(d, "match_cast")
                             ->Call({d->AsDoc<ExprDoc>(n->value, 
n_p->Attr("value")),
                                     d->AsDoc<ExprDoc>(n->struct_info, 
n_p->Attr("struct_info_"))});
diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h
index 58b8bf4431..989e9a63b1 100644
--- a/src/script/printer/relax/utils.h
+++ b/src/script/printer/relax/utils.h
@@ -19,6 +19,8 @@
 #ifndef TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
 #define TVM_SCRIPT_PRINTER_RELAX_UTILS_H_
 
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/op_attr_types.h>
 #include <tvm/relax/struct_info.h>
 #include <tvm/relax/type.h>
 #include <tvm/relax/utils.h>
@@ -82,10 +84,47 @@ inline Optional<ExprDoc> StructInfoAsAnn(const relax::Var& 
v, const ObjectPath&
   if (!v->struct_info_.defined()) {
     return NullOpt;
   }
+  bool attempt_to_hide_struct_info = !d->cfg->show_all_struct_info;
+
   if (const auto* call = rhs.as<relax::CallNode>()) {
     static const Op& call_tir_op = Op::Get("relax.call_tir");
     static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
     if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) 
{
+      attempt_to_hide_struct_info = true;
+    }
+  }
+  if (attempt_to_hide_struct_info) {
+    Optional<relax::StructInfo> inferred_sinfo = NullOpt;
+    if (auto opt = rhs.as<relax::Call>()) {
+      auto call = opt.value();
+      if (auto opt = call->op.as<Op>()) {
+        auto op = opt.value();
+
+        static auto op_map_infer_struct_info =
+            Op::GetAttrMap<relax::FInferStructInfo>("FInferStructInfo");
+
+        auto temp_builder = relax::BlockBuilder::Create(NullOpt);
+        inferred_sinfo = op_map_infer_struct_info[op](call, temp_builder);
+      } else if (auto opt = call->op.as<relax::FuncStructInfo>()) {
+        auto temp_builder = relax::BlockBuilder::Create(NullOpt);
+        inferred_sinfo =
+            DeriveCallRetStructInfo(opt.value(), call, temp_builder, 
temp_builder->GetAnalyzer());
+      }
+
+    } else if (const auto* tuple = rhs.as<relax::TupleNode>()) {
+      inferred_sinfo = 
relax::TupleStructInfo(tuple->fields.Map(relax::GetStructInfo));
+
+    } else if (const auto* get_item = rhs.as<relax::TupleGetItemNode>()) {
+      if (auto ptr = 
get_item->tuple->struct_info_.as<relax::TupleStructInfoNode>();
+          ptr && get_item->index < static_cast<int>(ptr->fields.size())) {
+        inferred_sinfo = ptr->fields[get_item->index];
+      }
+
+    } else if (const auto* trivial_binding = rhs.as<relax::VarNode>()) {
+      inferred_sinfo = trivial_binding->struct_info_.as<relax::StructInfo>();
+    }
+
+    if (inferred_sinfo && StructuralEqual()(inferred_sinfo, v->struct_info_)) {
       return NullOpt;
     }
   }
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py 
b/tests/python/relax/test_tvmscript_printer_relax.py
index 530e45e610..a75977ff99 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -829,5 +829,54 @@ def func(x: R.Tensor((128, 128), dtype="float32")) -> 
R.Tensor((128, 128), dtype
     )
 
 
+def test_hide_inferable_struct_info():
+    """Redundant type annotations can be omitted
+
+    When `show_all_struct_info=False`, TVMScript type annotations that
+    provide redundant struct info can be omitted.
+    """
+
+    @R.function
+    def func(A: R.Tensor([10, 20], "float32"), B: R.Tensor(ndim=2, 
dtype="float32")):
+        # R.match_cast has the struct info as an argument, so it can
+        # be omitted from the variable annotation.
+        B2 = R.match_cast(B, R.Tensor([10, 20], "float32"))
+
+        # Call nodes may have inferable shapes from their arguments.
+        C = R.add(A, B2)
+
+        # Trivial bindings can be inferred to have the same struct
+        # info as the RHS.
+        D = C
+
+        # Here, the struct info cannot be omitted.  `R.add(D,B)` has
+        # struct info `R.Tensor(ndim=2)`, but the variable has a shape
+        # `R.Tensor([10,20])`.  This is compatible, so it is not an
+        # error to have this annotation, but it is not inferrable from
+        # the RHS.  Therefore, it must still be printed.
+        E: R.Tensor([10, 20], "float32") = R.add(D, B)
+
+        # The return type can be inferred from function body, but is
+        # still always printed in the TVMScript.  When parsing an
+        # IRModule with functions calling each other, the return type
+        # of each callee must be available for use in the caller's
+        # shape inference.
+        return E
+
+    _assert_print(
+        func.script(show_all_struct_info=False),
+        """
+# from tvm.script import relax as R
+
[email protected]
+def func(A: R.Tensor((10, 20), dtype="float32"), B: R.Tensor(dtype="float32", 
ndim=2)) -> R.Tensor((10, 20), dtype="float32"):
+    B2 = R.match_cast(B, R.Tensor((10, 20), dtype="float32"))
+    C = R.add(A, B2)
+    D = C
+    E: R.Tensor((10, 20), dtype="float32") = R.add(D, B)
+    return E""",
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py 
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index 5b3e68e22f..66eef5ad81 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -21,7 +21,7 @@ import pytest
 import tvm
 import tvm.testing
 from tvm import tir
-from tvm.script import tir as T, ir as I
+from tvm.script import tir as T, ir as I, relax as R
 
 import numpy as np
 
@@ -3996,6 +3996,24 @@ def op_of_literal():
         yield make_ir_generator(op, arg)
 
 
+def relax_extern_func():
+    @R.function
+    def func(A: R.Tensor([10, 20], "float32")):
+        func = R.ExternFunc("dummy_func")
+
+        B: R.Tensor([10, 20], "float32") = R.call_dps_packed(
+            func, [A], out_sinfo=R.Tensor([10, 20], "float32")
+        )
+
+        C: R.Tensor(ndim=2, dtype="float32") = R.call_dps_packed(
+            func, [B], out_sinfo=R.Tensor([10, 20], "float32")
+        )
+
+        return C
+
+    return func
+
+
 ir_generator = tvm.testing.parameter(
     launch_env_thread,
     opt_gemm_normalize,
@@ -4081,6 +4099,17 @@ ir_generator = tvm.testing.parameter(
     *op_of_literal(),
 )
 
+relax_ir_generator = tvm.testing.parameter(
+    relax_extern_func,
+)
+
+show_all_relax_struct_info = tvm.testing.parameter(
+    by_dict={
+        "show_all_struct_info": True,
+        "hide_inferable_struct_info": False,
+    }
+)
+
 
 def test_roundtrip(ir_generator):
     original = ir_generator()
@@ -4088,6 +4117,17 @@ def test_roundtrip(ir_generator):
     tvm.ir.assert_structural_equal(original, after_roundtrip, True)
 
 
+def test_relax_roundtrip(relax_ir_generator, show_all_relax_struct_info):
+    original = relax_ir_generator()
+    after_roundtrip = tvm.script.from_source(
+        original.script(
+            show_meta=True,
+            show_all_struct_info=show_all_relax_struct_info,
+        )
+    )
+    tvm.ir.assert_structural_equal(original, after_roundtrip, True)
+
+
 def test_return_none_no_trailing_type():
     func = return_none()
     script = func.script()

Reply via email to