This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 834e998618 [TVMScript] Python Expression Precedence (#12148)
834e998618 is described below

commit 834e998618addb141e5a8b69f918ce5594e752cd
Author: Lite Ye <[email protected]>
AuthorDate: Mon Aug 1 02:45:44 2022 -0400

    [TVMScript] Python Expression Precedence (#12148)
    
    This PR:
    
    - Handle expression (operator) precedence during Python code printing (`(* 
1 (+ 2 3))` prints as
    `1 * (2 + 3)`)
    - Addresses remaining feedback from previous PR #12112
    - Reformats Python import with isort
    
    Tracking issue: #11912
---
 include/tvm/script/printer/doc.h                   |   4 +-
 python/tvm/script/printer/doc.py                   |   4 +-
 src/script/printer/doc.cc                          |   4 +-
 src/script/printer/python_doc_printer.cc           | 176 +++++++++++-
 .../python/unittest/test_tvmscript_printer_doc.py  |  47 +--
 .../test_tvmscript_printer_python_doc_printer.py   | 315 ++++++++++++++++++++-
 6 files changed, 508 insertions(+), 42 deletions(-)

diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index e3dd83743e..408c703d54 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -1067,7 +1067,7 @@ class FunctionDocNode : public StmtDocNode {
   /*! \brief Decorators of function. */
   Array<ExprDoc> decorators;
   /*! \brief The return type of function. */
-  ExprDoc return_type{nullptr};
+  Optional<ExprDoc> return_type{NullOpt};
   /*! \brief The body of function. */
   Array<StmtDoc> body;
 
@@ -1100,7 +1100,7 @@ class FunctionDoc : public StmtDoc {
    * \param body The body of function.
    */
   explicit FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> 
decorators,
-                       ExprDoc return_type, Array<StmtDoc> body);
+                       Optional<ExprDoc> return_type, Array<StmtDoc> body);
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionDoc, StmtDoc, 
FunctionDocNode);
 };
 
diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py
index 747ffc42f1..0a5fde8975 100644
--- a/python/tvm/script/printer/doc.py
+++ b/python/tvm/script/printer/doc.py
@@ -439,7 +439,7 @@ class FunctionDoc(StmtDoc):
     name: IdDoc
     args: Sequence[AssignDoc]
     decorators: Sequence[ExprDoc]
-    return_type: ExprDoc
+    return_type: Optional[ExprDoc]
     body: Sequence[StmtDoc]
 
     def __init__(
@@ -447,7 +447,7 @@ class FunctionDoc(StmtDoc):
         name: IdDoc,
         args: List[AssignDoc],
         decorators: List[ExprDoc],
-        return_type: ExprDoc,
+        return_type: Optional[ExprDoc],
         body: List[StmtDoc],
     ):
         self.__init_handle_by_constructor__(
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index bfff0cfad4..2334d1fad5 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -198,7 +198,7 @@ ReturnDoc::ReturnDoc(ExprDoc value) {
 }
 
 FunctionDoc::FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> 
decorators,
-                         ExprDoc return_type, Array<StmtDoc> body) {
+                         Optional<ExprDoc> return_type, Array<StmtDoc> body) {
   ObjectPtr<FunctionDocNode> n = make_object<FunctionDocNode>();
   n->name = name;
   n->args = args;
@@ -345,7 +345,7 @@ 
TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value)
 TVM_REGISTER_NODE_TYPE(FunctionDocNode);
 TVM_REGISTER_GLOBAL("script.printer.FunctionDoc")
     .set_body_typed([](IdDoc name, Array<AssignDoc> args, Array<ExprDoc> 
decorators,
-                       ExprDoc return_type, Array<StmtDoc> body) {
+                       Optional<ExprDoc> return_type, Array<StmtDoc> body) {
       return FunctionDoc(name, args, decorators, return_type, body);
     });
 
diff --git a/src/script/printer/python_doc_printer.cc 
b/src/script/printer/python_doc_printer.cc
index f44577ff80..536c57abd9 100644
--- a/src/script/printer/python_doc_printer.cc
+++ b/src/script/printer/python_doc_printer.cc
@@ -31,6 +31,111 @@ namespace tvm {
 namespace script {
 namespace printer {
 
+/*!
+ * \brief Operator precedence
+ *
+ * This is based on
+ * https://docs.python.org/3/reference/expressions.html#operator-precedence
+ */
+enum class ExprPrecedence : int32_t {
+  /*! \brief Unknown precedence */
+  kUnkown = 0,
+  /*! \brief Lambda Expression */
+  kLambda = 1,
+  /*! \brief Conditional Expression */
+  kIfThenElse = 2,
+  /*! \brief Boolean OR */
+  kBooleanOr = 3,
+  /*! \brief Boolean AND */
+  kBooleanAnd = 4,
+  /*! \brief Boolean NOT */
+  kBooleanNot = 5,
+  /*! \brief Comparisons */
+  kComparison = 6,
+  /*! \brief Bitwise OR */
+  kBitwiseOr = 7,
+  /*! \brief Bitwise XOR */
+  kBitwiseXor = 8,
+  /*! \brief Bitwise AND */
+  kBitwiseAnd = 9,
+  /*! \brief Shift Operators */
+  kShift = 10,
+  /*! \brief Addition and subtraction */
+  kAdd = 11,
+  /*! \brief Multiplication, division, floor division, remainder */
+  kMult = 12,
+  /*! \brief Positive negative and bitwise NOT */
+  kUnary = 13,
+  /*! \brief Exponentiation */
+  kExp = 14,
+  /*! \brief Index access, attribute access, call and atom expression */
+  kIdentity = 15,
+};
+
+ExprPrecedence GetExprPrecedence(const ExprDoc& doc) {
+  // Key is the value of OperationDocNode::Kind
+  static const std::vector<ExprPrecedence> op_kind_precedence = []() {
+    using OpKind = OperationDocNode::Kind;
+    std::map<OpKind, ExprPrecedence> raw_table = {
+        {OpKind::kUSub, ExprPrecedence::kUnary},
+        {OpKind::kInvert, ExprPrecedence::kUnary},
+        {OpKind::kAdd, ExprPrecedence::kAdd},
+        {OpKind::kSub, ExprPrecedence::kAdd},
+        {OpKind::kMult, ExprPrecedence::kMult},
+        {OpKind::kDiv, ExprPrecedence::kMult},
+        {OpKind::kFloorDiv, ExprPrecedence::kMult},
+        {OpKind::kMod, ExprPrecedence::kMult},
+        {OpKind::kPow, ExprPrecedence::kExp},
+        {OpKind::kLShift, ExprPrecedence::kShift},
+        {OpKind::kRShift, ExprPrecedence::kShift},
+        {OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd},
+        {OpKind::kBitOr, ExprPrecedence::kBitwiseOr},
+        {OpKind::kBitXor, ExprPrecedence::kBitwiseXor},
+        {OpKind::kLt, ExprPrecedence::kComparison},
+        {OpKind::kLtE, ExprPrecedence::kComparison},
+        {OpKind::kEq, ExprPrecedence::kComparison},
+        {OpKind::kNotEq, ExprPrecedence::kComparison},
+        {OpKind::kGt, ExprPrecedence::kComparison},
+        {OpKind::kGtE, ExprPrecedence::kComparison},
+        {OpKind::kIfThenElse, ExprPrecedence::kIfThenElse},
+    };
+    int n = static_cast<int>(OpKind::kSpecialEnd);
+    std::vector<ExprPrecedence> table(n + 1, ExprPrecedence::kUnkown);
+    for (const auto& kv : raw_table) {
+      table[static_cast<int>(kv.first)] = kv.second;
+    }
+    return table;
+  }();
+
+  // Key is the type index of Doc
+  static const std::unordered_map<uint32_t, ExprPrecedence> 
doc_type_precedence = {
+      {LiteralDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+      {IdDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+      {AttrAccessDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+      {IndexDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+      {CallDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+      {LambdaDocNode::RuntimeTypeIndex(), ExprPrecedence::kLambda},
+      {TupleDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+      {ListDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+      {DictDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+  };
+
+  if (const auto* op_doc = doc.as<OperationDocNode>()) {
+    size_t kind = static_cast<int>(op_doc->kind);
+    ICHECK_LT(kind, op_kind_precedence.size()) << "ValueError: Invalid 
operation: " << kind;
+    ExprPrecedence precedence = op_kind_precedence[kind];
+    ICHECK(precedence != ExprPrecedence::kUnkown)
+        << "Precedence for operator " << static_cast<int>(op_doc->kind) << " 
is unknown";
+    return precedence;
+  }
+  auto it = doc_type_precedence.find(doc->type_index());
+  if (it != doc_type_precedence.end()) {
+    return it->second;
+  }
+  ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is 
unknown";
+  throw;
+}
+
 class PythonDocPrinter : public DocPrinter {
  public:
   explicit PythonDocPrinter(int indent_spaces = 4) : DocPrinter(indent_spaces) 
{}
@@ -98,6 +203,42 @@ class PythonDocPrinter : public DocPrinter {
     }
   }
 
+  /*!
+   * \brief Print expression and add parenthesis if needed.
+   */
+  void PrintChildExpr(const ExprDoc& doc, ExprPrecedence parent_precedence,
+                      bool parenthesis_for_same_precedence = false) {
+    ExprPrecedence doc_precedence = GetExprPrecedence(doc);
+    if (doc_precedence < parent_precedence ||
+        (parenthesis_for_same_precedence && doc_precedence == 
parent_precedence)) {
+      output_ << "(";
+      PrintDoc(doc);
+      output_ << ")";
+    } else {
+      PrintDoc(doc);
+    }
+  }
+
+  /*!
+   * \brief Print expression and add parenthesis if doc has lower precedence 
than parent.
+   */
+  void PrintChildExpr(const ExprDoc& doc, const ExprDoc& parent,
+                      bool parenthesis_for_same_precedence = false) {
+    ExprPrecedence parent_precedence = GetExprPrecedence(parent);
+    return PrintChildExpr(doc, parent_precedence, 
parenthesis_for_same_precedence);
+  }
+
+  /*!
+   * \brief Print expression and add parenthesis if doc doesn't have higher 
precedence than parent.
+   *
+   * This function should be used to print an child expression that needs to 
be wrapped
+   * by parenthesis even if it has the same precedence as its parent, e.g., 
the `b` in `a + b`
+   * and the `b` and `c` in `a if b else c`.
+   */
+  void PrintChildExprConservatively(const ExprDoc& doc, const ExprDoc& parent) 
{
+    PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence=*/true);
+  }
+
   void MaybePrintCommentInline(const StmtDoc& stmt) {
     if (stmt->comment.defined()) {
       const std::string& comment = stmt->comment.value();
@@ -161,12 +302,12 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& 
doc) {
 void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; 
}
 
 void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) {
-  PrintDoc(doc->value);
+  PrintChildExpr(doc->value, doc);
   output_ << "." << doc->name;
 }
 
 void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) {
-  PrintDoc(doc->value);
+  PrintChildExpr(doc->value, doc);
   if (doc->indices.size() == 0) {
     output_ << "[()]";
   } else {
@@ -226,21 +367,30 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& 
doc) {
     // Unary Operators
     ICHECK_EQ(doc->operands.size(), 1);
     output_ << OperatorToString(doc->kind);
-    PrintDoc(doc->operands[0]);
+    PrintChildExpr(doc->operands[0], doc);
+  } else if (doc->kind == OpKind::kPow) {
+    // Power operator is different than other binary operators
+    // It's right-associative and binds less tightly than unary operator on 
its right.
+    // https://docs.python.org/3/reference/expressions.html#the-power-operator
+    // https://docs.python.org/3/reference/expressions.html#operator-precedence
+    ICHECK_EQ(doc->operands.size(), 2);
+    PrintChildExprConservatively(doc->operands[0], doc);
+    output_ << " ** ";
+    PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary);
   } else if (doc->kind < OpKind::kBinaryEnd) {
     // Binary Operator
     ICHECK_EQ(doc->operands.size(), 2);
-    PrintDoc(doc->operands[0]);
+    PrintChildExpr(doc->operands[0], doc);
     output_ << " " << OperatorToString(doc->kind) << " ";
-    PrintDoc(doc->operands[1]);
+    PrintChildExprConservatively(doc->operands[1], doc);
   } else if (doc->kind == OpKind::kIfThenElse) {
     ICHECK_EQ(doc->operands.size(), 3)
         << "ValueError: IfThenElse requires 3 operands, but got " << 
doc->operands.size();
-    PrintDoc(doc->operands[1]);
+    PrintChildExpr(doc->operands[1], doc);
     output_ << " if ";
-    PrintDoc(doc->operands[0]);
+    PrintChildExprConservatively(doc->operands[0], doc);
     output_ << " else ";
-    PrintDoc(doc->operands[2]);
+    PrintChildExprConservatively(doc->operands[2], doc);
   } else {
     LOG(FATAL) << "Unknown OperationDocNode::Kind " << 
static_cast<int>(doc->kind);
     throw;
@@ -248,7 +398,7 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& 
doc) {
 }
 
 void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) {
-  PrintDoc(doc->callee);
+  PrintChildExpr(doc->callee, doc);
 
   output_ << "(";
 
@@ -285,7 +435,7 @@ void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) {
   output_ << "lambda ";
   PrintJoinedDocs(doc->args, ", ");
   output_ << ": ";
-  PrintDoc(doc->body);
+  PrintChildExpr(doc->body, doc);
 }
 
 void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) {
@@ -444,8 +594,10 @@ void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& 
doc) {
   PrintJoinedDocs(doc->args, ", ");
   output_ << ")";
 
-  output_ << " -> ";
-  PrintDoc(doc->return_type);
+  if (doc->return_type.defined()) {
+    output_ << " -> ";
+    PrintDoc(doc->return_type.value());
+  }
 
   output_ << ":";
 
diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py 
b/tests/python/unittest/test_tvmscript_printer_doc.py
index 040a829010..f27bc71b66 100644
--- a/tests/python/unittest/test_tvmscript_printer_doc.py
+++ b/tests/python/unittest/test_tvmscript_printer_doc.py
@@ -21,30 +21,31 @@ Doc objects, then access and modify their attributes 
correctly.
 
 import pytest
 
+import tvm
 from tvm.script.printer.doc import (
-    LiteralDoc,
-    IdDoc,
+    AssertDoc,
+    AssignDoc,
     AttrAccessDoc,
-    IndexDoc,
     CallDoc,
-    OperationKind,
-    OperationDoc,
+    ClassDoc,
+    DictDoc,
+    ExprStmtDoc,
+    ForDoc,
+    FunctionDoc,
+    IdDoc,
+    IfDoc,
+    IndexDoc,
     LambdaDoc,
-    TupleDoc,
     ListDoc,
-    DictDoc,
+    LiteralDoc,
+    OperationDoc,
+    OperationKind,
+    ReturnDoc,
+    ScopeDoc,
     SliceDoc,
     StmtBlockDoc,
-    AssignDoc,
-    IfDoc,
+    TupleDoc,
     WhileDoc,
-    ForDoc,
-    ScopeDoc,
-    ExprStmtDoc,
-    AssertDoc,
-    ReturnDoc,
-    FunctionDoc,
-    ClassDoc,
 )
 
 
@@ -450,6 +451,13 @@ def test_return_doc():
         [IdDoc("test"), IdDoc("test2")],
     ],
 )
[email protected](
+    "return_type",
+    [
+        None,
+        LiteralDoc(None),
+    ],
+)
 @pytest.mark.parametrize(
     "body",
     [
@@ -458,9 +466,8 @@ def test_return_doc():
         [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
     ],
 )
-def test_function_doc(args, decorators, body):
+def test_function_doc(args, decorators, return_type, body):
     name = IdDoc("name")
-    return_type = LiteralDoc(None)
 
     doc = FunctionDoc(name, args, decorators, return_type, body)
 
@@ -504,3 +511,7 @@ def test_stmt_doc_comment():
     comment = "test comment"
     doc.comment = comment
     assert doc.comment == comment
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py 
b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
index 523f62d8b5..e0905cc145 100644
--- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
+++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
@@ -14,9 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import pytest
 import itertools
 
+import pytest
+
+import tvm
 from tvm.script.printer.doc import (
     AssertDoc,
     AssignDoc,
@@ -701,29 +703,32 @@ def test_print_return_doc(value, expected):
 
 
 @pytest.mark.parametrize(
-    "args, decorators, body, expected",
+    "args, decorators, return_type, body, expected",
     [
         (
             [],
             [],
+            None,
             [],
             """
-            def func() -> None:
+            def func():
                 pass
             """,
         ),
         (
             [AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int"))],
             [],
+            IdDoc("int"),
             [],
             """
-            def func(x: int) -> None:
+            def func(x: int) -> int:
                 pass
             """,
         ),
         (
             [AssignDoc(IdDoc("x"), rhs=LiteralDoc(1), 
annotation=IdDoc("int"))],
             [],
+            LiteralDoc(None),
             [],
             """
             def func(x: int = 1) -> None:
@@ -733,6 +738,7 @@ def test_print_return_doc(value, expected):
         (
             [],
             [IdDoc("wrap")],
+            LiteralDoc(None),
             [],
             """
             @wrap
@@ -743,6 +749,7 @@ def test_print_return_doc(value, expected):
         (
             [],
             [IdDoc("wrap_outter"), IdDoc("wrap_inner")],
+            LiteralDoc(None),
             [],
             """
             @wrap_outter
@@ -757,6 +764,7 @@ def test_print_return_doc(value, expected):
                 AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), 
annotation=IdDoc("int")),
             ],
             [IdDoc("wrap")],
+            LiteralDoc(None),
             [],
             """
             @wrap
@@ -770,6 +778,7 @@ def test_print_return_doc(value, expected):
                 AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), 
annotation=IdDoc("int")),
             ],
             [IdDoc("wrap")],
+            LiteralDoc(None),
             [
                 AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Add, 
[IdDoc("x"), LiteralDoc(1)])),
                 AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Sub, 
[IdDoc("y"), LiteralDoc(1)])),
@@ -784,8 +793,8 @@ def test_print_return_doc(value, expected):
     ],
     ids=itertools.count(),
 )
-def test_print_function_doc(args, decorators, body, expected):
-    doc = FunctionDoc(IdDoc("func"), args, decorators, LiteralDoc(None), body)
+def test_print_function_doc(args, decorators, body, return_type, expected):
+    doc = FunctionDoc(IdDoc("func"), args, decorators, return_type, body)
     assert to_python_script(doc) == format_script(expected)  # test
 
 
@@ -1038,3 +1047,297 @@ def test_print_invalid_multiline_doc_comment(doc):
     with pytest.raises(ValueError) as e:
         to_python_script(doc)
     assert "cannot have newline" in str(e.value)
+
+
+def generate_expr_precedence_test_cases():
+    x = IdDoc("x")
+    y = IdDoc("y")
+    z = IdDoc("z")
+
+    def negative(a):
+        return OperationDoc(OperationKind.USub, [a])
+
+    def invert(a):
+        return OperationDoc(OperationKind.Invert, [a])
+
+    def add(a, b):
+        return OperationDoc(OperationKind.Add, [a, b])
+
+    def sub(a, b):
+        return OperationDoc(OperationKind.Sub, [a, b])
+
+    def mult(a, b):
+        return OperationDoc(OperationKind.Mult, [a, b])
+
+    def div(a, b):
+        return OperationDoc(OperationKind.Div, [a, b])
+
+    def mod(a, b):
+        return OperationDoc(OperationKind.Mod, [a, b])
+
+    def pow(a, b):
+        return OperationDoc(OperationKind.Pow, [a, b])
+
+    def lshift(a, b):
+        return OperationDoc(OperationKind.LShift, [a, b])
+
+    def bit_and(a, b):
+        return OperationDoc(OperationKind.BitAnd, [a, b])
+
+    def bit_or(a, b):
+        return OperationDoc(OperationKind.BitOr, [a, b])
+
+    def bit_xor(a, b):
+        return OperationDoc(OperationKind.BitXor, [a, b])
+
+    def lt(a, b):
+        return OperationDoc(OperationKind.Lt, [a, b])
+
+    def eq(a, b):
+        return OperationDoc(OperationKind.Eq, [a, b])
+
+    def not_eq(a, b):
+        return OperationDoc(OperationKind.NotEq, [a, b])
+
+    def if_then_else(a, b, c):
+        return OperationDoc(OperationKind.IfThenElse, [a, b, c])
+
+    test_cases = {
+        "attr-call-index": [
+            (
+                add(x, y).attr("test"),
+                "(x + y).test",
+            ),
+            (
+                add(x, y.attr("test")),
+                "x + y.test",
+            ),
+            (
+                x[z].call(y),
+                "x[z](y)",
+            ),
+            (
+                x.call(y)[z],
+                "x(y)[z]",
+            ),
+            (
+                x.call(y).call(z),
+                "x(y)(z)",
+            ),
+            (
+                x.call(y).attr("test"),
+                "x(y).test",
+            ),
+            (
+                x.attr("test").call(y),
+                "x.test(y)",
+            ),
+            (
+                x.attr("test").attr("test2"),
+                "x.test.test2",
+            ),
+            (
+                LambdaDoc([x], x).call(y),
+                "(lambda x: x)(y)",
+            ),
+            (
+                add(x, y)[z][add(z, z)].attr("name"),
+                "(x + y)[z][z + z].name",
+            ),
+        ],
+        "power": [
+            (
+                pow(pow(x, y), z),
+                "(x ** y) ** z",
+            ),
+            (
+                pow(x, pow(y, z)),
+                "x ** y ** z",
+            ),
+            (
+                pow(negative(x), negative(y)),
+                "(-x) ** -y",
+            ),
+            (
+                pow(add(x, y), add(y, z)),
+                "(x + y) ** (y + z)",
+            ),
+        ],
+        "unary": [
+            (
+                invert(negative(y)),
+                "~-y",
+            ),
+            (
+                negative(y).attr("test"),
+                "(-y).test",
+            ),
+            (
+                negative(y.attr("test")),
+                "-y.test",
+            ),
+            (
+                mult(negative(x), negative(y)),
+                "-x * -y",
+            ),
+            (
+                negative(add(invert(x), negative(y))),
+                "-(~x + -y)",
+            ),
+        ],
+        "add-mult": [
+            (
+                mult(x, mult(y, z)),
+                "x * (y * z)",
+            ),
+            (
+                mult(mult(x, y), z),
+                "x * y * z",
+            ),
+            (
+                mult(x, add(y, z)),
+                "x * (y + z)",
+            ),
+            (
+                mult(add(y, z), x),
+                "(y + z) * x",
+            ),
+            (
+                add(x, mod(y, z)),
+                "x + y % z",
+            ),
+            (
+                add(mult(y, z), x),
+                "y * z + x",
+            ),
+            (
+                add(add(x, y), add(y, z)),
+                "x + y + (y + z)",
+            ),
+            (
+                div(add(x, y), add(y, z)),
+                "(x + y) / (y + z)",
+            ),
+        ],
+        "shift": [
+            (
+                div(x, lshift(y, z)),
+                "x / (y << z)",
+            ),
+            (
+                mult(lshift(y, z), x),
+                "(y << z) * x",
+            ),
+            (
+                lshift(x, mult(y, z)),
+                "x << y * z",
+            ),
+            (
+                lshift(mult(x, y), z),
+                "x * y << z",
+            ),
+            (
+                lshift(mult(x, y), z),
+                "x * y << z",
+            ),
+            (
+                lshift(lshift(x, y), z),
+                "x << y << z",
+            ),
+            (
+                lshift(x, lshift(y, z)),
+                "x << (y << z)",
+            ),
+        ],
+        "bitwise": [
+            (
+                add(bit_or(x, y), bit_or(y, z)),
+                "(x | y) + (y | z)",
+            ),
+            (
+                bit_and(bit_or(x, y), bit_or(y, z)),
+                "(x | y) & (y | z)",
+            ),
+            (
+                bit_or(bit_and(x, y), bit_and(y, z)),
+                "x & y | y & z",
+            ),
+            (
+                bit_and(bit_xor(x, bit_or(y, z)), z),
+                "(x ^ (y | z)) & z",
+            ),
+        ],
+        "comparison": [
+            (
+                not_eq(add(x, y), z),
+                "x + y != z",
+            ),
+            (
+                eq(pow(x, y), z),
+                "x ** y == z",
+            ),
+            (
+                lt(x, div(y, z)),
+                "x < y / z",
+            ),
+            (
+                lt(x, if_then_else(y, y, y)),
+                "x < (y if y else y)",
+            ),
+        ],
+        "if-then-else": [
+            (
+                if_then_else(x, if_then_else(y, y, y), z),
+                "y if y else y if x else z",
+            ),
+            (
+                if_then_else(if_then_else(x, x, x), y, z),
+                "y if (x if x else x) else z",
+            ),
+            (
+                if_then_else(x, y, if_then_else(z, z, z)),
+                "y if x else (z if z else z)",
+            ),
+            (
+                if_then_else(lt(x, x), add(y, y), mult(z, z)),
+                "y + y if x < x else z * z",
+            ),
+            (
+                if_then_else(LambdaDoc([x], x), LambdaDoc([y], y), 
LambdaDoc([z], z)),
+                "(lambda y: y) if (lambda x: x) else (lambda z: z)",
+            ),
+        ],
+        "lambda": [
+            (
+                LambdaDoc([x, y], add(z, z)),
+                "lambda x, y: z + z",
+            ),
+            (
+                add(LambdaDoc([x, y], z), z),
+                "(lambda x, y: z) + z",
+            ),
+            (
+                LambdaDoc([x, y], add(z, z)).call(x, y),
+                "(lambda x, y: z + z)(x, y)",
+            ),
+            (
+                LambdaDoc([x], LambdaDoc([y], z)),
+                "lambda x: lambda y: z",
+            ),
+        ],
+    }
+
+    return [
+        pytest.param(*args, id=f"{group_name}-{i}")
+        for group_name, cases in test_cases.items()
+        for i, args in enumerate(cases)
+    ]
+
+
[email protected]("doc, expected", 
generate_expr_precedence_test_cases())
+def test_expr_precedence(doc, expected):
+    assert to_python_script(doc) == format_script(expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to