This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 5e25f2bb76 Python: Add more expression classes (#5258)
5e25f2bb76 is described below
commit 5e25f2bb7638cd05b0e1e15fc3824012c7df9dc6
Author: Nick Ouellet <[email protected]>
AuthorDate: Mon Jul 25 15:37:26 2022 -0400
Python: Add more expression classes (#5258)
---
python/pyiceberg/expressions/base.py | 316 ++++++++++++++++---
python/tests/expressions/test_expressions_base.py | 361 +++++++++++++++++-----
2 files changed, 569 insertions(+), 108 deletions(-)
diff --git a/python/pyiceberg/expressions/base.py
b/python/pyiceberg/expressions/base.py
index 876965ae71..4b4a487a4d 100644
--- a/python/pyiceberg/expressions/base.py
+++ b/python/pyiceberg/expressions/base.py
@@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import reduce, singledispatch
-from typing import Generic, Tuple, TypeVar
+from typing import Generic, TypeVar
from pyiceberg.files import StructProtocol
from pyiceberg.schema import Accessor, Schema
@@ -41,7 +43,7 @@ class Literal(Generic[T], ABC):
return self._value # type: ignore
@abstractmethod
- def to(self, type_var) -> "Literal":
+ def to(self, type_var) -> Literal:
... # pragma: no cover
def __repr__(self):
@@ -73,7 +75,7 @@ class BooleanExpression(ABC):
"""Represents a boolean expression tree."""
@abstractmethod
- def __invert__(self) -> "BooleanExpression":
+ def __invert__(self) -> BooleanExpression:
"""Transform the Expression into its negated version."""
@@ -156,7 +158,7 @@ class Reference(UnboundTerm[T], BaseReference[T]):
Returns:
BoundReference: A reference bound to the specific field in the
Iceberg schema
"""
- field = schema.find_field(name_or_id=self.name,
case_sensitive=case_sensitive)
+ field = schema.find_field(name_or_id=self.name,
case_sensitive=case_sensitive) # pylint: disable=redefined-outer-name
if not field:
raise ValueError(f"Cannot find field '{self.name}' in schema:
{schema}")
@@ -169,16 +171,70 @@ class Reference(UnboundTerm[T], BaseReference[T]):
return BoundReference(field=field, accessor=accessor)
-@dataclass(frozen=True) # type: ignore[misc]
class BoundPredicate(Bound[T], BooleanExpression):
- term: BoundReference[T]
- literals: Tuple[Literal[T], ...]
+ _term: BoundTerm[T]
+ _literals: tuple[Literal[T], ...]
+
+ def __init__(self, term: BoundTerm[T], *literals: Literal[T]):
+ self._term = term
+ self._literals = literals
+ self._validate_literals()
+
+ def _validate_literals(self):
+ if len(self.literals) != 1:
+ raise AttributeError(f"{self.__class__.__name__} must have exactly
1 literal.")
+
+ @property
+ def term(self) -> BoundTerm[T]:
+ return self._term
+
+ @property
+ def literals(self) -> tuple[Literal[T], ...]:
+ return self._literals
+
+ def __str__(self) -> str:
+ return f"{self.__class__.__name__}({str(self.term)}{self.literals and
', '+str(self.literals)})"
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({repr(self.term)}{self.literals and
', '+repr(self.literals)})"
+
+ def __eq__(self, other) -> bool:
+ return id(self) == id(other) or (
+ type(self) == type(other) and self.term == other.term and
self.literals == other.literals
+ )
-@dataclass(frozen=True) # type: ignore[misc]
-class UnboundPredicate(Unbound[T, BooleanExpression], BooleanExpression):
- term: Reference[T]
- literals: Tuple[Literal[T], ...]
+class UnboundPredicate(Unbound[T, BooleanExpression], BooleanExpression, ABC):
+ _term: UnboundTerm[T]
+ _literals: tuple[Literal[T], ...]
+
+ def __init__(self, term: UnboundTerm[T], *literals: Literal[T]):
+ self._term = term
+ self._literals = literals
+ self._validate_literals()
+
+ def _validate_literals(self):
+ if len(self.literals) != 1:
+ raise AttributeError(f"{self.__class__.__name__} must have exactly
1 literal.")
+
+ @property
+ def term(self) -> UnboundTerm[T]:
+ return self._term
+
+ @property
+ def literals(self) -> tuple[Literal[T], ...]:
+ return self._literals
+
+ def __str__(self) -> str:
+ return f"{self.__class__.__name__}({str(self.term)}{self.literals and
', '+str(self.literals)})"
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({repr(self.term)}{self.literals and
', '+repr(self.literals)})"
+
+ def __eq__(self, other) -> bool:
+ return id(self) == id(other) or (
+ type(self) == type(other) and self.term == other.term and
self.literals == other.literals
+ )
class And(BooleanExpression):
@@ -209,14 +265,14 @@ class And(BooleanExpression):
def __eq__(self, other) -> bool:
return id(self) == id(other) or (isinstance(other, And) and self.left
== other.left and self.right == other.right)
- def __invert__(self) -> "Or":
+ def __invert__(self) -> Or:
return Or(~self.left, ~self.right)
def __repr__(self) -> str:
return f"And({repr(self.left)}, {repr(self.right)})"
def __str__(self) -> str:
- return f"({self.left} and {self.right})"
+ return f"And({str(self.left)}, {str(self.right)})"
class Or(BooleanExpression):
@@ -247,14 +303,14 @@ class Or(BooleanExpression):
def __eq__(self, other) -> bool:
return id(self) == id(other) or (isinstance(other, Or) and self.left
== other.left and self.right == other.right)
- def __invert__(self) -> "And":
+ def __invert__(self) -> And:
return And(~self.left, ~self.right)
def __repr__(self) -> str:
return f"Or({repr(self.left)}, {repr(self.right)})"
def __str__(self) -> str:
- return f"({self.left} or {self.right})"
+ return f"Or({str(self.left)}, {str(self.right)})"
class Not(BooleanExpression):
@@ -282,49 +338,239 @@ class Not(BooleanExpression):
return f"Not({repr(self.child)})"
def __str__(self) -> str:
- return f"(not {self.child})"
+ return f"Not({str(self.child)})"
+@dataclass(frozen=True)
class AlwaysTrue(BooleanExpression, ABC, Singleton):
"""TRUE expression"""
- def __invert__(self) -> "AlwaysFalse":
+ def __invert__(self) -> AlwaysFalse:
return AlwaysFalse()
- def __repr__(self) -> str:
- return "AlwaysTrue()"
-
- def __str__(self) -> str:
- return "true"
-
+@dataclass(frozen=True)
class AlwaysFalse(BooleanExpression, ABC, Singleton):
"""FALSE expression"""
- def __invert__(self) -> "AlwaysTrue":
+ def __invert__(self) -> AlwaysTrue:
return AlwaysTrue()
- def __repr__(self) -> str:
- return "AlwaysFalse()"
- def __str__(self) -> str:
- return "false"
+class IsNull(UnboundPredicate[T]):
+ def __invert__(self) -> NotNull:
+ return NotNull(self.term)
+
+ def _validate_literals(self): # pylint: disable=W0238
+ if self.literals is not None:
+ raise AttributeError("Null is a unary predicate and takes no
Literals.")
+
+ def bind(self, schema: Schema, case_sensitive: bool) -> BoundIsNull[T]:
+ bound_ref = self.term.bind(schema, case_sensitive)
+ return BoundIsNull(bound_ref)
+
+
+class BoundIsNull(BoundPredicate[T]):
+ def __invert__(self) -> BoundNotNull:
+ return BoundNotNull(self.term)
+
+ def _validate_literals(self): # pylint: disable=W0238
+ if self.literals:
+ raise AttributeError("Null is a unary predicate and takes no
Literals.")
+
+
+class NotNull(UnboundPredicate[T]):
+ def __invert__(self) -> IsNull:
+ return IsNull(self.term)
+
+ def _validate_literals(self): # pylint: disable=W0238
+ if self.literals:
+ raise AttributeError("NotNull is a unary predicate and takes no
Literals.")
+
+ def bind(self, schema: Schema, case_sensitive: bool) -> BoundNotNull[T]:
+ bound_ref = self.term.bind(schema, case_sensitive)
+ return BoundNotNull(bound_ref)
+
+
+class BoundNotNull(BoundPredicate[T]):
+ def __invert__(self) -> BoundIsNull:
+ return BoundIsNull(self.term)
+
+ def _validate_literals(self): # pylint: disable=W0238
+ if self.literals:
+ raise AttributeError("NotNull is a unary predicate and takes no
Literals.")
+
+
+class IsNaN(UnboundPredicate[T]):
+ def __invert__(self) -> NotNaN:
+ return NotNaN(self.term)
+
+ def _validate_literals(self): # pylint: disable=W0238
+ if self.literals:
+ raise AttributeError("IsNaN is a unary predicate and takes no
Literals.")
+
+ def bind(self, schema: Schema, case_sensitive: bool) -> BoundIsNaN[T]:
+ bound_ref = self.term.bind(schema, case_sensitive)
+ return BoundIsNaN(bound_ref)
+
+
+class BoundIsNaN(BoundPredicate[T]):
+ def __invert__(self) -> BoundNotNaN:
+ return BoundNotNaN(self.term)
+
+ def _validate_literals(self): # pylint: disable=W0238
+ if self.literals:
+ raise AttributeError("IsNaN is a unary predicate and takes no
Literals.")
+
+
+class NotNaN(UnboundPredicate[T]):
+ def __invert__(self) -> IsNaN:
+ return IsNaN(self.term)
+
+ def _validate_literals(self): # pylint: disable=W0238
+ if self.literals:
+ raise AttributeError("NotNaN is a unary predicate and takes no
Literals.")
+
+ def bind(self, schema: Schema, case_sensitive: bool) -> BoundNotNaN[T]:
+ bound_ref = self.term.bind(schema, case_sensitive)
+ return BoundNotNaN(bound_ref)
+
+
+class BoundNotNaN(BoundPredicate[T]):
+ def __invert__(self) -> BoundIsNaN:
+ return BoundIsNaN(self.term)
+
+ def _validate_literals(self): # pylint: disable=W0238
+ if self.literals:
+ raise AttributeError("NotNaN is a unary predicate and takes no
Literals.")
-@dataclass(frozen=True)
class BoundIn(BoundPredicate[T]):
- def __invert__(self):
- raise TypeError("In expressions do not support negation.")
+ def _validate_literals(self): # pylint: disable=W0238
+ if not self.literals:
+ raise AttributeError("BoundIn must contain at least 1 literal.")
+
+ def __invert__(self) -> BoundNotIn[T]:
+ return BoundNotIn(self.term, *self.literals)
-@dataclass(frozen=True)
class In(UnboundPredicate[T]):
- def __invert__(self):
- raise TypeError("In expressions do not support negation.")
+ def _validate_literals(self): # pylint: disable=W0238
+ if not self.literals:
+ raise AttributeError("In must contain at least 1 literal.")
+
+ def __invert__(self) -> NotIn[T]:
+ return NotIn(self.term, *self.literals)
def bind(self, schema: Schema, case_sensitive: bool) -> BoundIn[T]:
bound_ref = self.term.bind(schema, case_sensitive)
- return BoundIn(bound_ref, tuple(lit.to(bound_ref.field.field_type) for
lit in self.literals)) # type: ignore
+ return BoundIn(bound_ref, *tuple(lit.to(bound_ref.field.field_type)
for lit in self.literals)) # type: ignore
+
+
+class BoundNotIn(BoundPredicate[T]):
+ def _validate_literals(self): # pylint: disable=W0238
+ if not self.literals:
+ raise AttributeError("BoundNotIn must contain at least 1 literal.")
+
+ def __invert__(self) -> BoundIn[T]:
+ return BoundIn(self.term, *self.literals)
+
+
+class NotIn(UnboundPredicate[T]):
+ def _validate_literals(self): # pylint: disable=W0238
+ if not self.literals:
+ raise AttributeError("NotIn must contain at least 1 literal.")
+
+ def __invert__(self) -> In[T]:
+ return In(self.term, *self.literals)
+
+ def bind(self, schema: Schema, case_sensitive: bool) -> BoundNotIn[T]:
+ bound_ref = self.term.bind(schema, case_sensitive)
+ return BoundNotIn(bound_ref, *tuple(lit.to(bound_ref.field.field_type)
for lit in self.literals)) # type: ignore
+
+
+class BoundEq(BoundPredicate[T]):
+ def __invert__(self) -> BoundNotEq[T]:
+ return BoundNotEq(self.term, *self.literals)
+
+
+class Eq(UnboundPredicate[T]):
+ def __invert__(self) -> NotEq[T]:
+ return NotEq(self.term, *self.literals)
+
+ def bind(self, schema: Schema, case_sensitive: bool) -> BoundEq[T]:
+ bound_ref = self.term.bind(schema, case_sensitive)
+ return BoundEq(bound_ref,
self.literals[0].to(bound_ref.field.field_type)) # type: ignore
+
+
+class BoundNotEq(BoundPredicate[T]):
+ def __invert__(self) -> BoundEq[T]:
+ return BoundEq(self.term, *self.literals)
+
+
+class NotEq(UnboundPredicate[T]):
+ def __invert__(self) -> Eq[T]:
+ return Eq(self.term, *self.literals)
+
+ def bind(self, schema: Schema, case_sensitive: bool) -> BoundNotEq[T]:
+ bound_ref = self.term.bind(schema, case_sensitive)
+ return BoundNotEq(bound_ref,
self.literals[0].to(bound_ref.field.field_type)) # type: ignore
+
+
+class BoundLt(BoundPredicate[T]):
+ def __invert__(self) -> BoundGtEq[T]:
+ return BoundGtEq(self.term, *self.literals)
+
+
+class Lt(UnboundPredicate[T]):
+ def __invert__(self) -> GtEq[T]:
+ return GtEq(self.term, *self.literals)
+
+ def bind(self, schema: Schema, case_sensitive: bool) -> BoundEq[T]:
+ bound_ref = self.term.bind(schema, case_sensitive)
+ return BoundLt(bound_ref,
self.literals[0].to(bound_ref.field.field_type)) # type: ignore
+
+
+class BoundGtEq(BoundPredicate[T]):
+ def __invert__(self) -> BoundLt[T]:
+ return BoundLt(self.term, *self.literals)
+
+
+class GtEq(UnboundPredicate[T]):
+ def __invert__(self) -> Lt[T]:
+ return Lt(self.term, *self.literals)
+
+ def bind(self, schema: Schema, case_sensitive: bool) -> BoundEq[T]:
+ bound_ref = self.term.bind(schema, case_sensitive)
+ return BoundGtEq(bound_ref,
self.literals[0].to(bound_ref.field.field_type)) # type: ignore
+
+
+class BoundGt(BoundPredicate[T]):
+ def __invert__(self) -> BoundLtEq[T]:
+ return BoundLtEq(self.term, *self.literals)
+
+
+class Gt(UnboundPredicate[T]):
+ def __invert__(self) -> LtEq[T]:
+ return LtEq(self.term, *self.literals)
+
+ def bind(self, schema: Schema, case_sensitive: bool) -> BoundEq[T]:
+ bound_ref = self.term.bind(schema, case_sensitive)
+ return BoundGt(bound_ref,
self.literals[0].to(bound_ref.field.field_type)) # type: ignore
+
+
+class BoundLtEq(BoundPredicate[T]):
+ def __invert__(self) -> BoundGt[T]:
+ return BoundGt(self.term, *self.literals)
+
+
+class LtEq(UnboundPredicate[T]):
+ def __invert__(self) -> Gt[T]:
+ return Gt(self.term, *self.literals)
+
+ def bind(self, schema: Schema, case_sensitive: bool) -> BoundEq[T]:
+ bound_ref = self.term.bind(schema, case_sensitive)
+ return BoundLtEq(bound_ref,
self.literals[0].to(bound_ref.field.field_type)) # type: ignore
class BooleanExpressionVisitor(Generic[T], ABC):
diff --git a/python/tests/expressions/test_expressions_base.py
b/python/tests/expressions/test_expressions_base.py
index 51630d71e8..2fca17d6cc 100644
--- a/python/tests/expressions/test_expressions_base.py
+++ b/python/tests/expressions/test_expressions_base.py
@@ -130,9 +130,9 @@ def test_reprs(op, rep):
@pytest.mark.parametrize(
"op, string",
[
- (base.And(ExpressionA(), ExpressionB()), "(testexpra and testexprb)"),
- (base.Or(ExpressionA(), ExpressionB()), "(testexpra or testexprb)"),
- (base.Not(ExpressionA()), "(not testexpra)"),
+ (base.And(ExpressionA(), ExpressionB()), "And(testexpra, testexprb)"),
+ (base.Or(ExpressionA(), ExpressionB()), "Or(testexpra, testexprb)"),
+ (base.Not(ExpressionA()), "Not(testexpra)"),
],
)
def test_strs(op, string):
@@ -143,25 +143,193 @@ def test_strs(op, string):
"a, schema, case_sensitive, success",
[
(
- base.In(base.Reference("foo"), (literal("hello"),
literal("world"))),
+ base.In(base.Reference("foo"), literal("hello"), literal("world")),
"table_schema_simple",
True,
True,
),
(
- base.In(base.Reference("not_foo"), (literal("hello"),
literal("world"))),
+ base.In(base.Reference("not_foo"), literal("hello"),
literal("world")),
"table_schema_simple",
False,
False,
),
(
- base.In(base.Reference("Bar"), (literal("hello"),
literal("world"))),
+ base.In(base.Reference("Bar"), literal(5), literal(2)),
"table_schema_simple",
False,
True,
),
(
- base.In(base.Reference("Bar"), (literal("hello"),
literal("world"))),
+ base.In(base.Reference("Bar"), literal(5), literal(2)),
+ "table_schema_simple",
+ True,
+ False,
+ ),
+ (
+ base.NotIn(base.Reference("foo"), literal("hello"),
literal("world")),
+ "table_schema_simple",
+ True,
+ True,
+ ),
+ (
+ base.NotIn(base.Reference("not_foo"), literal("hello"),
literal("world")),
+ "table_schema_simple",
+ False,
+ False,
+ ),
+ (
+ base.NotIn(base.Reference("Bar"), literal(5), literal(2)),
+ "table_schema_simple",
+ False,
+ True,
+ ),
+ (
+ base.NotIn(base.Reference("Bar"), literal(5), literal(2)),
+ "table_schema_simple",
+ True,
+ False,
+ ),
+ (
+ base.NotEq(base.Reference("foo"), literal("hello")),
+ "table_schema_simple",
+ True,
+ True,
+ ),
+ (
+ base.NotEq(base.Reference("not_foo"), literal("hello")),
+ "table_schema_simple",
+ False,
+ False,
+ ),
+ (
+ base.NotEq(base.Reference("Bar"), literal(5)),
+ "table_schema_simple",
+ False,
+ True,
+ ),
+ (
+ base.NotEq(base.Reference("Bar"), literal(5)),
+ "table_schema_simple",
+ True,
+ False,
+ ),
+ (
+ base.Eq(base.Reference("foo"), literal("hello")),
+ "table_schema_simple",
+ True,
+ True,
+ ),
+ (
+ base.Eq(base.Reference("not_foo"), literal("hello")),
+ "table_schema_simple",
+ False,
+ False,
+ ),
+ (
+ base.Eq(base.Reference("Bar"), literal(5)),
+ "table_schema_simple",
+ False,
+ True,
+ ),
+ (
+ base.Eq(base.Reference("Bar"), literal(5)),
+ "table_schema_simple",
+ True,
+ False,
+ ),
+ (
+ base.Gt(base.Reference("foo"), literal("hello")),
+ "table_schema_simple",
+ True,
+ True,
+ ),
+ (
+ base.Gt(base.Reference("not_foo"), literal("hello")),
+ "table_schema_simple",
+ False,
+ False,
+ ),
+ (
+ base.Gt(base.Reference("Bar"), literal(5)),
+ "table_schema_simple",
+ False,
+ True,
+ ),
+ (
+ base.Gt(base.Reference("Bar"), literal(5)),
+ "table_schema_simple",
+ True,
+ False,
+ ),
+ (
+ base.Lt(base.Reference("foo"), literal("hello")),
+ "table_schema_simple",
+ True,
+ True,
+ ),
+ (
+ base.Lt(base.Reference("not_foo"), literal("hello")),
+ "table_schema_simple",
+ False,
+ False,
+ ),
+ (
+ base.Lt(base.Reference("Bar"), literal(5)),
+ "table_schema_simple",
+ False,
+ True,
+ ),
+ (
+ base.Lt(base.Reference("Bar"), literal(5)),
+ "table_schema_simple",
+ True,
+ False,
+ ),
+ (
+ base.GtEq(base.Reference("foo"), literal("hello")),
+ "table_schema_simple",
+ True,
+ True,
+ ),
+ (
+ base.GtEq(base.Reference("not_foo"), literal("hello")),
+ "table_schema_simple",
+ False,
+ False,
+ ),
+ (
+ base.GtEq(base.Reference("Bar"), literal(5)),
+ "table_schema_simple",
+ False,
+ True,
+ ),
+ (
+ base.GtEq(base.Reference("Bar"), literal(5)),
+ "table_schema_simple",
+ True,
+ False,
+ ),
+ (
+ base.LtEq(base.Reference("foo"), literal("hello")),
+ "table_schema_simple",
+ True,
+ True,
+ ),
+ (
+ base.LtEq(base.Reference("not_foo"), literal("hello")),
+ "table_schema_simple",
+ False,
+ False,
+ ),
+ (
+ base.LtEq(base.Reference("Bar"), literal(5)),
+ "table_schema_simple",
+ False,
+ True,
+ ),
+ (
+ base.LtEq(base.Reference("Bar"), literal(5)),
"table_schema_simple",
True,
False,
@@ -194,14 +362,14 @@ def test_bind(a, schema, case_sensitive, success,
request):
(ExpressionA(), ExpressionA(), ExpressionB()),
(ExpressionB(), ExpressionB(), ExpressionA()),
(
- base.In(base.Reference("foo"), (literal("hello"),
literal("world"))),
- base.In(base.Reference("foo"), (literal("hello"),
literal("world"))),
- base.In(base.Reference("not_foo"), (literal("hello"),
literal("world"))),
+ base.In(base.Reference("foo"), literal("hello"), literal("world")),
+ base.In(base.Reference("foo"), literal("hello"), literal("world")),
+ base.In(base.Reference("not_foo"), literal("hello"),
literal("world")),
),
(
- base.In(base.Reference("foo"), (literal("hello"),
literal("world"))),
- base.In(base.Reference("foo"), (literal("hello"),
literal("world"))),
- base.In(base.Reference("foo"), (literal("goodbye"),
literal("world"))),
+ base.In(base.Reference("foo"), literal("hello"), literal("world")),
+ base.In(base.Reference("foo"), literal("hello"), literal("world")),
+ base.In(base.Reference("foo"), literal("goodbye"),
literal("world")),
),
],
)
@@ -210,21 +378,39 @@ def test_eq(exp, testexpra, testexprb):
@pytest.mark.parametrize(
- "lhs, rhs, raises",
+ "lhs, rhs",
[
- (base.And(ExpressionA(), ExpressionB()), base.Or(ExpressionB(),
ExpressionA()), False),
- (base.Or(ExpressionA(), ExpressionB()), base.And(ExpressionB(),
ExpressionA()), False),
- (base.Not(ExpressionA()), ExpressionA(), False),
- (base.In(base.Reference("foo"), (literal("hello"), literal("world"))),
None, True),
- (ExpressionA(), ExpressionB(), False),
+ (
+ base.And(ExpressionA(), ExpressionB()),
+ base.Or(ExpressionB(), ExpressionA()),
+ ),
+ (
+ base.Or(ExpressionA(), ExpressionB()),
+ base.And(ExpressionB(), ExpressionA()),
+ ),
+ (
+ base.Not(ExpressionA()),
+ ExpressionA(),
+ ),
+ (
+ base.In(base.Reference("foo"), literal("hello"), literal("world")),
+ base.NotIn(base.Reference("foo"), literal("hello"),
literal("world")),
+ ),
+ (
+ base.NotIn(base.Reference("foo"), literal("hello"),
literal("world")),
+ base.In(base.Reference("foo"), literal("hello"), literal("world")),
+ ),
+ (base.Gt(base.Reference("foo"), literal(5)),
base.LtEq(base.Reference("foo"), literal(5))),
+ (base.Lt(base.Reference("foo"), literal(5)),
base.GtEq(base.Reference("foo"), literal(5))),
+ (base.Eq(base.Reference("foo"), literal(5)),
base.NotEq(base.Reference("foo"), literal(5))),
+ (
+ ExpressionA(),
+ ExpressionB(),
+ ),
],
)
-def test_negate(lhs, rhs, raises):
- if not raises:
- assert ~lhs == rhs
- else:
- with pytest.raises(TypeError):
- ~lhs # pylint: disable=W0104
+def test_negate(lhs, rhs):
+ assert ~lhs == rhs
@pytest.mark.parametrize(
@@ -400,55 +586,65 @@ def
test_always_false_or_always_true_expression_binding(table_schema_simple):
[
(
base.And(
- base.In(base.Reference("foo"), (literal("foo"),
literal("bar"))),
- base.In(base.Reference("bar"), (literal(1), literal(2),
literal(3))),
+ base.In(base.Reference("foo"), literal("foo"), literal("bar")),
+ base.In(base.Reference("bar"), literal(1), literal(2),
literal(3)),
),
base.And(
base.BoundIn[str](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("foo"), StringLiteral("bar")),
+ StringLiteral("foo"),
+ StringLiteral("bar"),
),
base.BoundIn[int](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=2, name="bar",
field_type=IntegerType(), required=True),
accessor=Accessor(position=1, inner=None),
),
- literals=(LongLiteral(1), LongLiteral(2), LongLiteral(3)),
+ LongLiteral(1),
+ LongLiteral(2),
+ LongLiteral(3),
),
),
),
(
base.And(
- base.In(base.Reference("foo"), (literal("bar"),
literal("baz"))),
- base.In(base.Reference("bar"), (literal(1),)),
- base.In(base.Reference("foo"), (literal("baz"),)),
+ base.In(base.Reference("foo"), literal("bar"), literal("baz")),
+ base.In(
+ base.Reference("bar"),
+ literal(1),
+ ),
+ base.In(
+ base.Reference("foo"),
+ literal("baz"),
+ ),
),
base.And(
base.And(
base.BoundIn[str](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("bar"), StringLiteral("baz")),
+ StringLiteral("bar"),
+ StringLiteral("baz"),
),
base.BoundIn[int](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=2, name="bar",
field_type=IntegerType(), required=True),
accessor=Accessor(position=1, inner=None),
),
- literals=(LongLiteral(1),),
+ LongLiteral(1),
),
),
base.BoundIn[str](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("baz"),),
+ StringLiteral("baz"),
),
),
),
@@ -465,55 +661,65 @@ def test_and_expression_binding(unbound_and_expression,
expected_bound_expressio
[
(
base.Or(
- base.In(base.Reference("foo"), (literal("foo"),
literal("bar"))),
- base.In(base.Reference("bar"), (literal(1), literal(2),
literal(3))),
+ base.In(base.Reference("foo"), literal("foo"), literal("bar")),
+ base.In(base.Reference("bar"), literal(1), literal(2),
literal(3)),
),
base.Or(
base.BoundIn[str](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("foo"), StringLiteral("bar")),
+ StringLiteral("foo"),
+ StringLiteral("bar"),
),
base.BoundIn[int](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=2, name="bar",
field_type=IntegerType(), required=True),
accessor=Accessor(position=1, inner=None),
),
- literals=(LongLiteral(1), LongLiteral(2), LongLiteral(3)),
+ LongLiteral(1),
+ LongLiteral(2),
+ LongLiteral(3),
),
),
),
(
base.Or(
- base.In(base.Reference("foo"), (literal("bar"),
literal("baz"))),
- base.In(base.Reference("foo"), (literal("bar"),)),
- base.In(base.Reference("foo"), (literal("baz"),)),
+ base.In(base.Reference("foo"), literal("bar"), literal("baz")),
+ base.In(
+ base.Reference("foo"),
+ literal("bar"),
+ ),
+ base.In(
+ base.Reference("foo"),
+ literal("baz"),
+ ),
),
base.Or(
base.Or(
base.BoundIn[str](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("bar"), StringLiteral("baz")),
+ StringLiteral("bar"),
+ StringLiteral("baz"),
),
base.BoundIn[str](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("bar"),),
+ StringLiteral("bar"),
),
),
base.BoundIn[str](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("baz"),),
+ StringLiteral("baz"),
),
),
),
@@ -550,33 +756,38 @@ def test_or_expression_binding(unbound_or_expression,
expected_bound_expression,
"unbound_in_expression,expected_bound_expression",
[
(
- base.In(base.Reference("foo"), (literal("foo"), literal("bar"))),
+ base.In(base.Reference("foo"), literal("foo"), literal("bar")),
base.BoundIn[str](
- term=base.BoundReference[str](
+ base.BoundReference[str](
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("foo"), StringLiteral("bar")),
+ StringLiteral("foo"),
+ StringLiteral("bar"),
),
),
(
- base.In(base.Reference("foo"), (literal("bar"), literal("baz"))),
+ base.In(base.Reference("foo"), literal("bar"), literal("baz")),
base.BoundIn[str](
- term=base.BoundReference[str](
+ base.BoundReference[str](
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("bar"), StringLiteral("baz")),
+ StringLiteral("bar"),
+ StringLiteral("baz"),
),
),
(
- base.In(base.Reference("foo"), (literal("bar"),)),
+ base.In(
+ base.Reference("foo"),
+ literal("bar"),
+ ),
base.BoundIn[str](
- term=base.BoundReference[str](
+ base.BoundReference[str](
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("bar"),),
+ StringLiteral("bar"),
),
),
],
@@ -591,39 +802,43 @@ def test_in_expression_binding(unbound_in_expression,
expected_bound_expression,
"unbound_not_expression,expected_bound_expression",
[
(
- base.Not(base.In(base.Reference("foo"), (literal("foo"),
literal("bar")))),
+ base.Not(base.In(base.Reference("foo"), literal("foo"),
literal("bar"))),
base.Not(
base.BoundIn[str](
- term=base.BoundReference[str](
+ base.BoundReference[str](
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("foo"), StringLiteral("bar")),
+ StringLiteral("foo"),
+ StringLiteral("bar"),
)
),
),
(
base.Not(
base.Or(
- base.In(base.Reference("foo"), (literal("foo"),
literal("bar"))),
- base.In(base.Reference("foo"), (literal("foo"),
literal("bar"), literal("baz"))),
+ base.In(base.Reference("foo"), literal("foo"),
literal("bar")),
+ base.In(base.Reference("foo"), literal("foo"),
literal("bar"), literal("baz")),
)
),
base.Not(
base.Or(
base.BoundIn[str](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("foo"), StringLiteral("bar")),
+ StringLiteral("foo"),
+ StringLiteral("bar"),
),
base.BoundIn[str](
- term=base.BoundReference(
+ base.BoundReference(
field=NestedField(field_id=1, name="foo",
field_type=StringType(), required=False),
accessor=Accessor(position=0, inner=None),
),
- literals=(StringLiteral("foo"), StringLiteral("bar"),
StringLiteral("baz")),
+ StringLiteral("foo"),
+ StringLiteral("bar"),
+ StringLiteral("baz"),
),
),
),