This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new bfdd9b2ad7 Python: Implement project in Transform implementations 
(#6128)
bfdd9b2ad7 is described below

commit bfdd9b2ad72e9d54fcec6cc2fe4a8466b35d2809
Author: Fokko Driesprong <[email protected]>
AuthorDate: Tue Nov 22 00:37:27 2022 +0100

    Python: Implement project in Transform implementations (#6128)
---
 python/pyiceberg/expressions/__init__.py  |  68 ++++++
 python/pyiceberg/expressions/literals.py  |  29 +++
 python/pyiceberg/transforms.py            | 214 ++++++++++++++--
 python/tests/expressions/test_literals.py |  18 ++
 python/tests/test_transforms.py           | 389 +++++++++++++++++++++++++++++-
 5 files changed, 701 insertions(+), 17 deletions(-)

diff --git a/python/pyiceberg/expressions/__init__.py 
b/python/pyiceberg/expressions/__init__.py
index f39e15006b..5d25d33915 100644
--- a/python/pyiceberg/expressions/__init__.py
+++ b/python/pyiceberg/expressions/__init__.py
@@ -309,6 +309,11 @@ class BoundPredicate(Generic[L], Bound, BooleanExpression, 
ABC):
             return self.term == other.term
         return False
 
+    @property
+    @abstractmethod
+    def as_unbound(self) -> Type[UnboundPredicate[Any]]:
+        ...
+
 
 class UnboundPredicate(Generic[L], Unbound[BooleanExpression], 
BooleanExpression, ABC):
     term: UnboundTerm[Any]
@@ -347,6 +352,11 @@ class BoundUnaryPredicate(BoundPredicate[L], ABC):
     def __repr__(self) -> str:
         return f"{str(self.__class__.__name__)}(term={repr(self.term)})"
 
+    @property
+    @abstractmethod
+    def as_unbound(self) -> Type[UnaryPredicate]:
+        ...
+
 
 class BoundIsNull(BoundUnaryPredicate[L]):
     def __new__(cls, term: BoundTerm[L]):  # pylint: disable=W0221
@@ -357,6 +367,10 @@ class BoundIsNull(BoundUnaryPredicate[L]):
     def __invert__(self) -> BoundNotNull[L]:
         return BoundNotNull(self.term)
 
+    @property
+    def as_unbound(self) -> Type[IsNull]:
+        return IsNull
+
 
 class BoundNotNull(BoundUnaryPredicate[L]):
     def __new__(cls, term: BoundTerm[L]):  # pylint: disable=W0221
@@ -367,6 +381,10 @@ class BoundNotNull(BoundUnaryPredicate[L]):
     def __invert__(self) -> BoundIsNull[L]:
         return BoundIsNull(self.term)
 
+    @property
+    def as_unbound(self) -> Type[NotNull]:
+        return NotNull
+
 
 class IsNull(UnaryPredicate):
     def __invert__(self) -> NotNull:
@@ -396,6 +414,10 @@ class BoundIsNaN(BoundUnaryPredicate[L]):
     def __invert__(self) -> BoundNotNaN[L]:
         return BoundNotNaN(self.term)
 
+    @property
+    def as_unbound(self) -> Type[IsNaN]:
+        return IsNaN
+
 
 class BoundNotNaN(BoundUnaryPredicate[L]):
     def __new__(cls, term: BoundTerm[L]):  # pylint: disable=W0221
@@ -407,6 +429,10 @@ class BoundNotNaN(BoundUnaryPredicate[L]):
     def __invert__(self) -> BoundIsNaN[L]:
         return BoundIsNaN(self.term)
 
+    @property
+    def as_unbound(self) -> Type[NotNaN]:
+        return NotNaN
+
 
 class IsNaN(UnaryPredicate):
     def __invert__(self) -> NotNaN:
@@ -477,6 +503,11 @@ class BoundSetPredicate(BoundPredicate[L], ABC):
     def __eq__(self, other: Any) -> bool:
         return self.term == other.term and self.literals == other.literals if 
isinstance(other, BoundSetPredicate) else False
 
+    @property
+    @abstractmethod
+    def as_unbound(self) -> Type[SetPredicate[L]]:
+        ...
+
 
 class BoundIn(BoundSetPredicate[L]):
     def __new__(cls, term: BoundTerm[L], literals: Set[Literal[L]]):  # 
pylint: disable=W0221
@@ -494,6 +525,10 @@ class BoundIn(BoundSetPredicate[L]):
     def __eq__(self, other: Any) -> bool:
         return self.term == other.term and self.literals == other.literals if 
isinstance(other, BoundIn) else False
 
+    @property
+    def as_unbound(self) -> Type[In[L]]:
+        return In
+
 
 class BoundNotIn(BoundSetPredicate[L]):
     def __new__(  # pylint: disable=W0221
@@ -512,6 +547,10 @@ class BoundNotIn(BoundSetPredicate[L]):
     def __invert__(self) -> BoundIn[L]:
         return BoundIn(self.term, self.literals)
 
+    @property
+    def as_unbound(self) -> Type[NotIn[L]]:
+        return NotIn
+
 
 class In(SetPredicate[L]):
     def __new__(
@@ -601,36 +640,65 @@ class BoundLiteralPredicate(BoundPredicate[L], ABC):
     def __repr__(self) -> str:
         return f"{str(self.__class__.__name__)}(term={repr(self.term)}, 
literal={repr(self.literal)})"
 
+    @property
+    @abstractmethod
+    def as_unbound(self) -> Type[LiteralPredicate[L]]:
+        ...
+
 
 class BoundEqualTo(BoundLiteralPredicate[L]):
     def __invert__(self) -> BoundNotEqualTo[L]:
         return BoundNotEqualTo[L](self.term, self.literal)
 
+    @property
+    def as_unbound(self) -> Type[EqualTo[L]]:
+        return EqualTo
+
 
 class BoundNotEqualTo(BoundLiteralPredicate[L]):
     def __invert__(self) -> BoundEqualTo[L]:
         return BoundEqualTo[L](self.term, self.literal)
 
+    @property
+    def as_unbound(self) -> Type[NotEqualTo[L]]:
+        return NotEqualTo
+
 
 class BoundGreaterThanOrEqual(BoundLiteralPredicate[L]):
     def __invert__(self) -> BoundLessThan[L]:
         return BoundLessThan[L](self.term, self.literal)
 
+    @property
+    def as_unbound(self) -> Type[GreaterThanOrEqual[L]]:
+        return GreaterThanOrEqual[L]
+
 
 class BoundGreaterThan(BoundLiteralPredicate[L]):
     def __invert__(self) -> BoundLessThanOrEqual[L]:
         return BoundLessThanOrEqual(self.term, self.literal)
 
+    @property
+    def as_unbound(self) -> Type[GreaterThan[L]]:
+        return GreaterThan[L]
+
 
 class BoundLessThan(BoundLiteralPredicate[L]):
     def __invert__(self) -> BoundGreaterThanOrEqual[L]:
         return BoundGreaterThanOrEqual[L](self.term, self.literal)
 
+    @property
+    def as_unbound(self) -> Type[LessThan[L]]:
+        return LessThan[L]
+
 
 class BoundLessThanOrEqual(BoundLiteralPredicate[L]):
     def __invert__(self) -> BoundGreaterThan[L]:
         return BoundGreaterThan[L](self.term, self.literal)
 
+    @property
+    def as_unbound(self) -> Type[LessThanOrEqual[L]]:
+        return LessThanOrEqual[L]
+
 
 class EqualTo(LiteralPredicate[L]):
     def __invert__(self) -> NotEqualTo[L]:
diff --git a/python/pyiceberg/expressions/literals.py 
b/python/pyiceberg/expressions/literals.py
index 44ab9a15e6..c59c6bcf8d 100644
--- a/python/pyiceberg/expressions/literals.py
+++ b/python/pyiceberg/expressions/literals.py
@@ -53,6 +53,7 @@ from pyiceberg.utils.datetime import (
     timestamp_to_micros,
     timestamptz_to_micros,
 )
+from pyiceberg.utils.decimal import decimal_to_unscaled, unscaled_to_decimal
 from pyiceberg.utils.singleton import Singleton
 
 
@@ -210,6 +211,12 @@ class LongLiteral(Literal[int]):
     def to(self, type_var: IcebergType) -> Literal:  # type: ignore
         raise TypeError(f"Cannot convert LongLiteral into {type_var}")
 
+    def increment(self) -> Literal[int]:
+        return LongLiteral(self.value + 1)
+
+    def decrement(self) -> Literal[int]:
+        return LongLiteral(self.value - 1)
+
     @to.register(LongType)
     def _(self, _: LongType) -> Literal[int]:
         return self
@@ -319,6 +326,12 @@ class DateLiteral(Literal[int]):
     def __init__(self, value: int):
         super().__init__(value, int)
 
+    def increment(self) -> Literal[int]:
+        return DateLiteral(self.value + 1)
+
+    def decrement(self) -> Literal[int]:
+        return DateLiteral(self.value - 1)
+
     @singledispatchmethod
     def to(self, type_var: IcebergType) -> Literal:  # type: ignore
         raise TypeError(f"Cannot convert DateLiteral into {type_var}")
@@ -345,6 +358,12 @@ class TimestampLiteral(Literal[int]):
     def __init__(self, value: int):
         super().__init__(value, int)
 
+    def increment(self) -> Literal[int]:
+        return TimestampLiteral(self.value + 1)
+
+    def decrement(self) -> Literal[int]:
+        return TimestampLiteral(self.value - 1)
+
     @singledispatchmethod
     def to(self, type_var: IcebergType) -> Literal:  # type: ignore
         raise TypeError(f"Cannot convert TimestampLiteral into {type_var}")
@@ -362,6 +381,16 @@ class DecimalLiteral(Literal[Decimal]):
     def __init__(self, value: Decimal):
         super().__init__(value, Decimal)
 
+    def increment(self) -> Literal[Decimal]:
+        original_scale = abs(self.value.as_tuple().exponent)
+        unscaled = decimal_to_unscaled(self.value)
+        return DecimalLiteral(unscaled_to_decimal(unscaled + 1, 
original_scale))
+
+    def decrement(self) -> Literal[Decimal]:
+        original_scale = abs(self.value.as_tuple().exponent)
+        unscaled = decimal_to_unscaled(self.value)
+        return DecimalLiteral(unscaled_to_decimal(unscaled - 1, 
original_scale))
+
     @singledispatchmethod
     def to(self, type_var: IcebergType) -> Literal:  # type: ignore
         raise TypeError(f"Cannot convert DecimalLiteral into {type_var}")
diff --git a/python/pyiceberg/transforms.py b/python/pyiceberg/transforms.py
index 14d76fd8cb..79f901913c 100644
--- a/python/pyiceberg/transforms.py
+++ b/python/pyiceberg/transforms.py
@@ -20,18 +20,41 @@ import struct
 from abc import ABC, abstractmethod
 from enum import IntEnum
 from functools import singledispatch
-from typing import (
-    Any,
-    Callable,
-    Generic,
-    Literal,
-    Optional,
-    TypeVar,
-)
+from typing import Any, Callable, Generic
+from typing import Literal as LiteralType
+from typing import Optional, TypeVar
 
 import mmh3
 from pydantic import Field, PositiveInt, PrivateAttr
 
+from pyiceberg.expressions import (
+    BoundEqualTo,
+    BoundGreaterThan,
+    BoundGreaterThanOrEqual,
+    BoundIn,
+    BoundLessThan,
+    BoundLessThanOrEqual,
+    BoundLiteralPredicate,
+    BoundNotIn,
+    BoundPredicate,
+    BoundSetPredicate,
+    BoundTerm,
+    BoundUnaryPredicate,
+    EqualTo,
+    GreaterThanOrEqual,
+    LessThanOrEqual,
+    Reference,
+    UnboundPredicate,
+)
+from pyiceberg.expressions.literals import (
+    DateLiteral,
+    DecimalLiteral,
+    Literal,
+    LongLiteral,
+    TimestampLiteral,
+    literal,
+)
+from pyiceberg.typedef import L
 from pyiceberg.types import (
     BinaryType,
     DateType,
@@ -68,6 +91,11 @@ BUCKET_PARSER = ParseNumberFromBrackets(BUCKET)
 TRUNCATE_PARSER = ParseNumberFromBrackets(TRUNCATE)
 
 
+def _transform_literal(func: Callable[[L], L], lit: Literal[L]) -> Literal[L]:
+    """Small helper to upwrap the value from the literal, and wrap it again"""
+    return literal(func(lit.value))
+
+
 class Transform(IcebergBaseModel, ABC, Generic[S, T]):
     """Transform base class for concrete transforms.
 
@@ -121,6 +149,10 @@ class Transform(IcebergBaseModel, ABC, Generic[S, T]):
     def result_type(self, source: IcebergType) -> IcebergType:
         ...
 
+    @abstractmethod
+    def project(self, name: str, pred: BoundPredicate[L]) -> 
Optional[UnboundPredicate[Any]]:
+        ...
+
     @property
     def preserves_order(self) -> bool:
         return False
@@ -173,6 +205,23 @@ class BucketTransform(Transform[S, int]):
     def result_type(self, source: IcebergType) -> IcebergType:
         return IntegerType()
 
+    def project(self, name: str, pred: BoundPredicate[L]) -> 
Optional[UnboundPredicate[Any]]:
+        transformer = self.transform(pred.term.ref().field.field_type)
+
+        if isinstance(pred.term, BoundTransform):
+            return _project_transform_predicate(self, name, pred)
+        elif isinstance(pred, BoundUnaryPredicate):
+            return pred.as_unbound(Reference(name))
+        elif isinstance(pred, BoundEqualTo):
+            return pred.as_unbound(Reference(name), 
_transform_literal(transformer, pred.literal))
+        elif isinstance(pred, BoundIn):  # NotIn can't be projected
+            return pred.as_unbound(Reference(name), 
{_transform_literal(transformer, literal) for literal in pred.literals})
+        else:
+            # - Comparison predicates can't be projected, notEq can't be 
projected
+            # - Small ranges can be projected:
+            #   For example, (x > 0) and (x < 3) can be turned into in({1, 2}) 
and projected.
+            return None
+
     def can_transform(self, source: IcebergType) -> bool:
         return type(source) in {
             IntegerType,
@@ -246,9 +295,26 @@ class TimeTransform(Transform[S, int], Singleton):
     def satisfies_order_of(self, other: Transform[S, T]) -> bool:
         return self.granularity <= other.granularity if hasattr(other, 
"granularity") else False
 
-    def result_type(self, source: IcebergType) -> IcebergType:
+    def result_type(self, source: IcebergType) -> IntegerType:
         return IntegerType()
 
+    @abstractmethod
+    def transform(self, source: IcebergType) -> Callable[[Optional[Any]], 
Optional[int]]:
+        ...
+
+    def project(self, name: str, pred: BoundPredicate[L]) -> 
Optional[UnboundPredicate[Any]]:
+        transformer = self.transform(pred.term.ref().field.field_type)
+        if isinstance(pred.term, BoundTransform):
+            return _project_transform_predicate(self, name, pred)
+        elif isinstance(pred, BoundUnaryPredicate):
+            return pred.as_unbound(Reference(name))
+        elif isinstance(pred, BoundLiteralPredicate):
+            return _truncate_number(name, pred, transformer)
+        elif isinstance(pred, BoundIn):  # NotIn can't be projected
+            return _set_apply_transform(name, pred, transformer)
+        else:
+            return None
+
     @property
     def dedup_name(self) -> str:
         return "time"
@@ -267,7 +333,7 @@ class YearTransform(TimeTransform[S]):
         47
     """
 
-    __root__: Literal["year"] = Field(default="year")
+    __root__: LiteralType["year"] = Field(default="year")  # noqa: F821
 
     def transform(self, source: IcebergType) -> Callable[[Optional[S]], 
Optional[int]]:
         source_type = type(source)
@@ -313,7 +379,7 @@ class MonthTransform(TimeTransform[S]):
         575
     """
 
-    __root__: Literal["month"] = Field(default="month")
+    __root__: LiteralType["month"] = Field(default="month")  # noqa: F821
 
     def transform(self, source: IcebergType) -> Callable[[Optional[S]], 
Optional[int]]:
         source_type = type(source)
@@ -359,7 +425,7 @@ class DayTransform(TimeTransform[S]):
         17501
     """
 
-    __root__: Literal["day"] = Field(default="day")
+    __root__: LiteralType["day"] = Field(default="day")  # noqa: F821
 
     def transform(self, source: IcebergType) -> Callable[[Optional[S]], 
Optional[int]]:
         source_type = type(source)
@@ -408,7 +474,7 @@ class HourTransform(TimeTransform[S]):
         420042
     """
 
-    __root__: Literal["hour"] = Field(default="hour")
+    __root__: LiteralType["hour"] = Field(default="hour")  # noqa: F821
 
     def transform(self, source: IcebergType) -> Callable[[Optional[S]], 
Optional[int]]:
         if type(source) in {TimestampType, TimestamptzType}:
@@ -452,7 +518,7 @@ class IdentityTransform(Transform[S, S]):
         'hello-world'
     """
 
-    __root__: Literal["identity"] = Field(default="identity")
+    __root__: LiteralType["identity"] = Field(default="identity")  # noqa: F821
 
     def transform(self, source: IcebergType) -> Callable[[Optional[S]], 
Optional[S]]:
         return lambda v: v
@@ -463,6 +529,18 @@ class IdentityTransform(Transform[S, S]):
     def result_type(self, source: IcebergType) -> IcebergType:
         return source
 
+    def project(self, name: str, pred: BoundPredicate[L]) -> 
Optional[UnboundPredicate[Any]]:
+        if isinstance(pred.term, BoundTransform):
+            return _project_transform_predicate(self, name, pred)
+        elif isinstance(pred, BoundUnaryPredicate):
+            return pred.as_unbound(Reference(name))
+        elif isinstance(pred, BoundEqualTo):
+            return pred.as_unbound(Reference(name), pred.literal)
+        elif isinstance(pred, (BoundIn, BoundNotIn)):
+            return pred.as_unbound(Reference(name), pred.literals)
+        else:
+            raise ValueError(f"Could not project: {self}")
+
     @property
     def preserves_order(self) -> bool:
         return True
@@ -511,6 +589,29 @@ class TruncateTransform(Transform[S, S]):
     def source_type(self) -> IcebergType:
         return self._source_type
 
+    def project(self, name: str, pred: BoundPredicate[L]) -> 
Optional[UnboundPredicate[Any]]:
+        field_type = pred.term.ref().field.field_type
+
+        if isinstance(pred.term, BoundTransform):
+            return _project_transform_predicate(self, name, pred)
+
+        # Implement startswith and notstartswith for string (and probably 
binary)
+        # https://github.com/apache/iceberg/issues/6112
+
+        if isinstance(pred, BoundUnaryPredicate):
+            return pred.as_unbound(Reference(name))
+        elif isinstance(field_type, (IntegerType, LongType, DecimalType)):
+            if isinstance(pred, BoundLiteralPredicate):
+                return _truncate_number(name, pred, self.transform(field_type))
+            elif isinstance(pred, BoundIn):
+                return _set_apply_transform(name, pred, 
self.transform(field_type))
+        elif isinstance(field_type, (BinaryType, StringType)):
+            if isinstance(pred, BoundLiteralPredicate):
+                return _truncate_array(name, pred, self.transform(field_type))
+            elif isinstance(pred, BoundIn):
+                return _set_apply_transform(name, pred, 
self.transform(field_type))
+        return None
+
     @property
     def width(self) -> int:
         return self._width
@@ -610,7 +711,7 @@ class UnknownTransform(Transform[S, T]):
       AttributeError: If the apply method is called.
     """
 
-    __root__: Literal["unknown"] = Field(default="unknown")
+    __root__: LiteralType["unknown"] = Field(default="unknown")  # noqa: F821
     _transform: str = PrivateAttr()
 
     def __init__(self, transform: str, **data: Any):
@@ -623,9 +724,12 @@ class UnknownTransform(Transform[S, T]):
     def can_transform(self, source: IcebergType) -> bool:
         return False
 
-    def result_type(self, source: IcebergType) -> IcebergType:
+    def result_type(self, source: IcebergType) -> StringType:
         return StringType()
 
+    def project(self, name: str, pred: BoundPredicate[L]) -> 
Optional[UnboundPredicate[Any]]:
+        return None
+
     def __repr__(self) -> str:
         return f"UnknownTransform(transform={repr(self._transform)})"
 
@@ -644,8 +748,86 @@ class VoidTransform(Transform[S, None], Singleton):
     def result_type(self, source: IcebergType) -> IcebergType:
         return source
 
+    def project(self, name: str, pred: BoundPredicate[L]) -> 
Optional[UnboundPredicate[Any]]:
+        return None
+
     def to_human_string(self, _: IcebergType, value: Optional[S]) -> str:
         return "null"
 
     def __repr__(self) -> str:
         return "VoidTransform()"
+
+
+def _truncate_number(
+    name: str, pred: BoundLiteralPredicate[L], transform: 
Callable[[Optional[L]], Optional[L]]
+) -> Optional[UnboundPredicate[Any]]:
+    boundary = pred.literal
+
+    if not isinstance(boundary, (LongLiteral, DecimalLiteral, DateLiteral, 
TimestampLiteral)):
+        raise ValueError(f"Expected a numeric literal, got: {type(boundary)}")
+
+    if isinstance(pred, BoundLessThan):
+        return LessThanOrEqual(Reference(name), _transform_literal(transform, 
boundary.decrement()))  # type: ignore
+    elif isinstance(pred, BoundLessThanOrEqual):
+        return LessThanOrEqual(Reference(name), _transform_literal(transform, 
boundary))
+    elif isinstance(pred, BoundGreaterThan):
+        return GreaterThanOrEqual(Reference(name), 
_transform_literal(transform, boundary.increment()))  # type: ignore
+    elif isinstance(pred, BoundGreaterThanOrEqual):
+        return GreaterThanOrEqual(Reference(name), 
_transform_literal(transform, boundary))
+    elif isinstance(pred, BoundEqualTo):
+        return EqualTo(Reference(name), _transform_literal(transform, 
boundary))
+    else:
+        return None
+
+
+def _truncate_array(
+    name: str, pred: BoundLiteralPredicate[L], transform: 
Callable[[Optional[L]], Optional[L]]
+) -> Optional[UnboundPredicate[Any]]:
+    boundary = pred.literal
+
+    if type(pred) in {BoundLessThan, BoundLessThanOrEqual}:
+        return LessThanOrEqual(Reference(name), _transform_literal(transform, 
boundary))
+    elif type(pred) in {BoundGreaterThan, BoundGreaterThanOrEqual}:
+        return GreaterThanOrEqual(Reference(name), 
_transform_literal(transform, boundary))
+    if isinstance(pred, BoundEqualTo):
+        return EqualTo(Reference(name), _transform_literal(transform, 
boundary))
+    else:
+        return None
+
+
+def _project_transform_predicate(
+    transform: Transform[Any, Any], partition_name: str, pred: 
BoundPredicate[L]
+) -> Optional[UnboundPredicate[Any]]:
+    term = pred.term
+    if isinstance(term, BoundTransform) and transform == term.transform:
+        return _remove_transform(partition_name, pred)
+    return None
+
+
+def _remove_transform(partition_name: str, pred: BoundPredicate[L]):
+    if isinstance(pred, BoundUnaryPredicate):
+        return pred.as_unbound(Reference(partition_name))
+    elif isinstance(pred, BoundLiteralPredicate):
+        return pred.as_unbound(Reference(partition_name), pred.literal)
+    elif isinstance(pred, (BoundIn, BoundNotIn)):
+        return pred.as_unbound(Reference(partition_name), pred.literals)
+    else:
+        raise ValueError(f"Cannot replace transform in unknown predicate: 
{pred}")
+
+
+def _set_apply_transform(name: str, pred: BoundSetPredicate[L], transform: 
Callable[[L], L]) -> UnboundPredicate[Any]:
+    literals = pred.literals
+    if isinstance(pred, BoundSetPredicate):
+        return pred.as_unbound(Reference(name), {_transform_literal(transform, 
literal) for literal in literals})
+    else:
+        raise ValueError(f"Unknown BoundSetPredicate: {pred}")
+
+
+class BoundTransform(BoundTerm[L]):
+    """A transform expression"""
+
+    transform: Transform[L, Any]
+
+    def __init__(self, term: BoundTerm[L], transform: Transform[L, Any]):
+        self.term: BoundTerm[L] = term
+        self.transform = transform
diff --git a/python/tests/expressions/test_literals.py 
b/python/tests/expressions/test_literals.py
index d203195930..efcacc4574 100644
--- a/python/tests/expressions/test_literals.py
+++ b/python/tests/expressions/test_literals.py
@@ -805,6 +805,24 @@ def test_string_to_decimal_type_invalid_value():
     assert "Could not convert 18.15 into a decimal(10, 0), scales differ 0 <> 
2" in str(e.value)
 
 
+def test_decimal_literal_increment():
+    dec = DecimalLiteral(Decimal("10.123"))
+    # Twice to check that we don't mutate the value
+    assert dec.increment() == DecimalLiteral(Decimal("10.124"))
+    assert dec.increment() == DecimalLiteral(Decimal("10.124"))
+    # To check that the scale is still the same
+    assert dec.increment().value.as_tuple() == Decimal("10.124").as_tuple()
+
+
+def test_decimal_literal_dencrement():
+    dec = DecimalLiteral(Decimal("10.123"))
+    # Twice to check that we don't mutate the value
+    assert dec.decrement() == DecimalLiteral(Decimal("10.122"))
+    assert dec.decrement() == DecimalLiteral(Decimal("10.122"))
+    # To check that the scale is still the same
+    assert dec.decrement().value.as_tuple() == Decimal("10.122").as_tuple()
+
+
 #   __  __      ___
 #  |  \/  |_  _| _ \_  _
 #  | |\/| | || |  _/ || |
diff --git a/python/tests/test_transforms.py b/python/tests/test_transforms.py
index f8bc42e2dc..9395db13c2 100644
--- a/python/tests/test_transforms.py
+++ b/python/tests/test_transforms.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=eval-used,protected-access
+# pylint: disable=eval-used,protected-access,redefined-outer-name
 from datetime import date
 from decimal import Decimal
 from typing import Any
@@ -24,6 +24,31 @@ import mmh3 as mmh3
 import pytest
 
 from pyiceberg import transforms
+from pyiceberg.expressions import (
+    BoundEqualTo,
+    BoundGreaterThan,
+    BoundGreaterThanOrEqual,
+    BoundIn,
+    BoundLessThan,
+    BoundLessThanOrEqual,
+    BoundNotIn,
+    BoundNotNull,
+    BoundReference,
+    EqualTo,
+    GreaterThanOrEqual,
+    In,
+    LessThanOrEqual,
+    NotIn,
+    NotNull,
+    Reference,
+)
+from pyiceberg.expressions.literals import (
+    DateLiteral,
+    DecimalLiteral,
+    TimestampLiteral,
+    literal,
+)
+from pyiceberg.schema import Accessor
 from pyiceberg.transforms import (
     BucketTransform,
     DayTransform,
@@ -46,6 +71,7 @@ from pyiceberg.types import (
     FloatType,
     IntegerType,
     LongType,
+    NestedField,
     StringType,
     TimestampType,
     TimestamptzType,
@@ -507,3 +533,364 @@ def test_datetime_transform_str(transform, transform_str):
 )
 def test_datetime_transform_repr(transform, transform_repr):
     assert repr(transform) == transform_repr
+
+
[email protected]
+def bound_reference_str() -> BoundReference[str]:
+    return BoundReference(field=NestedField(1, "field", StringType(), 
required=False), accessor=Accessor(position=0, inner=None))
+
+
[email protected]
+def bound_reference_date() -> BoundReference[int]:
+    return BoundReference(field=NestedField(1, "field", DateType(), 
required=False), accessor=Accessor(position=0, inner=None))
+
+
[email protected]
+def bound_reference_timestamp() -> BoundReference[int]:
+    return BoundReference(
+        field=NestedField(1, "field", TimestampType(), required=False), 
accessor=Accessor(position=0, inner=None)
+    )
+
+
[email protected]
+def bound_reference_decimal() -> BoundReference[Decimal]:
+    return BoundReference(
+        field=NestedField(1, "field", DecimalType(8, 2), required=False), 
accessor=Accessor(position=0, inner=None)
+    )
+
+
[email protected]
+def bound_reference_long() -> BoundReference[int]:
+    return BoundReference(
+        field=NestedField(1, "field", DecimalType(8, 2), required=False), 
accessor=Accessor(position=0, inner=None)
+    )
+
+
+def test_projection_bucket_unary(bound_reference_str: BoundReference[str]) -> 
None:
+    assert BucketTransform(2).project("name", 
BoundNotNull(term=bound_reference_str)) == NotNull(term=Reference(name="name"))
+
+
+def test_projection_bucket_literal(bound_reference_str: BoundReference[str]) 
-> None:
+    assert BucketTransform(2).project("name", 
BoundEqualTo(term=bound_reference_str, literal=literal("data"))) == EqualTo(
+        term="name", literal=1
+    )
+
+
+def test_projection_bucket_set_same_bucket(bound_reference_str: 
BoundReference[str]) -> None:
+    assert BucketTransform(2).project(
+        "name", BoundIn(term=bound_reference_str, literals={literal("hello"), 
literal("world")})
+    ) == EqualTo(term="name", literal=1)
+
+
+def test_projection_bucket_set_in(bound_reference_str: BoundReference[str]) -> 
None:
+    assert BucketTransform(3).project(
+        "name", BoundIn(term=bound_reference_str, literals={literal("hello"), 
literal("world")})
+    ) == In(term="name", literals={1, 2})
+
+
+def test_projection_bucket_set_not_in(bound_reference_str: 
BoundReference[str]) -> None:
+    assert (
+        BucketTransform(3).project("name", 
BoundNotIn(term=bound_reference_str, literals={literal("hello"), 
literal("world")}))
+        is None
+    )
+
+
+def test_projection_year_unary(bound_reference_date: BoundReference[int]) -> 
None:
+    assert YearTransform().project("name", 
BoundNotNull(term=bound_reference_date)) == NotNull(term="name")
+
+
+def test_projection_year_literal(bound_reference_date: BoundReference[int]) -> 
None:
+    assert YearTransform().project("name", 
BoundEqualTo(term=bound_reference_date, literal=DateLiteral(1925))) == EqualTo(
+        term="name", literal=5
+    )
+
+
+def test_projection_year_set_same_year(bound_reference_date: 
BoundReference[int]) -> None:
+    assert YearTransform().project(
+        "name", BoundIn(term=bound_reference_date, 
literals={DateLiteral(1925), DateLiteral(1926)})
+    ) == EqualTo(term="name", literal=5)
+
+
+def test_projection_year_set_in(bound_reference_date: BoundReference[int]) -> 
None:
+    assert YearTransform().project(
+        "name", BoundIn(term=bound_reference_date, 
literals={DateLiteral(1925), DateLiteral(2925)})
+    ) == In(term="name", literals={8, 5})
+
+
+def test_projection_year_set_not_in(bound_reference_date: BoundReference[int]) 
-> None:
+    assert (
+        YearTransform().project("name", BoundNotIn(term=bound_reference_date, 
literals={DateLiteral(1925), DateLiteral(2925)}))
+        is None
+    )
+
+
+def test_projection_month_unary(bound_reference_date: BoundReference[int]) -> 
None:
+    assert MonthTransform().project("name", 
BoundNotNull(term=bound_reference_date)) == NotNull(term="name")
+
+
+def test_projection_month_literal(bound_reference_date: BoundReference[int]) 
-> None:
+    assert MonthTransform().project("name", 
BoundEqualTo(term=bound_reference_date, literal=DateLiteral(1925))) == EqualTo(
+        term="name", literal=63
+    )
+
+
+def test_projection_month_set_same_month(bound_reference_date: 
BoundReference[int]) -> None:
+    assert MonthTransform().project(
+        "name", BoundIn(term=bound_reference_date, 
literals={DateLiteral(1925), DateLiteral(1926)})
+    ) == EqualTo(term="name", literal=63)
+
+
+def test_projection_month_set_in(bound_reference_date: BoundReference[int]) -> 
None:
+    assert MonthTransform().project(
+        "name", BoundIn(term=bound_reference_date, 
literals={DateLiteral(1925), DateLiteral(2925)})
+    ) == In(term="name", literals={96, 63})
+
+
+def test_projection_day_month_not_in(bound_reference_date: 
BoundReference[int]) -> None:
+    assert (
+        MonthTransform().project("name", BoundNotIn(term=bound_reference_date, 
literals={DateLiteral(1925), DateLiteral(2925)}))
+        is None
+    )
+
+
+def test_projection_day_unary(bound_reference_timestamp) -> None:
+    assert DayTransform().project("name", 
BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name")
+
+
+def test_projection_day_literal(bound_reference_timestamp) -> None:
+    assert DayTransform().project(
+        "name", BoundEqualTo(term=bound_reference_timestamp, 
literal=TimestampLiteral(1667696874000))
+    ) == EqualTo(term="name", literal=19)
+
+
+def test_projection_day_set_same_day(bound_reference_timestamp) -> None:
+    assert DayTransform().project(
+        "name",
+        BoundIn(term=bound_reference_timestamp, 
literals={TimestampLiteral(1667696874001), TimestampLiteral(1667696874000)}),
+    ) == EqualTo(term="name", literal=19)
+
+
+def test_projection_day_set_in(bound_reference_timestamp) -> None:
+    assert DayTransform().project(
+        "name",
+        BoundIn(term=bound_reference_timestamp, 
literals={TimestampLiteral(1667696874001), TimestampLiteral(1567696874000)}),
+    ) == In(term="name", literals={18, 19})
+
+
+def test_projection_day_set_not_in(bound_reference_timestamp) -> None:
+    assert (
+        DayTransform().project(
+            "name",
+            BoundNotIn(term=bound_reference_timestamp, 
literals={TimestampLiteral(1567696874), TimestampLiteral(1667696874)}),
+        )
+        is None
+    )
+
+
+def test_projection_day_human(bound_reference_date: BoundReference[int]) -> 
None:
+    date_literal = DateLiteral(17532)
+    assert DayTransform().project("dt", 
BoundEqualTo(term=bound_reference_date, literal=date_literal)) == EqualTo(
+        term="dt", literal=17532
+    )  # == 2018, 1, 1
+
+    assert DayTransform().project("dt", 
BoundLessThanOrEqual(term=bound_reference_date, literal=date_literal)) == 
LessThanOrEqual(
+        term="dt", literal=17532
+    )  # <= 2018, 1, 1
+
+    assert DayTransform().project("dt", 
BoundLessThan(term=bound_reference_date, literal=date_literal)) == 
LessThanOrEqual(
+        term="dt", literal=17531
+    )  # <= 2017, 12, 31
+
+    assert DayTransform().project(
+        "dt", BoundGreaterThanOrEqual(term=bound_reference_date, 
literal=date_literal)
+    ) == GreaterThanOrEqual(
+        term="dt", literal=17532
+    )  # >= 2018, 1, 1
+
+    assert DayTransform().project("dt", 
BoundGreaterThan(term=bound_reference_date, literal=date_literal)) == 
GreaterThanOrEqual(
+        term="dt", literal=17533
+    )  # >= 2018, 1, 2
+
+
+def test_projection_hour_unary(bound_reference_timestamp) -> None:
+    assert HourTransform().project("name", 
BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name")
+
+
+TIMESTAMP_EXAMPLE = 1667696874000000  # Sun Nov 06 2022 01:07:54
+HOUR_IN_MICROSECONDS = 60 * 60 * 1000 * 1000
+
+
+def test_projection_hour_literal(bound_reference_timestamp) -> None:
+    assert HourTransform().project(
+        "name", BoundEqualTo(term=bound_reference_timestamp, 
literal=TimestampLiteral(TIMESTAMP_EXAMPLE))
+    ) == EqualTo(term="name", literal=463249)
+
+
+def test_projection_hour_set_same_hour(bound_reference_timestamp) -> None:
+    assert HourTransform().project(
+        "name",
+        BoundIn(
+            term=bound_reference_timestamp,
+            literals={TimestampLiteral(TIMESTAMP_EXAMPLE + 1), 
TimestampLiteral(TIMESTAMP_EXAMPLE)},
+        ),
+    ) == EqualTo(term="name", literal=463249)
+
+
+def test_projection_hour_set_in(bound_reference_timestamp) -> None:
+    assert HourTransform().project(
+        "name",
+        BoundIn(
+            term=bound_reference_timestamp,
+            literals={TimestampLiteral(TIMESTAMP_EXAMPLE + 
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)},
+        ),
+    ) == In(term="name", literals={463249, 463250})
+
+
+def test_projection_hour_set_not_in(bound_reference_timestamp) -> None:
+    assert (
+        HourTransform().project(
+            "name",
+            BoundNotIn(
+                term=bound_reference_timestamp,
+                literals={TimestampLiteral(TIMESTAMP_EXAMPLE + 
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)},
+            ),
+        )
+        is None
+    )
+
+
+def test_projection_identity_unary(bound_reference_timestamp) -> None:
+    assert IdentityTransform().project("name", 
BoundNotNull(term=bound_reference_timestamp)) == NotNull(term="name")
+
+
+def test_projection_identity_literal(bound_reference_timestamp) -> None:
+    assert IdentityTransform().project(
+        "name", BoundEqualTo(term=bound_reference_timestamp, 
literal=TimestampLiteral(TIMESTAMP_EXAMPLE))
+    ) == EqualTo(
+        term="name", literal=TimestampLiteral(TIMESTAMP_EXAMPLE)  # type: 
ignore
+    )
+
+
+def test_projection_identity_set_in(bound_reference_timestamp) -> None:
+    assert IdentityTransform().project(
+        "name",
+        BoundIn(
+            term=bound_reference_timestamp,
+            literals={TimestampLiteral(TIMESTAMP_EXAMPLE + 
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)},
+        ),
+    ) == In(
+        term="name", literals={TimestampLiteral(TIMESTAMP_EXAMPLE + 
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)}  # type: ignore
+    )
+
+
+def test_projection_identity_set_not_in(bound_reference_timestamp) -> None:
+    assert IdentityTransform().project(
+        "name",
+        BoundNotIn(
+            term=bound_reference_timestamp,
+            literals={TimestampLiteral(TIMESTAMP_EXAMPLE + 
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)},
+        ),
+    ) == NotIn(
+        term="name", literals={TimestampLiteral(TIMESTAMP_EXAMPLE + 
HOUR_IN_MICROSECONDS), TimestampLiteral(TIMESTAMP_EXAMPLE)}  # type: ignore
+    )
+
+
+def test_projection_truncate_string_unary(bound_reference_str: 
BoundReference[str]) -> None:
+    assert TruncateTransform(2).project("name", 
BoundNotNull(term=bound_reference_str)) == NotNull(term="name")
+
+
+def test_projection_truncate_string_literal_eq(bound_reference_str: 
BoundReference[str]) -> None:
+    assert TruncateTransform(2).project("name", 
BoundEqualTo(term=bound_reference_str, literal=literal("data"))) == EqualTo(
+        term="name", literal=literal("da")
+    )
+
+
+def test_projection_truncate_string_literal_gt(bound_reference_str: 
BoundReference[str]) -> None:
+    assert TruncateTransform(2).project("name", 
BoundGreaterThan(term=bound_reference_str, literal=literal("data"))) == EqualTo(
+        term="name", literal=literal("da")
+    )
+
+
+def test_projection_truncate_string_literal_gte(bound_reference_str: 
BoundReference[str]) -> None:
+    assert TruncateTransform(2).project(
+        "name", BoundGreaterThanOrEqual(term=bound_reference_str, 
literal=literal("data"))
+    ) == EqualTo(term="name", literal=literal("da"))
+
+
+def test_projection_truncate_string_set_same_result(bound_reference_str: 
BoundReference[str]) -> None:
+    assert TruncateTransform(2).project(
+        "name", BoundIn(term=bound_reference_str, literals={literal("hello"), 
literal("helloworld")})
+    ) == EqualTo(term="name", literal=literal("he"))
+
+
+def test_projection_truncate_string_set_in(bound_reference_str: 
BoundReference[str]) -> None:
+    assert TruncateTransform(3).project(
+        "name", BoundIn(term=bound_reference_str, literals={literal("hello"), 
literal("world")})
+    ) == In(term="name", literals={literal("hel"), literal("wor")})
+
+
+def test_projection_truncate_string_set_not_in(bound_reference_str: 
BoundReference[str]) -> None:
+    assert (
+        TruncateTransform(3).project("name", 
BoundNotIn(term=bound_reference_str, literals={literal("hello"), 
literal("world")}))
+        is None
+    )
+
+
+def test_projection_truncate_decimal_literal_eq(bound_reference_decimal: 
BoundReference[Decimal]) -> None:
+    assert TruncateTransform(2).project(
+        "name", BoundEqualTo(term=bound_reference_decimal, 
literal=DecimalLiteral(Decimal(19.25)))
+    ) == EqualTo(term="name", literal=Decimal("19.24"))
+
+
+def test_projection_truncate_decimal_literal_gt(bound_reference_decimal: 
BoundReference[Decimal]) -> None:
+    assert TruncateTransform(2).project(
+        "name", BoundGreaterThan(term=bound_reference_decimal, 
literal=DecimalLiteral(Decimal(19.25)))
+    ) == GreaterThanOrEqual(term="name", literal=Decimal("19.26"))
+
+
+def test_projection_truncate_decimal_literal_gte(bound_reference_decimal: 
BoundReference[Decimal]) -> None:
+    assert TruncateTransform(2).project(
+        "name", BoundGreaterThanOrEqual(term=bound_reference_decimal, 
literal=DecimalLiteral(Decimal(19.25)))
+    ) == GreaterThanOrEqual(term="name", literal=Decimal("19.24"))
+
+
+def test_projection_truncate_decimal_in(bound_reference_decimal: 
BoundReference[Decimal]) -> None:
+    assert TruncateTransform(2).project(
+        "name", BoundIn(term=bound_reference_decimal, 
literals={literal(Decimal(19.25)), literal(Decimal(18.15))})
+    ) == In(
+        term="name",
+        literals={
+            Decimal("19.24"),
+            Decimal("18.14999999999999857891452847979962825775146484374"),
+        },
+    )
+
+
+def test_projection_truncate_long_literal_eq(bound_reference_decimal: 
BoundReference[Decimal]) -> None:
+    assert TruncateTransform(2).project(
+        "name", BoundEqualTo(term=bound_reference_decimal, 
literal=DecimalLiteral(Decimal(19.25)))
+    ) == EqualTo(term="name", literal=Decimal("19.24"))
+
+
+def test_projection_truncate_long_literal_gt(bound_reference_decimal: 
BoundReference[Decimal]) -> None:
+    assert TruncateTransform(2).project(
+        "name", BoundGreaterThan(term=bound_reference_decimal, 
literal=DecimalLiteral(Decimal(19.25)))
+    ) == GreaterThanOrEqual(term="name", literal=Decimal("19.26"))
+
+
+def test_projection_truncate_long_literal_gte(bound_reference_decimal: 
BoundReference[Decimal]) -> None:
+    assert TruncateTransform(2).project(
+        "name", BoundGreaterThanOrEqual(term=bound_reference_decimal, 
literal=DecimalLiteral(Decimal(19.25)))
+    ) == GreaterThanOrEqual(term="name", literal=Decimal("19.24"))
+
+
+def test_projection_truncate_long_in(bound_reference_decimal: 
BoundReference[Decimal]) -> None:
+    assert TruncateTransform(2).project(
+        "name", BoundIn(term=bound_reference_decimal, 
literals={DecimalLiteral(Decimal(19.25)), DecimalLiteral(Decimal(18.15))})
+    ) == In(
+        term="name",
+        literals={
+            Decimal("19.24"),
+            Decimal("18.14999999999999857891452847979962825775146484374"),
+        },
+    )


Reply via email to