This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e25f2bb76 Python: Add more expression classes (#5258)
5e25f2bb76 is described below

commit 5e25f2bb7638cd05b0e1e15fc3824012c7df9dc6
Author: Nick Ouellet <[email protected]>
AuthorDate: Mon Jul 25 15:37:26 2022 -0400

    Python: Add more expression classes (#5258)
---
 python/pyiceberg/expressions/base.py              | 316 ++++++++++++++++---
 python/tests/expressions/test_expressions_base.py | 361 +++++++++++++++++-----
 2 files changed, 569 insertions(+), 108 deletions(-)

diff --git a/python/pyiceberg/expressions/base.py 
b/python/pyiceberg/expressions/base.py
index 876965ae71..4b4a487a4d 100644
--- a/python/pyiceberg/expressions/base.py
+++ b/python/pyiceberg/expressions/base.py
@@ -14,10 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from __future__ import annotations
+
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from functools import reduce, singledispatch
-from typing import Generic, Tuple, TypeVar
+from typing import Generic, TypeVar
 
 from pyiceberg.files import StructProtocol
 from pyiceberg.schema import Accessor, Schema
@@ -41,7 +43,7 @@ class Literal(Generic[T], ABC):
         return self._value  # type: ignore
 
     @abstractmethod
-    def to(self, type_var) -> "Literal":
+    def to(self, type_var) -> Literal:
         ...  # pragma: no cover
 
     def __repr__(self):
@@ -73,7 +75,7 @@ class BooleanExpression(ABC):
     """Represents a boolean expression tree."""
 
     @abstractmethod
-    def __invert__(self) -> "BooleanExpression":
+    def __invert__(self) -> BooleanExpression:
         """Transform the Expression into its negated version."""
 
 
@@ -156,7 +158,7 @@ class Reference(UnboundTerm[T], BaseReference[T]):
         Returns:
             BoundReference: A reference bound to the specific field in the 
Iceberg schema
         """
-        field = schema.find_field(name_or_id=self.name, 
case_sensitive=case_sensitive)
+        field = schema.find_field(name_or_id=self.name, 
case_sensitive=case_sensitive)  # pylint: disable=redefined-outer-name
 
         if not field:
             raise ValueError(f"Cannot find field '{self.name}' in schema: 
{schema}")
@@ -169,16 +171,70 @@ class Reference(UnboundTerm[T], BaseReference[T]):
         return BoundReference(field=field, accessor=accessor)
 
 
-@dataclass(frozen=True)  # type: ignore[misc]
 class BoundPredicate(Bound[T], BooleanExpression):
-    term: BoundReference[T]
-    literals: Tuple[Literal[T], ...]
+    _term: BoundTerm[T]
+    _literals: tuple[Literal[T], ...]
+
+    def __init__(self, term: BoundTerm[T], *literals: Literal[T]):
+        self._term = term
+        self._literals = literals
+        self._validate_literals()
+
+    def _validate_literals(self):
+        if len(self.literals) != 1:
+            raise AttributeError(f"{self.__class__.__name__} must have exactly 
1 literal.")
+
+    @property
+    def term(self) -> BoundTerm[T]:
+        return self._term
+
+    @property
+    def literals(self) -> tuple[Literal[T], ...]:
+        return self._literals
+
+    def __str__(self) -> str:
+        return f"{self.__class__.__name__}({str(self.term)}{self.literals and 
', '+str(self.literals)})"
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({repr(self.term)}{self.literals and 
', '+repr(self.literals)})"
+
+    def __eq__(self, other) -> bool:
+        return id(self) == id(other) or (
+            type(self) == type(other) and self.term == other.term and 
self.literals == other.literals
+        )
 
 
-@dataclass(frozen=True)  # type: ignore[misc]
-class UnboundPredicate(Unbound[T, BooleanExpression], BooleanExpression):
-    term: Reference[T]
-    literals: Tuple[Literal[T], ...]
+class UnboundPredicate(Unbound[T, BooleanExpression], BooleanExpression, ABC):
+    _term: UnboundTerm[T]
+    _literals: tuple[Literal[T], ...]
+
+    def __init__(self, term: UnboundTerm[T], *literals: Literal[T]):
+        self._term = term
+        self._literals = literals
+        self._validate_literals()
+
+    def _validate_literals(self):
+        if len(self.literals) != 1:
+            raise AttributeError(f"{self.__class__.__name__} must have exactly 
1 literal.")
+
+    @property
+    def term(self) -> UnboundTerm[T]:
+        return self._term
+
+    @property
+    def literals(self) -> tuple[Literal[T], ...]:
+        return self._literals
+
+    def __str__(self) -> str:
+        return f"{self.__class__.__name__}({str(self.term)}{self.literals and 
', '+str(self.literals)})"
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__name__}({repr(self.term)}{self.literals and 
', '+repr(self.literals)})"
+
+    def __eq__(self, other) -> bool:
+        return id(self) == id(other) or (
+            type(self) == type(other) and self.term == other.term and 
self.literals == other.literals
+        )
 
 
 class And(BooleanExpression):
@@ -209,14 +265,14 @@ class And(BooleanExpression):
     def __eq__(self, other) -> bool:
         return id(self) == id(other) or (isinstance(other, And) and self.left 
== other.left and self.right == other.right)
 
-    def __invert__(self) -> "Or":
+    def __invert__(self) -> Or:
         return Or(~self.left, ~self.right)
 
     def __repr__(self) -> str:
         return f"And({repr(self.left)}, {repr(self.right)})"
 
     def __str__(self) -> str:
-        return f"({self.left} and {self.right})"
+        return f"And({str(self.left)}, {str(self.right)})"
 
 
 class Or(BooleanExpression):
@@ -247,14 +303,14 @@ class Or(BooleanExpression):
     def __eq__(self, other) -> bool:
         return id(self) == id(other) or (isinstance(other, Or) and self.left 
== other.left and self.right == other.right)
 
-    def __invert__(self) -> "And":
+    def __invert__(self) -> And:
         return And(~self.left, ~self.right)
 
     def __repr__(self) -> str:
         return f"Or({repr(self.left)}, {repr(self.right)})"
 
     def __str__(self) -> str:
-        return f"({self.left} or {self.right})"
+        return f"Or({str(self.left)}, {str(self.right)})"
 
 
 class Not(BooleanExpression):
@@ -282,49 +338,239 @@ class Not(BooleanExpression):
         return f"Not({repr(self.child)})"
 
     def __str__(self) -> str:
-        return f"(not {self.child})"
+        return f"Not({str(self.child)})"
 
 
+@dataclass(frozen=True)
 class AlwaysTrue(BooleanExpression, ABC, Singleton):
     """TRUE expression"""
 
-    def __invert__(self) -> "AlwaysFalse":
+    def __invert__(self) -> AlwaysFalse:
         return AlwaysFalse()
 
-    def __repr__(self) -> str:
-        return "AlwaysTrue()"
-
-    def __str__(self) -> str:
-        return "true"
-
 
+@dataclass(frozen=True)
 class AlwaysFalse(BooleanExpression, ABC, Singleton):
     """FALSE expression"""
 
-    def __invert__(self) -> "AlwaysTrue":
+    def __invert__(self) -> AlwaysTrue:
         return AlwaysTrue()
 
-    def __repr__(self) -> str:
-        return "AlwaysFalse()"
 
-    def __str__(self) -> str:
-        return "false"
+class IsNull(UnboundPredicate[T]):
+    def __invert__(self) -> NotNull:
+        return NotNull(self.term)
+
+    def _validate_literals(self):  # pylint: disable=W0238
+        if self.literals is not None:
+            raise AttributeError("Null is a unary predicate and takes no 
Literals.")
+
+    def bind(self, schema: Schema, case_sensitive: bool) -> BoundIsNull[T]:
+        bound_ref = self.term.bind(schema, case_sensitive)
+        return BoundIsNull(bound_ref)
+
+
+class BoundIsNull(BoundPredicate[T]):
+    def __invert__(self) -> BoundNotNull:
+        return BoundNotNull(self.term)
+
+    def _validate_literals(self):  # pylint: disable=W0238
+        if self.literals:
+            raise AttributeError("Null is a unary predicate and takes no 
Literals.")
+
+
+class NotNull(UnboundPredicate[T]):
+    def __invert__(self) -> IsNull:
+        return IsNull(self.term)
+
+    def _validate_literals(self):  # pylint: disable=W0238
+        if self.literals:
+            raise AttributeError("NotNull is a unary predicate and takes no 
Literals.")
+
+    def bind(self, schema: Schema, case_sensitive: bool) -> BoundNotNull[T]:
+        bound_ref = self.term.bind(schema, case_sensitive)
+        return BoundNotNull(bound_ref)
+
+
+class BoundNotNull(BoundPredicate[T]):
+    def __invert__(self) -> BoundIsNull:
+        return BoundIsNull(self.term)
+
+    def _validate_literals(self):  # pylint: disable=W0238
+        if self.literals:
+            raise AttributeError("NotNull is a unary predicate and takes no 
Literals.")
+
+
+class IsNaN(UnboundPredicate[T]):
+    def __invert__(self) -> NotNaN:
+        return NotNaN(self.term)
+
+    def _validate_literals(self):  # pylint: disable=W0238
+        if self.literals:
+            raise AttributeError("IsNaN is a unary predicate and takes no 
Literals.")
+
+    def bind(self, schema: Schema, case_sensitive: bool) -> BoundIsNaN[T]:
+        bound_ref = self.term.bind(schema, case_sensitive)
+        return BoundIsNaN(bound_ref)
+
+
+class BoundIsNaN(BoundPredicate[T]):
+    def __invert__(self) -> BoundNotNaN:
+        return BoundNotNaN(self.term)
+
+    def _validate_literals(self):  # pylint: disable=W0238
+        if self.literals:
+            raise AttributeError("IsNaN is a unary predicate and takes no 
Literals.")
+
+
+class NotNaN(UnboundPredicate[T]):
+    def __invert__(self) -> IsNaN:
+        return IsNaN(self.term)
+
+    def _validate_literals(self):  # pylint: disable=W0238
+        if self.literals:
+            raise AttributeError("NotNaN is a unary predicate and takes no 
Literals.")
+
+    def bind(self, schema: Schema, case_sensitive: bool) -> BoundNotNaN[T]:
+        bound_ref = self.term.bind(schema, case_sensitive)
+        return BoundNotNaN(bound_ref)
+
+
+class BoundNotNaN(BoundPredicate[T]):
+    def __invert__(self) -> BoundIsNaN:
+        return BoundIsNaN(self.term)
+
+    def _validate_literals(self):  # pylint: disable=W0238
+        if self.literals:
+            raise AttributeError("NotNaN is a unary predicate and takes no 
Literals.")
 
 
-@dataclass(frozen=True)
 class BoundIn(BoundPredicate[T]):
-    def __invert__(self):
-        raise TypeError("In expressions do not support negation.")
+    def _validate_literals(self):  # pylint: disable=W0238
+        if not self.literals:
+            raise AttributeError("BoundIn must contain at least 1 literal.")
+
+    def __invert__(self) -> BoundNotIn[T]:
+        return BoundNotIn(self.term, *self.literals)
 
 
-@dataclass(frozen=True)
 class In(UnboundPredicate[T]):
-    def __invert__(self):
-        raise TypeError("In expressions do not support negation.")
+    def _validate_literals(self):  # pylint: disable=W0238
+        if not self.literals:
+            raise AttributeError("In must contain at least 1 literal.")
+
+    def __invert__(self) -> NotIn[T]:
+        return NotIn(self.term, *self.literals)
 
     def bind(self, schema: Schema, case_sensitive: bool) -> BoundIn[T]:
         bound_ref = self.term.bind(schema, case_sensitive)
-        return BoundIn(bound_ref, tuple(lit.to(bound_ref.field.field_type) for 
lit in self.literals))  # type: ignore
+        return BoundIn(bound_ref, *tuple(lit.to(bound_ref.field.field_type) 
for lit in self.literals))  # type: ignore
+
+
+class BoundNotIn(BoundPredicate[T]):
+    def _validate_literals(self):  # pylint: disable=W0238
+        if not self.literals:
+            raise AttributeError("BoundNotIn must contain at least 1 literal.")
+
+    def __invert__(self) -> BoundIn[T]:
+        return BoundIn(self.term, *self.literals)
+
+
+class NotIn(UnboundPredicate[T]):
+    def _validate_literals(self):  # pylint: disable=W0238
+        if not self.literals:
+            raise AttributeError("NotIn must contain at least 1 literal.")
+
+    def __invert__(self) -> In[T]:
+        return In(self.term, *self.literals)
+
+    def bind(self, schema: Schema, case_sensitive: bool) -> BoundNotIn[T]:
+        bound_ref = self.term.bind(schema, case_sensitive)
+        return BoundNotIn(bound_ref, *tuple(lit.to(bound_ref.field.field_type) 
for lit in self.literals))  # type: ignore
+
+
+class BoundEq(BoundPredicate[T]):
+    def __invert__(self) -> BoundNotEq[T]:
+        return BoundNotEq(self.term, *self.literals)
+
+
+class Eq(UnboundPredicate[T]):
+    def __invert__(self) -> NotEq[T]:
+        return NotEq(self.term, *self.literals)
+
+    def bind(self, schema: Schema, case_sensitive: bool) -> BoundEq[T]:
+        bound_ref = self.term.bind(schema, case_sensitive)
+        return BoundEq(bound_ref, 
self.literals[0].to(bound_ref.field.field_type))  # type: ignore
+
+
+class BoundNotEq(BoundPredicate[T]):
+    def __invert__(self) -> BoundEq[T]:
+        return BoundEq(self.term, *self.literals)
+
+
+class NotEq(UnboundPredicate[T]):
+    def __invert__(self) -> Eq[T]:
+        return Eq(self.term, *self.literals)
+
+    def bind(self, schema: Schema, case_sensitive: bool) -> BoundNotEq[T]:
+        bound_ref = self.term.bind(schema, case_sensitive)
+        return BoundNotEq(bound_ref, 
self.literals[0].to(bound_ref.field.field_type))  # type: ignore
+
+
+class BoundLt(BoundPredicate[T]):
+    def __invert__(self) -> BoundGtEq[T]:
+        return BoundGtEq(self.term, *self.literals)
+
+
+class Lt(UnboundPredicate[T]):
+    def __invert__(self) -> GtEq[T]:
+        return GtEq(self.term, *self.literals)
+
+    def bind(self, schema: Schema, case_sensitive: bool) -> BoundEq[T]:
+        bound_ref = self.term.bind(schema, case_sensitive)
+        return BoundLt(bound_ref, 
self.literals[0].to(bound_ref.field.field_type))  # type: ignore
+
+
+class BoundGtEq(BoundPredicate[T]):
+    def __invert__(self) -> BoundLt[T]:
+        return BoundLt(self.term, *self.literals)
+
+
+class GtEq(UnboundPredicate[T]):
+    def __invert__(self) -> Lt[T]:
+        return Lt(self.term, *self.literals)
+
+    def bind(self, schema: Schema, case_sensitive: bool) -> BoundEq[T]:
+        bound_ref = self.term.bind(schema, case_sensitive)
+        return BoundGtEq(bound_ref, 
self.literals[0].to(bound_ref.field.field_type))  # type: ignore
+
+
+class BoundGt(BoundPredicate[T]):
+    def __invert__(self) -> BoundLtEq[T]:
+        return BoundLtEq(self.term, *self.literals)
+
+
+class Gt(UnboundPredicate[T]):
+    def __invert__(self) -> LtEq[T]:
+        return LtEq(self.term, *self.literals)
+
+    def bind(self, schema: Schema, case_sensitive: bool) -> BoundEq[T]:
+        bound_ref = self.term.bind(schema, case_sensitive)
+        return BoundGt(bound_ref, 
self.literals[0].to(bound_ref.field.field_type))  # type: ignore
+
+
+class BoundLtEq(BoundPredicate[T]):
+    def __invert__(self) -> BoundGt[T]:
+        return BoundGt(self.term, *self.literals)
+
+
+class LtEq(UnboundPredicate[T]):
+    def __invert__(self) -> Gt[T]:
+        return Gt(self.term, *self.literals)
+
+    def bind(self, schema: Schema, case_sensitive: bool) -> BoundEq[T]:
+        bound_ref = self.term.bind(schema, case_sensitive)
+        return BoundLtEq(bound_ref, 
self.literals[0].to(bound_ref.field.field_type))  # type: ignore
 
 
 class BooleanExpressionVisitor(Generic[T], ABC):
diff --git a/python/tests/expressions/test_expressions_base.py 
b/python/tests/expressions/test_expressions_base.py
index 51630d71e8..2fca17d6cc 100644
--- a/python/tests/expressions/test_expressions_base.py
+++ b/python/tests/expressions/test_expressions_base.py
@@ -130,9 +130,9 @@ def test_reprs(op, rep):
 @pytest.mark.parametrize(
     "op, string",
     [
-        (base.And(ExpressionA(), ExpressionB()), "(testexpra and testexprb)"),
-        (base.Or(ExpressionA(), ExpressionB()), "(testexpra or testexprb)"),
-        (base.Not(ExpressionA()), "(not testexpra)"),
+        (base.And(ExpressionA(), ExpressionB()), "And(testexpra, testexprb)"),
+        (base.Or(ExpressionA(), ExpressionB()), "Or(testexpra, testexprb)"),
+        (base.Not(ExpressionA()), "Not(testexpra)"),
     ],
 )
 def test_strs(op, string):
@@ -143,25 +143,193 @@ def test_strs(op, string):
     "a,  schema, case_sensitive, success",
     [
         (
-            base.In(base.Reference("foo"), (literal("hello"), 
literal("world"))),
+            base.In(base.Reference("foo"), literal("hello"), literal("world")),
             "table_schema_simple",
             True,
             True,
         ),
         (
-            base.In(base.Reference("not_foo"), (literal("hello"), 
literal("world"))),
+            base.In(base.Reference("not_foo"), literal("hello"), 
literal("world")),
             "table_schema_simple",
             False,
             False,
         ),
         (
-            base.In(base.Reference("Bar"), (literal("hello"), 
literal("world"))),
+            base.In(base.Reference("Bar"), literal(5), literal(2)),
             "table_schema_simple",
             False,
             True,
         ),
         (
-            base.In(base.Reference("Bar"), (literal("hello"), 
literal("world"))),
+            base.In(base.Reference("Bar"), literal(5), literal(2)),
+            "table_schema_simple",
+            True,
+            False,
+        ),
+        (
+            base.NotIn(base.Reference("foo"), literal("hello"), 
literal("world")),
+            "table_schema_simple",
+            True,
+            True,
+        ),
+        (
+            base.NotIn(base.Reference("not_foo"), literal("hello"), 
literal("world")),
+            "table_schema_simple",
+            False,
+            False,
+        ),
+        (
+            base.NotIn(base.Reference("Bar"), literal(5), literal(2)),
+            "table_schema_simple",
+            False,
+            True,
+        ),
+        (
+            base.NotIn(base.Reference("Bar"), literal(5), literal(2)),
+            "table_schema_simple",
+            True,
+            False,
+        ),
+        (
+            base.NotEq(base.Reference("foo"), literal("hello")),
+            "table_schema_simple",
+            True,
+            True,
+        ),
+        (
+            base.NotEq(base.Reference("not_foo"), literal("hello")),
+            "table_schema_simple",
+            False,
+            False,
+        ),
+        (
+            base.NotEq(base.Reference("Bar"), literal(5)),
+            "table_schema_simple",
+            False,
+            True,
+        ),
+        (
+            base.NotEq(base.Reference("Bar"), literal(5)),
+            "table_schema_simple",
+            True,
+            False,
+        ),
+        (
+            base.Eq(base.Reference("foo"), literal("hello")),
+            "table_schema_simple",
+            True,
+            True,
+        ),
+        (
+            base.Eq(base.Reference("not_foo"), literal("hello")),
+            "table_schema_simple",
+            False,
+            False,
+        ),
+        (
+            base.Eq(base.Reference("Bar"), literal(5)),
+            "table_schema_simple",
+            False,
+            True,
+        ),
+        (
+            base.Eq(base.Reference("Bar"), literal(5)),
+            "table_schema_simple",
+            True,
+            False,
+        ),
+        (
+            base.Gt(base.Reference("foo"), literal("hello")),
+            "table_schema_simple",
+            True,
+            True,
+        ),
+        (
+            base.Gt(base.Reference("not_foo"), literal("hello")),
+            "table_schema_simple",
+            False,
+            False,
+        ),
+        (
+            base.Gt(base.Reference("Bar"), literal(5)),
+            "table_schema_simple",
+            False,
+            True,
+        ),
+        (
+            base.Gt(base.Reference("Bar"), literal(5)),
+            "table_schema_simple",
+            True,
+            False,
+        ),
+        (
+            base.Lt(base.Reference("foo"), literal("hello")),
+            "table_schema_simple",
+            True,
+            True,
+        ),
+        (
+            base.Lt(base.Reference("not_foo"), literal("hello")),
+            "table_schema_simple",
+            False,
+            False,
+        ),
+        (
+            base.Lt(base.Reference("Bar"), literal(5)),
+            "table_schema_simple",
+            False,
+            True,
+        ),
+        (
+            base.Lt(base.Reference("Bar"), literal(5)),
+            "table_schema_simple",
+            True,
+            False,
+        ),
+        (
+            base.GtEq(base.Reference("foo"), literal("hello")),
+            "table_schema_simple",
+            True,
+            True,
+        ),
+        (
+            base.GtEq(base.Reference("not_foo"), literal("hello")),
+            "table_schema_simple",
+            False,
+            False,
+        ),
+        (
+            base.GtEq(base.Reference("Bar"), literal(5)),
+            "table_schema_simple",
+            False,
+            True,
+        ),
+        (
+            base.GtEq(base.Reference("Bar"), literal(5)),
+            "table_schema_simple",
+            True,
+            False,
+        ),
+        (
+            base.LtEq(base.Reference("foo"), literal("hello")),
+            "table_schema_simple",
+            True,
+            True,
+        ),
+        (
+            base.LtEq(base.Reference("not_foo"), literal("hello")),
+            "table_schema_simple",
+            False,
+            False,
+        ),
+        (
+            base.LtEq(base.Reference("Bar"), literal(5)),
+            "table_schema_simple",
+            False,
+            True,
+        ),
+        (
+            base.LtEq(base.Reference("Bar"), literal(5)),
             "table_schema_simple",
             True,
             False,
@@ -194,14 +362,14 @@ def test_bind(a, schema, case_sensitive, success, 
request):
         (ExpressionA(), ExpressionA(), ExpressionB()),
         (ExpressionB(), ExpressionB(), ExpressionA()),
         (
-            base.In(base.Reference("foo"), (literal("hello"), 
literal("world"))),
-            base.In(base.Reference("foo"), (literal("hello"), 
literal("world"))),
-            base.In(base.Reference("not_foo"), (literal("hello"), 
literal("world"))),
+            base.In(base.Reference("foo"), literal("hello"), literal("world")),
+            base.In(base.Reference("foo"), literal("hello"), literal("world")),
+            base.In(base.Reference("not_foo"), literal("hello"), 
literal("world")),
         ),
         (
-            base.In(base.Reference("foo"), (literal("hello"), 
literal("world"))),
-            base.In(base.Reference("foo"), (literal("hello"), 
literal("world"))),
-            base.In(base.Reference("foo"), (literal("goodbye"), 
literal("world"))),
+            base.In(base.Reference("foo"), literal("hello"), literal("world")),
+            base.In(base.Reference("foo"), literal("hello"), literal("world")),
+            base.In(base.Reference("foo"), literal("goodbye"), 
literal("world")),
         ),
     ],
 )
@@ -210,21 +378,39 @@ def test_eq(exp, testexpra, testexprb):
 
 
 @pytest.mark.parametrize(
-    "lhs, rhs, raises",
+    "lhs, rhs",
     [
-        (base.And(ExpressionA(), ExpressionB()), base.Or(ExpressionB(), 
ExpressionA()), False),
-        (base.Or(ExpressionA(), ExpressionB()), base.And(ExpressionB(), 
ExpressionA()), False),
-        (base.Not(ExpressionA()), ExpressionA(), False),
-        (base.In(base.Reference("foo"), (literal("hello"), literal("world"))), 
None, True),
-        (ExpressionA(), ExpressionB(), False),
+        (
+            base.And(ExpressionA(), ExpressionB()),
+            base.Or(ExpressionB(), ExpressionA()),
+        ),
+        (
+            base.Or(ExpressionA(), ExpressionB()),
+            base.And(ExpressionB(), ExpressionA()),
+        ),
+        (
+            base.Not(ExpressionA()),
+            ExpressionA(),
+        ),
+        (
+            base.In(base.Reference("foo"), literal("hello"), literal("world")),
+            base.NotIn(base.Reference("foo"), literal("hello"), 
literal("world")),
+        ),
+        (
+            base.NotIn(base.Reference("foo"), literal("hello"), 
literal("world")),
+            base.In(base.Reference("foo"), literal("hello"), literal("world")),
+        ),
+        (base.Gt(base.Reference("foo"), literal(5)), 
base.LtEq(base.Reference("foo"), literal(5))),
+        (base.Lt(base.Reference("foo"), literal(5)), 
base.GtEq(base.Reference("foo"), literal(5))),
+        (base.Eq(base.Reference("foo"), literal(5)), 
base.NotEq(base.Reference("foo"), literal(5))),
+        (
+            ExpressionA(),
+            ExpressionB(),
+        ),
     ],
 )
-def test_negate(lhs, rhs, raises):
-    if not raises:
-        assert ~lhs == rhs
-    else:
-        with pytest.raises(TypeError):
-            ~lhs  # pylint: disable=W0104
+def test_negate(lhs, rhs):
+    assert ~lhs == rhs
 
 
 @pytest.mark.parametrize(
@@ -400,55 +586,65 @@ def 
test_always_false_or_always_true_expression_binding(table_schema_simple):
     [
         (
             base.And(
-                base.In(base.Reference("foo"), (literal("foo"), 
literal("bar"))),
-                base.In(base.Reference("bar"), (literal(1), literal(2), 
literal(3))),
+                base.In(base.Reference("foo"), literal("foo"), literal("bar")),
+                base.In(base.Reference("bar"), literal(1), literal(2), 
literal(3)),
             ),
             base.And(
                 base.BoundIn[str](
-                    term=base.BoundReference(
+                    base.BoundReference(
                         field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                         accessor=Accessor(position=0, inner=None),
                     ),
-                    literals=(StringLiteral("foo"), StringLiteral("bar")),
+                    StringLiteral("foo"),
+                    StringLiteral("bar"),
                 ),
                 base.BoundIn[int](
-                    term=base.BoundReference(
+                    base.BoundReference(
                         field=NestedField(field_id=2, name="bar", 
field_type=IntegerType(), required=True),
                         accessor=Accessor(position=1, inner=None),
                     ),
-                    literals=(LongLiteral(1), LongLiteral(2), LongLiteral(3)),
+                    LongLiteral(1),
+                    LongLiteral(2),
+                    LongLiteral(3),
                 ),
             ),
         ),
         (
             base.And(
-                base.In(base.Reference("foo"), (literal("bar"), 
literal("baz"))),
-                base.In(base.Reference("bar"), (literal(1),)),
-                base.In(base.Reference("foo"), (literal("baz"),)),
+                base.In(base.Reference("foo"), literal("bar"), literal("baz")),
+                base.In(
+                    base.Reference("bar"),
+                    literal(1),
+                ),
+                base.In(
+                    base.Reference("foo"),
+                    literal("baz"),
+                ),
             ),
             base.And(
                 base.And(
                     base.BoundIn[str](
-                        term=base.BoundReference(
+                        base.BoundReference(
                             field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                             accessor=Accessor(position=0, inner=None),
                         ),
-                        literals=(StringLiteral("bar"), StringLiteral("baz")),
+                        StringLiteral("bar"),
+                        StringLiteral("baz"),
                     ),
                     base.BoundIn[int](
-                        term=base.BoundReference(
+                        base.BoundReference(
                             field=NestedField(field_id=2, name="bar", 
field_type=IntegerType(), required=True),
                             accessor=Accessor(position=1, inner=None),
                         ),
-                        literals=(LongLiteral(1),),
+                        LongLiteral(1),
                     ),
                 ),
                 base.BoundIn[str](
-                    term=base.BoundReference(
+                    base.BoundReference(
                         field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                         accessor=Accessor(position=0, inner=None),
                     ),
-                    literals=(StringLiteral("baz"),),
+                    StringLiteral("baz"),
                 ),
             ),
         ),
@@ -465,55 +661,65 @@ def test_and_expression_binding(unbound_and_expression, 
expected_bound_expressio
     [
         (
             base.Or(
-                base.In(base.Reference("foo"), (literal("foo"), 
literal("bar"))),
-                base.In(base.Reference("bar"), (literal(1), literal(2), 
literal(3))),
+                base.In(base.Reference("foo"), literal("foo"), literal("bar")),
+                base.In(base.Reference("bar"), literal(1), literal(2), 
literal(3)),
             ),
             base.Or(
                 base.BoundIn[str](
-                    term=base.BoundReference(
+                    base.BoundReference(
                         field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                         accessor=Accessor(position=0, inner=None),
                     ),
-                    literals=(StringLiteral("foo"), StringLiteral("bar")),
+                    StringLiteral("foo"),
+                    StringLiteral("bar"),
                 ),
                 base.BoundIn[int](
-                    term=base.BoundReference(
+                    base.BoundReference(
                         field=NestedField(field_id=2, name="bar", 
field_type=IntegerType(), required=True),
                         accessor=Accessor(position=1, inner=None),
                     ),
-                    literals=(LongLiteral(1), LongLiteral(2), LongLiteral(3)),
+                    LongLiteral(1),
+                    LongLiteral(2),
+                    LongLiteral(3),
                 ),
             ),
         ),
         (
             base.Or(
-                base.In(base.Reference("foo"), (literal("bar"), 
literal("baz"))),
-                base.In(base.Reference("foo"), (literal("bar"),)),
-                base.In(base.Reference("foo"), (literal("baz"),)),
+                base.In(base.Reference("foo"), literal("bar"), literal("baz")),
+                base.In(
+                    base.Reference("foo"),
+                    literal("bar"),
+                ),
+                base.In(
+                    base.Reference("foo"),
+                    literal("baz"),
+                ),
             ),
             base.Or(
                 base.Or(
                     base.BoundIn[str](
-                        term=base.BoundReference(
+                        base.BoundReference(
                             field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                             accessor=Accessor(position=0, inner=None),
                         ),
-                        literals=(StringLiteral("bar"), StringLiteral("baz")),
+                        StringLiteral("bar"),
+                        StringLiteral("baz"),
                     ),
                     base.BoundIn[str](
-                        term=base.BoundReference(
+                        base.BoundReference(
                             field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                             accessor=Accessor(position=0, inner=None),
                         ),
-                        literals=(StringLiteral("bar"),),
+                        StringLiteral("bar"),
                     ),
                 ),
                 base.BoundIn[str](
-                    term=base.BoundReference(
+                    base.BoundReference(
                         field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                         accessor=Accessor(position=0, inner=None),
                     ),
-                    literals=(StringLiteral("baz"),),
+                    StringLiteral("baz"),
                 ),
             ),
         ),
@@ -550,33 +756,38 @@ def test_or_expression_binding(unbound_or_expression, 
expected_bound_expression,
     "unbound_in_expression,expected_bound_expression",
     [
         (
-            base.In(base.Reference("foo"), (literal("foo"), literal("bar"))),
+            base.In(base.Reference("foo"), literal("foo"), literal("bar")),
             base.BoundIn[str](
-                term=base.BoundReference[str](
+                base.BoundReference[str](
                     field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                     accessor=Accessor(position=0, inner=None),
                 ),
-                literals=(StringLiteral("foo"), StringLiteral("bar")),
+                StringLiteral("foo"),
+                StringLiteral("bar"),
             ),
         ),
         (
-            base.In(base.Reference("foo"), (literal("bar"), literal("baz"))),
+            base.In(base.Reference("foo"), literal("bar"), literal("baz")),
             base.BoundIn[str](
-                term=base.BoundReference[str](
+                base.BoundReference[str](
                     field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                     accessor=Accessor(position=0, inner=None),
                 ),
-                literals=(StringLiteral("bar"), StringLiteral("baz")),
+                StringLiteral("bar"),
+                StringLiteral("baz"),
             ),
         ),
         (
-            base.In(base.Reference("foo"), (literal("bar"),)),
+            base.In(
+                base.Reference("foo"),
+                literal("bar"),
+            ),
             base.BoundIn[str](
-                term=base.BoundReference[str](
+                base.BoundReference[str](
                     field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                     accessor=Accessor(position=0, inner=None),
                 ),
-                literals=(StringLiteral("bar"),),
+                StringLiteral("bar"),
             ),
         ),
     ],
@@ -591,39 +802,43 @@ def test_in_expression_binding(unbound_in_expression, 
expected_bound_expression,
     "unbound_not_expression,expected_bound_expression",
     [
         (
-            base.Not(base.In(base.Reference("foo"), (literal("foo"), 
literal("bar")))),
+            base.Not(base.In(base.Reference("foo"), literal("foo"), 
literal("bar"))),
             base.Not(
                 base.BoundIn[str](
-                    term=base.BoundReference[str](
+                    base.BoundReference[str](
                         field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                         accessor=Accessor(position=0, inner=None),
                     ),
-                    literals=(StringLiteral("foo"), StringLiteral("bar")),
+                    StringLiteral("foo"),
+                    StringLiteral("bar"),
                 )
             ),
         ),
         (
             base.Not(
                 base.Or(
-                    base.In(base.Reference("foo"), (literal("foo"), 
literal("bar"))),
-                    base.In(base.Reference("foo"), (literal("foo"), 
literal("bar"), literal("baz"))),
+                    base.In(base.Reference("foo"), literal("foo"), 
literal("bar")),
+                    base.In(base.Reference("foo"), literal("foo"), 
literal("bar"), literal("baz")),
                 )
             ),
             base.Not(
                 base.Or(
                     base.BoundIn[str](
-                        term=base.BoundReference(
+                        base.BoundReference(
                             field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                             accessor=Accessor(position=0, inner=None),
                         ),
-                        literals=(StringLiteral("foo"), StringLiteral("bar")),
+                        StringLiteral("foo"),
+                        StringLiteral("bar"),
                     ),
                     base.BoundIn[str](
-                        term=base.BoundReference(
+                        base.BoundReference(
                             field=NestedField(field_id=1, name="foo", 
field_type=StringType(), required=False),
                             accessor=Accessor(position=0, inner=None),
                         ),
-                        literals=(StringLiteral("foo"), StringLiteral("bar"), 
StringLiteral("baz")),
+                        StringLiteral("foo"),
+                        StringLiteral("bar"),
+                        StringLiteral("baz"),
                     ),
                 ),
             ),

Reply via email to