This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 09f38ac91c [TVMScript][Fix] Print Multi-line String as Metadata 
(#13965)
09f38ac91c is described below

commit 09f38ac91c28640469c8ec4c6b8fd086cc438e7f
Author: Junru Shao <[email protected]>
AuthorDate: Sun Feb 12 06:42:01 2023 -0800

    [TVMScript][Fix] Print Multi-line String as Metadata (#13965)
    
    Multi-line strings might make less sense to be printed out by default,
    as they could be LLVM snippets, CUDA source code and anything hard to
    comprehend but easy to mess up with the TVMScript itself. Therefore,
    this PR is introduced to print them as metadata by default.
---
 python/tvm/script/parser/core/entry.py            |  2 ++
 python/tvm/script/parser/ir/parser.py             | 11 +++++++++++
 src/script/printer/ir/misc.cc                     |  3 +++
 src/script/printer/tir/expr.cc                    |  6 +++++-
 src/script/printer/utils.h                        | 14 +++++++++++---
 tests/python/unittest/test_tvmscript_roundtrip.py |  2 +-
 6 files changed, 33 insertions(+), 5 deletions(-)

diff --git a/python/tvm/script/parser/core/entry.py 
b/python/tvm/script/parser/core/entry.py
index bf6a118672..9e6c100c95 100644
--- a/python/tvm/script/parser/core/entry.py
+++ b/python/tvm/script/parser/core/entry.py
@@ -41,10 +41,12 @@ def parse(program: Union[doc.AST, Any, str], extra_vars: 
Dict[str, Any] = None)
         The parsed TVMScript program.
     """
     if extra_vars is None:
+        import tvm  # pylint: disable=import-outside-toplevel
         from tvm.script.parser import ir  # pylint: 
disable=import-outside-toplevel
         from tvm.script.parser import tir  # pylint: 
disable=import-outside-toplevel
 
         extra_vars = {
+            "tvm": tvm,
             "I": ir,
             "ir": ir,
             "T": tir,
diff --git a/python/tvm/script/parser/ir/parser.py 
b/python/tvm/script/parser/ir/parser.py
index 9532e7e32c..e0268412d2 100644
--- a/python/tvm/script/parser/ir/parser.py
+++ b/python/tvm/script/parser/ir/parser.py
@@ -64,3 +64,14 @@ def _visit_expr(_self: Parser, _node: doc.Expr) -> None:
     node : doc.ClassDef
         The doc AST expression node.
     """
+
+
[email protected](token="default", type_name="Assign")
+def visit_assign(self: Parser, node: doc.Assign) -> None:
+    if len(node.targets) != 1:
+        self.report_error(node, "Consequential assignments like 'a = b = c' 
are not supported.")
+    lhs = node.targets[0]
+    rhs = self.eval_expr(node.value)
+    self.eval_assign(
+        target=lhs, source=rhs, bind_value=lambda _a, _b, _c, value: value, 
allow_shadowing=True
+    )
diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc
index cb78dc3ff5..ef68b89b5b 100644
--- a/src/script/printer/ir/misc.cc
+++ b/src/script/printer/ir/misc.cc
@@ -24,6 +24,9 @@ namespace printer {
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<String>("", [](String s, ObjectPath p, IRDocsifier d) -> Doc 
{
+      if (HasMultipleLines(s)) {
+        return d->AddMetadata(s);
+      }
       return LiteralDoc::Str(s, p);
     });
 
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index cc37f46e60..a5d5d492ff 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -79,7 +79,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<tir::StringImm>("", [](tir::StringImm s, ObjectPath p, 
IRDocsifier d) -> Doc {
-      return d->AsDoc<ExprDoc>(s->value, p->Attr("value"));
+      if (HasMultipleLines(s->value)) {
+        return d->AddMetadata(s);
+      } else {
+        return d->AsDoc<ExprDoc>(s->value, p->Attr("value"));
+      }
     });
 
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
index ade19b3452..90300518b7 100644
--- a/src/script/printer/utils.h
+++ b/src/script/printer/utils.h
@@ -27,6 +27,8 @@
 #include <utility>
 #include <vector>
 
+#include "../../support/str_escape.h"
+
 namespace tvm {
 namespace script {
 namespace printer {
@@ -76,9 +78,10 @@ inline std::string Docsify(const ObjectRef& obj, const 
IRDocsifier& d, const Fra
   std::ostringstream os;
   if (!d->metadata.empty()) {
     if (d->cfg->show_meta) {
-      os << "metadata = tvm.ir.load_json("
-         << SaveJSON(Map<String, ObjectRef>(d->metadata.begin(), 
d->metadata.end())) << ")"
-         << "\n";
+      os << "metadata = tvm.ir.load_json(\""
+         << support::StrEscape(
+                SaveJSON(Map<String, ObjectRef>(d->metadata.begin(), 
d->metadata.end())))
+         << "\")\n";
     } else {
       f->stmts.push_back(
           CommentDoc("Metadata omitted. Use show_meta=True in script() method 
to show it."));
@@ -130,6 +133,11 @@ inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& 
doc) {
   return doc;
 }
 
+/*! \brief Check if a string has multiple lines. */
+inline bool HasMultipleLines(const std::string& str) {
+  return str.find_first_of('\n') != std::string::npos;
+}
+
 inline Optional<String> GetBindingName(const IRDocsifier& d) {
   return d->cfg->binding_names.empty() ? Optional<String>(NullOpt) : 
d->cfg->binding_names.back();
 }
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 05a3270d15..1ec8f49b4b 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3630,7 +3630,7 @@ ir_generator = tvm.testing.parameter(
 
 def test_roundtrip(ir_generator):
     original = ir_generator()
-    after_roundtrip = tvm.script.from_source(original.script())
+    after_roundtrip = tvm.script.from_source(original.script(show_meta=True))
     tvm.ir.assert_structural_equal(original, after_roundtrip, True)
 
 

Reply via email to