This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1527bfba04 [TVMScript][Bugfix] Tuple on the RHS of AssignDoc (#14452)
1527bfba04 is described below

commit 1527bfba0453d0f375c8c200fa3ee668680e2e56
Author: Junru Shao <[email protected]>
AuthorDate: Sun Apr 2 16:59:19 2023 -0700

    [TVMScript][Bugfix] Tuple on the RHS of AssignDoc (#14452)
    
    This PR fixes a bug in TVMScript printer handling tuples on the RHS of
    an assignment, i.e.:
    
    ```python
    ... = (a, b, c)
          ^^^^^^^^^
    ```
    
    The existing  sugar removes the brackets surrounding `(a, b, c)`,
    which makes it slightly more readable as:
    
    ```python
    ... = a, b, c
    ```
    
    However, it overlooks a possibility where the tuple could be empty or
    has only one element, i.e.
    
    ```python
    ... = (a, ) // Case 1: The tuple has only one element
    ... = ()    // Case 2: The tuple is empty
    ```
    
    In both cases, removing brackets may lead to wrong outcome.
    
    This patch fixes this bug.
---
 .../printer/doc_printer/python_doc_printer.cc      |  6 +++++-
 tests/python/relax/test_tvmscript_parser.py        | 22 ++++++++++++++++++++++
 2 files changed, 27 insertions(+), 1 deletion(-)

diff --git a/src/script/printer/doc_printer/python_doc_printer.cc 
b/src/script/printer/doc_printer/python_doc_printer.cc
index e726cd42a2..54194e7e2a 100644
--- a/src/script/printer/doc_printer/python_doc_printer.cc
+++ b/src/script/printer/doc_printer/python_doc_printer.cc
@@ -549,7 +549,11 @@ void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) 
{
   if (doc->rhs) {
     output_ << " = ";
     if (const auto* tuple_doc = doc->rhs.as<TupleDocNode>()) {
-      PrintJoinedDocs(tuple_doc->elements, ", ");
+      if (tuple_doc->elements.size() > 1) {
+        PrintJoinedDocs(tuple_doc->elements, ", ");
+      } else {
+        PrintDoc(doc->rhs.value());
+      }
     } else {
       PrintDoc(doc->rhs.value());
     }
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index c814c7abe8..0e0905ffbc 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1325,5 +1325,27 @@ def test_context_aware_parsing():
     _check(Module)
 
 
+def test_unit_tuple_on_rhs_of_assign():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(input: R.Tensor((5, 5))) -> R.Tuple(R.Tensor((5, 5))):
+            gv = (input,)
+            return gv
+
+    _check(Module)
+
+
+def test_empty_tuple_on_rhs_of_assign():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(input: R.Tensor((5, 5))) -> R.Tuple():
+            gv = ()
+            return gv
+
+    _check(Module)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to