This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 834e998618 [TVMScript] Python Expression Precedence (#12148)
834e998618 is described below
commit 834e998618addb141e5a8b69f918ce5594e752cd
Author: Lite Ye <[email protected]>
AuthorDate: Mon Aug 1 02:45:44 2022 -0400
[TVMScript] Python Expression Precedence (#12148)
This PR:
- Handle expression (operator) precedence during Python code printing (`(*
1 (+ 2 3))` prints as
`1 * (2 + 3)`)
- Addresses remaining feedback from previous PR #12112
- Reformats Python import with isort
Tracking issue: #11912
---
include/tvm/script/printer/doc.h | 4 +-
python/tvm/script/printer/doc.py | 4 +-
src/script/printer/doc.cc | 4 +-
src/script/printer/python_doc_printer.cc | 176 +++++++++++-
.../python/unittest/test_tvmscript_printer_doc.py | 47 +--
.../test_tvmscript_printer_python_doc_printer.py | 315 ++++++++++++++++++++-
6 files changed, 508 insertions(+), 42 deletions(-)
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index e3dd83743e..408c703d54 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -1067,7 +1067,7 @@ class FunctionDocNode : public StmtDocNode {
/*! \brief Decorators of function. */
Array<ExprDoc> decorators;
/*! \brief The return type of function. */
- ExprDoc return_type{nullptr};
+ Optional<ExprDoc> return_type{NullOpt};
/*! \brief The body of function. */
Array<StmtDoc> body;
@@ -1100,7 +1100,7 @@ class FunctionDoc : public StmtDoc {
* \param body The body of function.
*/
explicit FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc>
decorators,
- ExprDoc return_type, Array<StmtDoc> body);
+ Optional<ExprDoc> return_type, Array<StmtDoc> body);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionDoc, StmtDoc,
FunctionDocNode);
};
diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py
index 747ffc42f1..0a5fde8975 100644
--- a/python/tvm/script/printer/doc.py
+++ b/python/tvm/script/printer/doc.py
@@ -439,7 +439,7 @@ class FunctionDoc(StmtDoc):
name: IdDoc
args: Sequence[AssignDoc]
decorators: Sequence[ExprDoc]
- return_type: ExprDoc
+ return_type: Optional[ExprDoc]
body: Sequence[StmtDoc]
def __init__(
@@ -447,7 +447,7 @@ class FunctionDoc(StmtDoc):
name: IdDoc,
args: List[AssignDoc],
decorators: List[ExprDoc],
- return_type: ExprDoc,
+ return_type: Optional[ExprDoc],
body: List[StmtDoc],
):
self.__init_handle_by_constructor__(
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index bfff0cfad4..2334d1fad5 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -198,7 +198,7 @@ ReturnDoc::ReturnDoc(ExprDoc value) {
}
FunctionDoc::FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc>
decorators,
- ExprDoc return_type, Array<StmtDoc> body) {
+ Optional<ExprDoc> return_type, Array<StmtDoc> body) {
ObjectPtr<FunctionDocNode> n = make_object<FunctionDocNode>();
n->name = name;
n->args = args;
@@ -345,7 +345,7 @@
TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value)
TVM_REGISTER_NODE_TYPE(FunctionDocNode);
TVM_REGISTER_GLOBAL("script.printer.FunctionDoc")
.set_body_typed([](IdDoc name, Array<AssignDoc> args, Array<ExprDoc>
decorators,
- ExprDoc return_type, Array<StmtDoc> body) {
+ Optional<ExprDoc> return_type, Array<StmtDoc> body) {
return FunctionDoc(name, args, decorators, return_type, body);
});
diff --git a/src/script/printer/python_doc_printer.cc
b/src/script/printer/python_doc_printer.cc
index f44577ff80..536c57abd9 100644
--- a/src/script/printer/python_doc_printer.cc
+++ b/src/script/printer/python_doc_printer.cc
@@ -31,6 +31,111 @@ namespace tvm {
namespace script {
namespace printer {
+/*!
+ * \brief Operator precedence
+ *
+ * This is based on
+ * https://docs.python.org/3/reference/expressions.html#operator-precedence
+ */
+enum class ExprPrecedence : int32_t {
+ /*! \brief Unknown precedence */
+ kUnkown = 0,
+ /*! \brief Lambda Expression */
+ kLambda = 1,
+ /*! \brief Conditional Expression */
+ kIfThenElse = 2,
+ /*! \brief Boolean OR */
+ kBooleanOr = 3,
+ /*! \brief Boolean AND */
+ kBooleanAnd = 4,
+ /*! \brief Boolean NOT */
+ kBooleanNot = 5,
+ /*! \brief Comparisons */
+ kComparison = 6,
+ /*! \brief Bitwise OR */
+ kBitwiseOr = 7,
+ /*! \brief Bitwise XOR */
+ kBitwiseXor = 8,
+ /*! \brief Bitwise AND */
+ kBitwiseAnd = 9,
+ /*! \brief Shift Operators */
+ kShift = 10,
+ /*! \brief Addition and subtraction */
+ kAdd = 11,
+ /*! \brief Multiplication, division, floor division, remainder */
+ kMult = 12,
+ /*! \brief Positive negative and bitwise NOT */
+ kUnary = 13,
+ /*! \brief Exponentiation */
+ kExp = 14,
+ /*! \brief Index access, attribute access, call and atom expression */
+ kIdentity = 15,
+};
+
+ExprPrecedence GetExprPrecedence(const ExprDoc& doc) {
+ // Key is the value of OperationDocNode::Kind
+ static const std::vector<ExprPrecedence> op_kind_precedence = []() {
+ using OpKind = OperationDocNode::Kind;
+ std::map<OpKind, ExprPrecedence> raw_table = {
+ {OpKind::kUSub, ExprPrecedence::kUnary},
+ {OpKind::kInvert, ExprPrecedence::kUnary},
+ {OpKind::kAdd, ExprPrecedence::kAdd},
+ {OpKind::kSub, ExprPrecedence::kAdd},
+ {OpKind::kMult, ExprPrecedence::kMult},
+ {OpKind::kDiv, ExprPrecedence::kMult},
+ {OpKind::kFloorDiv, ExprPrecedence::kMult},
+ {OpKind::kMod, ExprPrecedence::kMult},
+ {OpKind::kPow, ExprPrecedence::kExp},
+ {OpKind::kLShift, ExprPrecedence::kShift},
+ {OpKind::kRShift, ExprPrecedence::kShift},
+ {OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd},
+ {OpKind::kBitOr, ExprPrecedence::kBitwiseOr},
+ {OpKind::kBitXor, ExprPrecedence::kBitwiseXor},
+ {OpKind::kLt, ExprPrecedence::kComparison},
+ {OpKind::kLtE, ExprPrecedence::kComparison},
+ {OpKind::kEq, ExprPrecedence::kComparison},
+ {OpKind::kNotEq, ExprPrecedence::kComparison},
+ {OpKind::kGt, ExprPrecedence::kComparison},
+ {OpKind::kGtE, ExprPrecedence::kComparison},
+ {OpKind::kIfThenElse, ExprPrecedence::kIfThenElse},
+ };
+ int n = static_cast<int>(OpKind::kSpecialEnd);
+ std::vector<ExprPrecedence> table(n + 1, ExprPrecedence::kUnkown);
+ for (const auto& kv : raw_table) {
+ table[static_cast<int>(kv.first)] = kv.second;
+ }
+ return table;
+ }();
+
+ // Key is the type index of Doc
+ static const std::unordered_map<uint32_t, ExprPrecedence>
doc_type_precedence = {
+ {LiteralDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+ {IdDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+ {AttrAccessDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+ {IndexDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+ {CallDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+ {LambdaDocNode::RuntimeTypeIndex(), ExprPrecedence::kLambda},
+ {TupleDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+ {ListDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+ {DictDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
+ };
+
+ if (const auto* op_doc = doc.as<OperationDocNode>()) {
+ size_t kind = static_cast<int>(op_doc->kind);
+ ICHECK_LT(kind, op_kind_precedence.size()) << "ValueError: Invalid
operation: " << kind;
+ ExprPrecedence precedence = op_kind_precedence[kind];
+ ICHECK(precedence != ExprPrecedence::kUnkown)
+ << "Precedence for operator " << static_cast<int>(op_doc->kind) << "
is unknown";
+ return precedence;
+ }
+ auto it = doc_type_precedence.find(doc->type_index());
+ if (it != doc_type_precedence.end()) {
+ return it->second;
+ }
+ ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is
unknown";
+ throw;
+}
+
class PythonDocPrinter : public DocPrinter {
public:
explicit PythonDocPrinter(int indent_spaces = 4) : DocPrinter(indent_spaces)
{}
@@ -98,6 +203,42 @@ class PythonDocPrinter : public DocPrinter {
}
}
+ /*!
+ * \brief Print expression and add parenthesis if needed.
+ */
+ void PrintChildExpr(const ExprDoc& doc, ExprPrecedence parent_precedence,
+ bool parenthesis_for_same_precedence = false) {
+ ExprPrecedence doc_precedence = GetExprPrecedence(doc);
+ if (doc_precedence < parent_precedence ||
+ (parenthesis_for_same_precedence && doc_precedence ==
parent_precedence)) {
+ output_ << "(";
+ PrintDoc(doc);
+ output_ << ")";
+ } else {
+ PrintDoc(doc);
+ }
+ }
+
+ /*!
+ * \brief Print expression and add parenthesis if doc has lower precedence
than parent.
+ */
+ void PrintChildExpr(const ExprDoc& doc, const ExprDoc& parent,
+ bool parenthesis_for_same_precedence = false) {
+ ExprPrecedence parent_precedence = GetExprPrecedence(parent);
+ return PrintChildExpr(doc, parent_precedence,
parenthesis_for_same_precedence);
+ }
+
+ /*!
+ * \brief Print expression and add parenthesis if doc doesn't have higher
precedence than parent.
+ *
+ * This function should be used to print an child expression that needs to
be wrapped
+ * by parenthesis even if it has the same precedence as its parent, e.g.,
the `b` in `a + b`
+ * and the `b` and `c` in `a if b else c`.
+ */
+ void PrintChildExprConservatively(const ExprDoc& doc, const ExprDoc& parent)
{
+ PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence=*/true);
+ }
+
void MaybePrintCommentInline(const StmtDoc& stmt) {
if (stmt->comment.defined()) {
const std::string& comment = stmt->comment.value();
@@ -161,12 +302,12 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc&
doc) {
void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name;
}
void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) {
- PrintDoc(doc->value);
+ PrintChildExpr(doc->value, doc);
output_ << "." << doc->name;
}
void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) {
- PrintDoc(doc->value);
+ PrintChildExpr(doc->value, doc);
if (doc->indices.size() == 0) {
output_ << "[()]";
} else {
@@ -226,21 +367,30 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc&
doc) {
// Unary Operators
ICHECK_EQ(doc->operands.size(), 1);
output_ << OperatorToString(doc->kind);
- PrintDoc(doc->operands[0]);
+ PrintChildExpr(doc->operands[0], doc);
+ } else if (doc->kind == OpKind::kPow) {
+ // Power operator is different than other binary operators
+ // It's right-associative and binds less tightly than unary operator on
its right.
+ // https://docs.python.org/3/reference/expressions.html#the-power-operator
+ // https://docs.python.org/3/reference/expressions.html#operator-precedence
+ ICHECK_EQ(doc->operands.size(), 2);
+ PrintChildExprConservatively(doc->operands[0], doc);
+ output_ << " ** ";
+ PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary);
} else if (doc->kind < OpKind::kBinaryEnd) {
// Binary Operator
ICHECK_EQ(doc->operands.size(), 2);
- PrintDoc(doc->operands[0]);
+ PrintChildExpr(doc->operands[0], doc);
output_ << " " << OperatorToString(doc->kind) << " ";
- PrintDoc(doc->operands[1]);
+ PrintChildExprConservatively(doc->operands[1], doc);
} else if (doc->kind == OpKind::kIfThenElse) {
ICHECK_EQ(doc->operands.size(), 3)
<< "ValueError: IfThenElse requires 3 operands, but got " <<
doc->operands.size();
- PrintDoc(doc->operands[1]);
+ PrintChildExpr(doc->operands[1], doc);
output_ << " if ";
- PrintDoc(doc->operands[0]);
+ PrintChildExprConservatively(doc->operands[0], doc);
output_ << " else ";
- PrintDoc(doc->operands[2]);
+ PrintChildExprConservatively(doc->operands[2], doc);
} else {
LOG(FATAL) << "Unknown OperationDocNode::Kind " <<
static_cast<int>(doc->kind);
throw;
@@ -248,7 +398,7 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc&
doc) {
}
void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) {
- PrintDoc(doc->callee);
+ PrintChildExpr(doc->callee, doc);
output_ << "(";
@@ -285,7 +435,7 @@ void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) {
output_ << "lambda ";
PrintJoinedDocs(doc->args, ", ");
output_ << ": ";
- PrintDoc(doc->body);
+ PrintChildExpr(doc->body, doc);
}
void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) {
@@ -444,8 +594,10 @@ void PythonDocPrinter::PrintTypedDoc(const FunctionDoc&
doc) {
PrintJoinedDocs(doc->args, ", ");
output_ << ")";
- output_ << " -> ";
- PrintDoc(doc->return_type);
+ if (doc->return_type.defined()) {
+ output_ << " -> ";
+ PrintDoc(doc->return_type.value());
+ }
output_ << ":";
diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py
b/tests/python/unittest/test_tvmscript_printer_doc.py
index 040a829010..f27bc71b66 100644
--- a/tests/python/unittest/test_tvmscript_printer_doc.py
+++ b/tests/python/unittest/test_tvmscript_printer_doc.py
@@ -21,30 +21,31 @@ Doc objects, then access and modify their attributes
correctly.
import pytest
+import tvm
from tvm.script.printer.doc import (
- LiteralDoc,
- IdDoc,
+ AssertDoc,
+ AssignDoc,
AttrAccessDoc,
- IndexDoc,
CallDoc,
- OperationKind,
- OperationDoc,
+ ClassDoc,
+ DictDoc,
+ ExprStmtDoc,
+ ForDoc,
+ FunctionDoc,
+ IdDoc,
+ IfDoc,
+ IndexDoc,
LambdaDoc,
- TupleDoc,
ListDoc,
- DictDoc,
+ LiteralDoc,
+ OperationDoc,
+ OperationKind,
+ ReturnDoc,
+ ScopeDoc,
SliceDoc,
StmtBlockDoc,
- AssignDoc,
- IfDoc,
+ TupleDoc,
WhileDoc,
- ForDoc,
- ScopeDoc,
- ExprStmtDoc,
- AssertDoc,
- ReturnDoc,
- FunctionDoc,
- ClassDoc,
)
@@ -450,6 +451,13 @@ def test_return_doc():
[IdDoc("test"), IdDoc("test2")],
],
)
[email protected](
+ "return_type",
+ [
+ None,
+ LiteralDoc(None),
+ ],
+)
@pytest.mark.parametrize(
"body",
[
@@ -458,9 +466,8 @@ def test_return_doc():
[ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
],
)
-def test_function_doc(args, decorators, body):
+def test_function_doc(args, decorators, return_type, body):
name = IdDoc("name")
- return_type = LiteralDoc(None)
doc = FunctionDoc(name, args, decorators, return_type, body)
@@ -504,3 +511,7 @@ def test_stmt_doc_comment():
comment = "test comment"
doc.comment = comment
assert doc.comment == comment
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
index 523f62d8b5..e0905cc145 100644
--- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
+++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
@@ -14,9 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import pytest
import itertools
+import pytest
+
+import tvm
from tvm.script.printer.doc import (
AssertDoc,
AssignDoc,
@@ -701,29 +703,32 @@ def test_print_return_doc(value, expected):
@pytest.mark.parametrize(
- "args, decorators, body, expected",
+ "args, decorators, return_type, body, expected",
[
(
[],
[],
+ None,
[],
"""
- def func() -> None:
+ def func():
pass
""",
),
(
[AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int"))],
[],
+ IdDoc("int"),
[],
"""
- def func(x: int) -> None:
+ def func(x: int) -> int:
pass
""",
),
(
[AssignDoc(IdDoc("x"), rhs=LiteralDoc(1),
annotation=IdDoc("int"))],
[],
+ LiteralDoc(None),
[],
"""
def func(x: int = 1) -> None:
@@ -733,6 +738,7 @@ def test_print_return_doc(value, expected):
(
[],
[IdDoc("wrap")],
+ LiteralDoc(None),
[],
"""
@wrap
@@ -743,6 +749,7 @@ def test_print_return_doc(value, expected):
(
[],
[IdDoc("wrap_outter"), IdDoc("wrap_inner")],
+ LiteralDoc(None),
[],
"""
@wrap_outter
@@ -757,6 +764,7 @@ def test_print_return_doc(value, expected):
AssignDoc(IdDoc("y"), rhs=LiteralDoc(1),
annotation=IdDoc("int")),
],
[IdDoc("wrap")],
+ LiteralDoc(None),
[],
"""
@wrap
@@ -770,6 +778,7 @@ def test_print_return_doc(value, expected):
AssignDoc(IdDoc("y"), rhs=LiteralDoc(1),
annotation=IdDoc("int")),
],
[IdDoc("wrap")],
+ LiteralDoc(None),
[
AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Add,
[IdDoc("x"), LiteralDoc(1)])),
AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Sub,
[IdDoc("y"), LiteralDoc(1)])),
@@ -784,8 +793,8 @@ def test_print_return_doc(value, expected):
],
ids=itertools.count(),
)
-def test_print_function_doc(args, decorators, body, expected):
- doc = FunctionDoc(IdDoc("func"), args, decorators, LiteralDoc(None), body)
+def test_print_function_doc(args, decorators, body, return_type, expected):
+ doc = FunctionDoc(IdDoc("func"), args, decorators, return_type, body)
assert to_python_script(doc) == format_script(expected) # test
@@ -1038,3 +1047,297 @@ def test_print_invalid_multiline_doc_comment(doc):
with pytest.raises(ValueError) as e:
to_python_script(doc)
assert "cannot have newline" in str(e.value)
+
+
+def generate_expr_precedence_test_cases():
+ x = IdDoc("x")
+ y = IdDoc("y")
+ z = IdDoc("z")
+
+ def negative(a):
+ return OperationDoc(OperationKind.USub, [a])
+
+ def invert(a):
+ return OperationDoc(OperationKind.Invert, [a])
+
+ def add(a, b):
+ return OperationDoc(OperationKind.Add, [a, b])
+
+ def sub(a, b):
+ return OperationDoc(OperationKind.Sub, [a, b])
+
+ def mult(a, b):
+ return OperationDoc(OperationKind.Mult, [a, b])
+
+ def div(a, b):
+ return OperationDoc(OperationKind.Div, [a, b])
+
+ def mod(a, b):
+ return OperationDoc(OperationKind.Mod, [a, b])
+
+ def pow(a, b):
+ return OperationDoc(OperationKind.Pow, [a, b])
+
+ def lshift(a, b):
+ return OperationDoc(OperationKind.LShift, [a, b])
+
+ def bit_and(a, b):
+ return OperationDoc(OperationKind.BitAnd, [a, b])
+
+ def bit_or(a, b):
+ return OperationDoc(OperationKind.BitOr, [a, b])
+
+ def bit_xor(a, b):
+ return OperationDoc(OperationKind.BitXor, [a, b])
+
+ def lt(a, b):
+ return OperationDoc(OperationKind.Lt, [a, b])
+
+ def eq(a, b):
+ return OperationDoc(OperationKind.Eq, [a, b])
+
+ def not_eq(a, b):
+ return OperationDoc(OperationKind.NotEq, [a, b])
+
+ def if_then_else(a, b, c):
+ return OperationDoc(OperationKind.IfThenElse, [a, b, c])
+
+ test_cases = {
+ "attr-call-index": [
+ (
+ add(x, y).attr("test"),
+ "(x + y).test",
+ ),
+ (
+ add(x, y.attr("test")),
+ "x + y.test",
+ ),
+ (
+ x[z].call(y),
+ "x[z](y)",
+ ),
+ (
+ x.call(y)[z],
+ "x(y)[z]",
+ ),
+ (
+ x.call(y).call(z),
+ "x(y)(z)",
+ ),
+ (
+ x.call(y).attr("test"),
+ "x(y).test",
+ ),
+ (
+ x.attr("test").call(y),
+ "x.test(y)",
+ ),
+ (
+ x.attr("test").attr("test2"),
+ "x.test.test2",
+ ),
+ (
+ LambdaDoc([x], x).call(y),
+ "(lambda x: x)(y)",
+ ),
+ (
+ add(x, y)[z][add(z, z)].attr("name"),
+ "(x + y)[z][z + z].name",
+ ),
+ ],
+ "power": [
+ (
+ pow(pow(x, y), z),
+ "(x ** y) ** z",
+ ),
+ (
+ pow(x, pow(y, z)),
+ "x ** y ** z",
+ ),
+ (
+ pow(negative(x), negative(y)),
+ "(-x) ** -y",
+ ),
+ (
+ pow(add(x, y), add(y, z)),
+ "(x + y) ** (y + z)",
+ ),
+ ],
+ "unary": [
+ (
+ invert(negative(y)),
+ "~-y",
+ ),
+ (
+ negative(y).attr("test"),
+ "(-y).test",
+ ),
+ (
+ negative(y.attr("test")),
+ "-y.test",
+ ),
+ (
+ mult(negative(x), negative(y)),
+ "-x * -y",
+ ),
+ (
+ negative(add(invert(x), negative(y))),
+ "-(~x + -y)",
+ ),
+ ],
+ "add-mult": [
+ (
+ mult(x, mult(y, z)),
+ "x * (y * z)",
+ ),
+ (
+ mult(mult(x, y), z),
+ "x * y * z",
+ ),
+ (
+ mult(x, add(y, z)),
+ "x * (y + z)",
+ ),
+ (
+ mult(add(y, z), x),
+ "(y + z) * x",
+ ),
+ (
+ add(x, mod(y, z)),
+ "x + y % z",
+ ),
+ (
+ add(mult(y, z), x),
+ "y * z + x",
+ ),
+ (
+ add(add(x, y), add(y, z)),
+ "x + y + (y + z)",
+ ),
+ (
+ div(add(x, y), add(y, z)),
+ "(x + y) / (y + z)",
+ ),
+ ],
+ "shift": [
+ (
+ div(x, lshift(y, z)),
+ "x / (y << z)",
+ ),
+ (
+ mult(lshift(y, z), x),
+ "(y << z) * x",
+ ),
+ (
+ lshift(x, mult(y, z)),
+ "x << y * z",
+ ),
+ (
+ lshift(mult(x, y), z),
+ "x * y << z",
+ ),
+ (
+ lshift(mult(x, y), z),
+ "x * y << z",
+ ),
+ (
+ lshift(lshift(x, y), z),
+ "x << y << z",
+ ),
+ (
+ lshift(x, lshift(y, z)),
+ "x << (y << z)",
+ ),
+ ],
+ "bitwise": [
+ (
+ add(bit_or(x, y), bit_or(y, z)),
+ "(x | y) + (y | z)",
+ ),
+ (
+ bit_and(bit_or(x, y), bit_or(y, z)),
+ "(x | y) & (y | z)",
+ ),
+ (
+ bit_or(bit_and(x, y), bit_and(y, z)),
+ "x & y | y & z",
+ ),
+ (
+ bit_and(bit_xor(x, bit_or(y, z)), z),
+ "(x ^ (y | z)) & z",
+ ),
+ ],
+ "comparison": [
+ (
+ not_eq(add(x, y), z),
+ "x + y != z",
+ ),
+ (
+ eq(pow(x, y), z),
+ "x ** y == z",
+ ),
+ (
+ lt(x, div(y, z)),
+ "x < y / z",
+ ),
+ (
+ lt(x, if_then_else(y, y, y)),
+ "x < (y if y else y)",
+ ),
+ ],
+ "if-then-else": [
+ (
+ if_then_else(x, if_then_else(y, y, y), z),
+ "y if y else y if x else z",
+ ),
+ (
+ if_then_else(if_then_else(x, x, x), y, z),
+ "y if (x if x else x) else z",
+ ),
+ (
+ if_then_else(x, y, if_then_else(z, z, z)),
+ "y if x else (z if z else z)",
+ ),
+ (
+ if_then_else(lt(x, x), add(y, y), mult(z, z)),
+ "y + y if x < x else z * z",
+ ),
+ (
+ if_then_else(LambdaDoc([x], x), LambdaDoc([y], y),
LambdaDoc([z], z)),
+ "(lambda y: y) if (lambda x: x) else (lambda z: z)",
+ ),
+ ],
+ "lambda": [
+ (
+ LambdaDoc([x, y], add(z, z)),
+ "lambda x, y: z + z",
+ ),
+ (
+ add(LambdaDoc([x, y], z), z),
+ "(lambda x, y: z) + z",
+ ),
+ (
+ LambdaDoc([x, y], add(z, z)).call(x, y),
+ "(lambda x, y: z + z)(x, y)",
+ ),
+ (
+ LambdaDoc([x], LambdaDoc([y], z)),
+ "lambda x: lambda y: z",
+ ),
+ ],
+ }
+
+ return [
+ pytest.param(*args, id=f"{group_name}-{i}")
+ for group_name, cases in test_cases.items()
+ for i, args in enumerate(cases)
+ ]
+
+
[email protected]("doc, expected",
generate_expr_precedence_test_cases())
+def test_expr_precedence(doc, expected):
+ assert to_python_script(doc) == format_script(expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()