This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1527bfba04 [TVMScript][Bugfix] Tuple on the RHS of AssignDoc (#14452)
1527bfba04 is described below
commit 1527bfba0453d0f375c8c200fa3ee668680e2e56
Author: Junru Shao <[email protected]>
AuthorDate: Sun Apr 2 16:59:19 2023 -0700
[TVMScript][Bugfix] Tuple on the RHS of AssignDoc (#14452)
This PR fixes a bug in TVMScript printer handling tuples on the RHS of
an assignment, i.e.:
```python
... = (a, b, c)
^^^^^^^^^
```
The existing sugar removes the brackets surrounding `(a, b, c)`,
which makes it slightly more readable as:
```python
... = a, b, c
```
However, it overlooks a possibility where the tuple could be empty or
has only one element, i.e.
```python
... = (a, ) // Case 1: The tuple has only one element
... = () // Case 2: The tuple is empty
```
In both cases, removing brackets may lead to wrong outcome.
This patch fixes this bug.
---
.../printer/doc_printer/python_doc_printer.cc | 6 +++++-
tests/python/relax/test_tvmscript_parser.py | 22 ++++++++++++++++++++++
2 files changed, 27 insertions(+), 1 deletion(-)
diff --git a/src/script/printer/doc_printer/python_doc_printer.cc
b/src/script/printer/doc_printer/python_doc_printer.cc
index e726cd42a2..54194e7e2a 100644
--- a/src/script/printer/doc_printer/python_doc_printer.cc
+++ b/src/script/printer/doc_printer/python_doc_printer.cc
@@ -549,7 +549,11 @@ void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc)
{
if (doc->rhs) {
output_ << " = ";
if (const auto* tuple_doc = doc->rhs.as<TupleDocNode>()) {
- PrintJoinedDocs(tuple_doc->elements, ", ");
+ if (tuple_doc->elements.size() > 1) {
+ PrintJoinedDocs(tuple_doc->elements, ", ");
+ } else {
+ PrintDoc(doc->rhs.value());
+ }
} else {
PrintDoc(doc->rhs.value());
}
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index c814c7abe8..0e0905ffbc 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1325,5 +1325,27 @@ def test_context_aware_parsing():
_check(Module)
+def test_unit_tuple_on_rhs_of_assign():
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(input: R.Tensor((5, 5))) -> R.Tuple(R.Tensor((5, 5))):
+ gv = (input,)
+ return gv
+
+ _check(Module)
+
+
+def test_empty_tuple_on_rhs_of_assign():
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(input: R.Tensor((5, 5))) -> R.Tuple():
+ gv = ()
+ return gv
+
+ _check(Module)
+
+
if __name__ == "__main__":
tvm.testing.main()