This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 65a03d26  Support Appends with TimeTransform Partitions (#784)
65a03d26 is described below

commit 65a03d2667ac073778b03d99d6580149a2abb326
Author: Sung Yun <[email protected]>
AuthorDate: Fri May 31 16:11:35 2024 -0400

     Support Appends with TimeTransform Partitions (#784)
    
    * checkpoint
    
    * checkpoint2
    
    * todo: sort with pyarrow_transform vals
    
    * checkpoint
    
    * checkpoint
    
    * fix
    
    * tests
    
    * more tests
    
    * adopt review feedback
    
    * comment
    
    * checkpoint
    
    * checkpoint2
    
    * todo: sort with pyarrow_transform vals
    
    * checkpoint
    
    * checkpoint
    
    * fix
    
    * tests
    
    * more tests
    
    * adopt review feedback
    
    * comment
    
    * rebase
---
 pyiceberg/partitioning.py                          |   2 +-
 pyiceberg/table/__init__.py                        |  67 +++----
 pyiceberg/transforms.py                            |  99 +++++++++-
 tests/conftest.py                                  |  43 +++++
 .../test_writes/test_partitioned_writes.py         | 201 +++++++++++++++++++--
 tests/test_transforms.py                           |  34 +++-
 6 files changed, 392 insertions(+), 54 deletions(-)

diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py
index 481207db..da52d5df 100644
--- a/pyiceberg/partitioning.py
+++ b/pyiceberg/partitioning.py
@@ -387,7 +387,7 @@ class PartitionKey:
         for raw_partition_field_value in self.raw_partition_field_values:
             partition_fields = 
self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
             if len(partition_fields) != 1:
-                raise ValueError("partition_fields must contain exactly one 
field.")
+                raise ValueError(f"Cannot have redundant partitions: 
{partition_fields}")
             partition_field = partition_fields[0]
             iceberg_typed_key_values[partition_field.name] = 
partition_record_value(
                 partition_field=partition_field,
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index aa108de0..f160ab24 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -392,10 +392,11 @@ class Transaction:
         if not isinstance(df, pa.Table):
             raise ValueError(f"Expected PyArrow table, got: {df}")
 
-        supported_transforms = {IdentityTransform}
-        if not all(type(field.transform) in supported_transforms for field in 
self.table_metadata.spec().fields):
+        if unsupported_partitions := [
+            field for field in self.table_metadata.spec().fields if not 
field.transform.supports_pyarrow_transform
+        ]:
             raise ValueError(
-                f"All transforms are not supported, expected: 
{supported_transforms}, but get: {[str(field) for field in 
self.table_metadata.spec().fields if field.transform not in 
supported_transforms]}."
+                f"Not all partition types are supported for writes. Following 
partitions cannot be written using pyarrow: {unsupported_partitions}."
             )
 
         _check_schema_compatible(self._table.schema(), other_schema=df.schema)
@@ -3643,33 +3644,6 @@ class TablePartition:
     arrow_table_partition: pa.Table
 
 
-def _get_partition_sort_order(partition_columns: list[str], reverse: bool = 
False) -> dict[str, Any]:
-    order = "ascending" if not reverse else "descending"
-    null_placement = "at_start" if reverse else "at_end"
-    return {"sort_keys": [(column_name, order) for column_name in 
partition_columns], "null_placement": null_placement}
-
-
-def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: 
list[str]) -> pa.Table:
-    """Given a table, sort it by current partition scheme."""
-    # only works for identity for now
-    sort_options = _get_partition_sort_order(partition_columns, reverse=False)
-    sorted_arrow_table = 
arrow_table.sort_by(sorting=sort_options["sort_keys"], 
null_placement=sort_options["null_placement"])
-    return sorted_arrow_table
-
-
-def get_partition_columns(
-    spec: PartitionSpec,
-    schema: Schema,
-) -> list[str]:
-    partition_cols = []
-    for partition_field in spec.fields:
-        column_name = schema.find_column_name(partition_field.source_id)
-        if not column_name:
-            raise ValueError(f"{partition_field=} could not be found in 
{schema}.")
-        partition_cols.append(column_name)
-    return partition_cols
-
-
 def _get_table_partitions(
     arrow_table: pa.Table,
     partition_spec: PartitionSpec,
@@ -3724,13 +3698,30 @@ def _determine_partitions(spec: PartitionSpec, schema: 
Schema, arrow_table: pa.T
     """
     import pyarrow as pa
 
-    partition_columns = get_partition_columns(spec=spec, schema=schema)
-    arrow_table = group_by_partition_scheme(arrow_table, partition_columns)
-
-    reversing_sort_order_options = 
_get_partition_sort_order(partition_columns, reverse=True)
-    reversed_indices = pa.compute.sort_indices(arrow_table, 
**reversing_sort_order_options).to_pylist()
-
-    slice_instructions: list[dict[str, Any]] = []
+    partition_columns: List[Tuple[PartitionField, NestedField]] = [
+        (partition_field, schema.find_field(partition_field.source_id)) for 
partition_field in spec.fields
+    ]
+    partition_values_table = pa.table({
+        str(partition.field_id): 
partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
+        for partition, field in partition_columns
+    })
+
+    # Sort by partitions
+    sort_indices = pa.compute.sort_indices(
+        partition_values_table,
+        sort_keys=[(col, "ascending") for col in 
partition_values_table.column_names],
+        null_placement="at_end",
+    ).to_pylist()
+    arrow_table = arrow_table.take(sort_indices)
+
+    # Get slice_instructions to group by partitions
+    partition_values_table = partition_values_table.take(sort_indices)
+    reversed_indices = pa.compute.sort_indices(
+        partition_values_table,
+        sort_keys=[(col, "descending") for col in 
partition_values_table.column_names],
+        null_placement="at_start",
+    ).to_pylist()
+    slice_instructions: List[Dict[str, Any]] = []
     last = len(reversed_indices)
     reversed_indices_size = len(reversed_indices)
     ptr = 0
@@ -3741,6 +3732,6 @@ def _determine_partitions(spec: PartitionSpec, schema: 
Schema, arrow_table: pa.T
         last = reversed_indices[ptr]
         ptr = ptr + group_size
 
-    table_partitions: list[TablePartition] = 
_get_table_partitions(arrow_table, spec, schema, slice_instructions)
+    table_partitions: List[TablePartition] = 
_get_table_partitions(arrow_table, spec, schema, slice_instructions)
 
     return table_partitions
diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py
index 6dcae59e..38cc6221 100644
--- a/pyiceberg/transforms.py
+++ b/pyiceberg/transforms.py
@@ -20,7 +20,7 @@ import struct
 from abc import ABC, abstractmethod
 from enum import IntEnum
 from functools import singledispatch
-from typing import Any, Callable, Generic, Optional, TypeVar
+from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar
 from typing import Literal as LiteralType
 from uuid import UUID
 
@@ -82,6 +82,9 @@ from pyiceberg.utils.decimal import decimal_to_bytes, 
truncate_decimal
 from pyiceberg.utils.parsing import ParseNumberFromBrackets
 from pyiceberg.utils.singleton import Singleton
 
+if TYPE_CHECKING:
+    import pyarrow as pa
+
 S = TypeVar("S")
 T = TypeVar("T")
 
@@ -175,6 +178,13 @@ class Transform(IcebergRootModel[str], ABC, Generic[S, T]):
             return self.root == other.root
         return False
 
+    @property
+    def supports_pyarrow_transform(self) -> bool:
+        return False
+
+    @abstractmethod
+    def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]": ...
+
 
 class BucketTransform(Transform[S, int]):
     """Base Transform class to transform a value into a bucket partition value.
@@ -290,6 +300,9 @@ class BucketTransform(Transform[S, int]):
         """Return the string representation of the BucketTransform class."""
         return f"BucketTransform(num_buckets={self._num_buckets})"
 
+    def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]":
+        raise NotImplementedError()
+
 
 class TimeResolution(IntEnum):
     YEAR = 6
@@ -349,6 +362,10 @@ class TimeTransform(Transform[S, int], Generic[S], 
Singleton):
     def preserves_order(self) -> bool:
         return True
 
+    @property
+    def supports_pyarrow_transform(self) -> bool:
+        return True
+
 
 class YearTransform(TimeTransform[S]):
     """Transforms a datetime value into a year value.
@@ -391,6 +408,21 @@ class YearTransform(TimeTransform[S]):
         """Return the string representation of the YearTransform class."""
         return "YearTransform()"
 
+    def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]":
+        import pyarrow as pa
+        import pyarrow.compute as pc
+
+        if isinstance(source, DateType):
+            epoch = datetime.EPOCH_DATE
+        elif isinstance(source, TimestampType):
+            epoch = datetime.EPOCH_TIMESTAMP
+        elif isinstance(source, TimestamptzType):
+            epoch = datetime.EPOCH_TIMESTAMPTZ
+        else:
+            raise ValueError(f"Cannot apply year transform for type: {source}")
+
+        return lambda v: pc.years_between(pa.scalar(epoch), v) if v is not 
None else None
+
 
 class MonthTransform(TimeTransform[S]):
     """Transforms a datetime value into a month value.
@@ -433,6 +465,27 @@ class MonthTransform(TimeTransform[S]):
         """Return the string representation of the MonthTransform class."""
         return "MonthTransform()"
 
+    def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]":
+        import pyarrow as pa
+        import pyarrow.compute as pc
+
+        if isinstance(source, DateType):
+            epoch = datetime.EPOCH_DATE
+        elif isinstance(source, TimestampType):
+            epoch = datetime.EPOCH_TIMESTAMP
+        elif isinstance(source, TimestamptzType):
+            epoch = datetime.EPOCH_TIMESTAMPTZ
+        else:
+            raise ValueError(f"Cannot apply month transform for type: 
{source}")
+
+        def month_func(v: pa.Array) -> pa.Array:
+            return pc.add(
+                pc.multiply(pc.years_between(pa.scalar(epoch), v), 
pa.scalar(12)),
+                pc.add(pc.month(v), pa.scalar(-1)),
+            )
+
+        return lambda v: month_func(v) if v is not None else None
+
 
 class DayTransform(TimeTransform[S]):
     """Transforms a datetime value into a day value.
@@ -478,6 +531,21 @@ class DayTransform(TimeTransform[S]):
         """Return the string representation of the DayTransform class."""
         return "DayTransform()"
 
+    def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]":
+        import pyarrow as pa
+        import pyarrow.compute as pc
+
+        if isinstance(source, DateType):
+            epoch = datetime.EPOCH_DATE
+        elif isinstance(source, TimestampType):
+            epoch = datetime.EPOCH_TIMESTAMP
+        elif isinstance(source, TimestamptzType):
+            epoch = datetime.EPOCH_TIMESTAMPTZ
+        else:
+            raise ValueError(f"Cannot apply day transform for type: {source}")
+
+        return lambda v: pc.days_between(pa.scalar(epoch), v) if v is not None 
else None
+
 
 class HourTransform(TimeTransform[S]):
     """Transforms a datetime value into a hour value.
@@ -515,6 +583,19 @@ class HourTransform(TimeTransform[S]):
         """Return the string representation of the HourTransform class."""
         return "HourTransform()"
 
+    def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]":
+        import pyarrow as pa
+        import pyarrow.compute as pc
+
+        if isinstance(source, TimestampType):
+            epoch = datetime.EPOCH_TIMESTAMP
+        elif isinstance(source, TimestamptzType):
+            epoch = datetime.EPOCH_TIMESTAMPTZ
+        else:
+            raise ValueError(f"Cannot apply hour transform for type: {source}")
+
+        return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not 
None else None
+
 
 def _base64encode(buffer: bytes) -> str:
     """Convert bytes to base64 string."""
@@ -585,6 +666,13 @@ class IdentityTransform(Transform[S, S]):
         """Return the string representation of the IdentityTransform class."""
         return "IdentityTransform()"
 
+    def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]":
+        return lambda v: v
+
+    @property
+    def supports_pyarrow_transform(self) -> bool:
+        return True
+
 
 class TruncateTransform(Transform[S, S]):
     """A transform for truncating a value to a specified width.
@@ -725,6 +813,9 @@ class TruncateTransform(Transform[S, S]):
         """Return the string representation of the TruncateTransform class."""
         return f"TruncateTransform(width={self._width})"
 
+    def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]":
+        raise NotImplementedError()
+
 
 @singledispatch
 def _human_string(value: Any, _type: IcebergType) -> str:
@@ -807,6 +898,9 @@ class UnknownTransform(Transform[S, T]):
         """Return the string representation of the UnknownTransform class."""
         return f"UnknownTransform(transform={repr(self._transform)})"
 
+    def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]":
+        raise NotImplementedError()
+
 
 class VoidTransform(Transform[S, None], Singleton):
     """A transform that always returns None."""
@@ -835,6 +929,9 @@ class VoidTransform(Transform[S, None], Singleton):
         """Return the string representation of the VoidTransform class."""
         return "VoidTransform()"
 
+    def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]":
+        raise NotImplementedError()
+
 
 def _truncate_number(
     name: str, pred: BoundLiteralPredicate[L], transform: 
Callable[[Optional[L]], Optional[L]]
diff --git a/tests/conftest.py b/tests/conftest.py
index 01915b7d..d3f23689 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2158,3 +2158,46 @@ def arrow_table_with_only_nulls(pa_schema: "pa.Schema") 
-> "pa.Table":
     import pyarrow as pa
 
     return pa.Table.from_pylist([{}, {}], schema=pa_schema)
+
+
[email protected](scope="session")
+def arrow_table_date_timestamps() -> "pa.Table":
+    """Pyarrow table with only date, timestamp and timestamptz values."""
+    import pyarrow as pa
+
+    return pa.Table.from_pydict(
+        {
+            "date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), 
date(2024, 2, 1), date(2024, 2, 1), None],
+            "timestamp": [
+                datetime(2023, 12, 31, 0, 0, 0),
+                datetime(2024, 1, 1, 0, 0, 0),
+                datetime(2024, 1, 31, 0, 0, 0),
+                datetime(2024, 2, 1, 0, 0, 0),
+                datetime(2024, 2, 1, 6, 0, 0),
+                None,
+            ],
+            "timestamptz": [
+                datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc),
+                datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
+                datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc),
+                datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc),
+                datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc),
+                None,
+            ],
+        },
+        schema=pa.schema([
+            ("date", pa.date32()),
+            ("timestamp", pa.timestamp(unit="us")),
+            ("timestamptz", pa.timestamp(unit="us", tz="UTC")),
+        ]),
+    )
+
+
[email protected](scope="session")
+def arrow_table_date_timestamps_schema() -> Schema:
+    """Pyarrow table Schema with only date, timestamp and timestamptz 
values."""
+    return Schema(
+        NestedField(field_id=1, name="date", field_type=DateType(), 
required=False),
+        NestedField(field_id=2, name="timestamp", field_type=TimestampType(), 
required=False),
+        NestedField(field_id=3, name="timestamptz", 
field_type=TimestamptzType(), required=False),
+    )
diff --git a/tests/integration/test_writes/test_partitioned_writes.py 
b/tests/integration/test_writes/test_partitioned_writes.py
index 5cb03e59..76d559ca 100644
--- a/tests/integration/test_writes/test_partitioned_writes.py
+++ b/tests/integration/test_writes/test_partitioned_writes.py
@@ -16,6 +16,10 @@
 # under the License.
 # pylint:disable=redefined-outer-name
 
+
+from datetime import date
+from typing import Any, Set
+
 import pyarrow as pa
 import pytest
 from pyspark.sql import SparkSession
@@ -23,12 +27,14 @@ from pyspark.sql import SparkSession
 from pyiceberg.catalog import Catalog
 from pyiceberg.exceptions import NoSuchTableError
 from pyiceberg.partitioning import PartitionField, PartitionSpec
+from pyiceberg.schema import Schema
 from pyiceberg.transforms import (
     BucketTransform,
     DayTransform,
     HourTransform,
     IdentityTransform,
     MonthTransform,
+    Transform,
     TruncateTransform,
     YearTransform,
 )
@@ -351,18 +357,6 @@ def test_invalid_arguments(spark: SparkSession, 
session_catalog: Catalog) -> Non
         (PartitionSpec(PartitionField(source_id=5, field_id=1001, 
transform=TruncateTransform(2), name="long_trunc"))),
         (PartitionSpec(PartitionField(source_id=2, field_id=1001, 
transform=TruncateTransform(2), name="string_trunc"))),
         (PartitionSpec(PartitionField(source_id=11, field_id=1001, 
transform=TruncateTransform(2), name="binary_trunc"))),
-        (PartitionSpec(PartitionField(source_id=8, field_id=1001, 
transform=YearTransform(), name="timestamp_year"))),
-        (PartitionSpec(PartitionField(source_id=9, field_id=1001, 
transform=YearTransform(), name="timestamptz_year"))),
-        (PartitionSpec(PartitionField(source_id=10, field_id=1001, 
transform=YearTransform(), name="date_year"))),
-        (PartitionSpec(PartitionField(source_id=8, field_id=1001, 
transform=MonthTransform(), name="timestamp_month"))),
-        (PartitionSpec(PartitionField(source_id=9, field_id=1001, 
transform=MonthTransform(), name="timestamptz_month"))),
-        (PartitionSpec(PartitionField(source_id=10, field_id=1001, 
transform=MonthTransform(), name="date_month"))),
-        (PartitionSpec(PartitionField(source_id=8, field_id=1001, 
transform=DayTransform(), name="timestamp_day"))),
-        (PartitionSpec(PartitionField(source_id=9, field_id=1001, 
transform=DayTransform(), name="timestamptz_day"))),
-        (PartitionSpec(PartitionField(source_id=10, field_id=1001, 
transform=DayTransform(), name="date_day"))),
-        (PartitionSpec(PartitionField(source_id=8, field_id=1001, 
transform=HourTransform(), name="timestamp_hour"))),
-        (PartitionSpec(PartitionField(source_id=9, field_id=1001, 
transform=HourTransform(), name="timestamptz_hour"))),
-        (PartitionSpec(PartitionField(source_id=10, field_id=1001, 
transform=HourTransform(), name="date_hour"))),
     ],
 )
 def test_unsupported_transform(
@@ -382,5 +376,186 @@ def test_unsupported_transform(
         properties={"format-version": "1"},
     )
 
-    with pytest.raises(ValueError, match="All transforms are not supported.*"):
+    with pytest.raises(
+        ValueError,
+        match="Not all partition types are supported for writes. Following 
partitions cannot be written using pyarrow: *",
+    ):
         tbl.append(arrow_table_with_null)
+
+
[email protected]
[email protected](
+    "transform,expected_rows",
+    [
+        pytest.param(YearTransform(), 2, id="year_transform"),
+        pytest.param(MonthTransform(), 3, id="month_transform"),
+        pytest.param(DayTransform(), 3, id="day_transform"),
+    ],
+)
[email protected]("part_col", ["date", "timestamp", "timestamptz"])
[email protected]("format_version", [1, 2])
+def test_append_ymd_transform_partitioned(
+    session_catalog: Catalog,
+    spark: SparkSession,
+    arrow_table_with_null: pa.Table,
+    transform: Transform[Any, Any],
+    expected_rows: int,
+    part_col: str,
+    format_version: int,
+) -> None:
+    # Given
+    identifier = 
f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}"
+    nested_field = TABLE_SCHEMA.find_field(part_col)
+    partition_spec = PartitionSpec(
+        PartitionField(source_id=nested_field.field_id, field_id=1001, 
transform=transform, name=part_col)
+    )
+
+    # When
+    tbl = _create_table(
+        session_catalog=session_catalog,
+        identifier=identifier,
+        properties={"format-version": str(format_version)},
+        data=[arrow_table_with_null],
+        partition_spec=partition_spec,
+    )
+
+    # Then
+    assert tbl.format_version == format_version, f"Expected v{format_version}, 
got: v{tbl.format_version}"
+    df = spark.table(identifier)
+    assert df.count() == 3, f"Expected 3 total rows for {identifier}"
+    for col in TEST_DATA_WITH_NULL.keys():
+        assert df.where(f"{col} is not null").count() == 2, f"Expected 2 
non-null rows for {col}"
+        assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row 
for {col} is null"
+
+    assert tbl.inspect.partitions().num_rows == expected_rows
+    files_df = spark.sql(
+        f"""
+            SELECT *
+            FROM {identifier}.files
+        """
+    )
+    assert files_df.count() == expected_rows
+
+
[email protected]
[email protected](
+    "transform,expected_partitions",
+    [
+        pytest.param(YearTransform(), {53, 54, None}, id="year_transform"),
+        pytest.param(MonthTransform(), {647, 648, 649, None}, 
id="month_transform"),
+        pytest.param(
+            DayTransform(), {date(2023, 12, 31), date(2024, 1, 1), date(2024, 
1, 31), date(2024, 2, 1), None}, id="day_transform"
+        ),
+        pytest.param(HourTransform(), {473328, 473352, 474072, 474096, 474102, 
None}, id="hour_transform"),
+    ],
+)
[email protected]("format_version", [1, 2])
+def test_append_transform_partition_verify_partitions_count(
+    session_catalog: Catalog,
+    spark: SparkSession,
+    arrow_table_date_timestamps: pa.Table,
+    arrow_table_date_timestamps_schema: Schema,
+    transform: Transform[Any, Any],
+    expected_partitions: Set[Any],
+    format_version: int,
+) -> None:
+    # Given
+    part_col = "timestamptz"
+    identifier = 
f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
+    nested_field = arrow_table_date_timestamps_schema.find_field(part_col)
+    partition_spec = PartitionSpec(
+        PartitionField(source_id=nested_field.field_id, field_id=1001, 
transform=transform, name=part_col),
+    )
+
+    # When
+    tbl = _create_table(
+        session_catalog=session_catalog,
+        identifier=identifier,
+        properties={"format-version": str(format_version)},
+        data=[arrow_table_date_timestamps],
+        partition_spec=partition_spec,
+        schema=arrow_table_date_timestamps_schema,
+    )
+
+    # Then
+    assert tbl.format_version == format_version, f"Expected v{format_version}, 
got: v{tbl.format_version}"
+    df = spark.table(identifier)
+    assert df.count() == 6, f"Expected 6 total rows for {identifier}"
+    for col in arrow_table_date_timestamps.column_names:
+        assert df.where(f"{col} is not null").count() == 5, f"Expected 2 
non-null rows for {col}"
+        assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row 
for {col} is null"
+
+    partitions_table = tbl.inspect.partitions()
+    assert partitions_table.num_rows == len(expected_partitions)
+    assert {part[part_col] for part in 
partitions_table["partition"].to_pylist()} == expected_partitions
+    files_df = spark.sql(
+        f"""
+            SELECT *
+            FROM {identifier}.files
+        """
+    )
+    assert files_df.count() == len(expected_partitions)
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_append_multiple_partitions(
+    session_catalog: Catalog,
+    spark: SparkSession,
+    arrow_table_date_timestamps: pa.Table,
+    arrow_table_date_timestamps_schema: Schema,
+    format_version: int,
+) -> None:
+    # Given
+    identifier = 
f"default.arrow_table_v{format_version}_with_multiple_partitions"
+    partition_spec = PartitionSpec(
+        PartitionField(
+            
source_id=arrow_table_date_timestamps_schema.find_field("date").field_id,
+            field_id=1001,
+            transform=YearTransform(),
+            name="date_year",
+        ),
+        PartitionField(
+            
source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id,
+            field_id=1000,
+            transform=HourTransform(),
+            name="timestamptz_hour",
+        ),
+    )
+
+    # When
+    tbl = _create_table(
+        session_catalog=session_catalog,
+        identifier=identifier,
+        properties={"format-version": str(format_version)},
+        data=[arrow_table_date_timestamps],
+        partition_spec=partition_spec,
+        schema=arrow_table_date_timestamps_schema,
+    )
+
+    # Then
+    assert tbl.format_version == format_version, f"Expected v{format_version}, 
got: v{tbl.format_version}"
+    df = spark.table(identifier)
+    assert df.count() == 6, f"Expected 6 total rows for {identifier}"
+    for col in arrow_table_date_timestamps.column_names:
+        assert df.where(f"{col} is not null").count() == 5, f"Expected 2 
non-null rows for {col}"
+        assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row 
for {col} is null"
+
+    partitions_table = tbl.inspect.partitions()
+    assert partitions_table.num_rows == 6
+    partitions = partitions_table["partition"].to_pylist()
+    assert {(part["date_year"], part["timestamptz_hour"]) for part in 
partitions} == {
+        (53, 473328),
+        (54, 473352),
+        (54, 474072),
+        (54, 474096),
+        (54, 474102),
+        (None, None),
+    }
+    files_df = spark.sql(
+        f"""
+            SELECT *
+            FROM {identifier}.files
+        """
+    )
+    assert files_df.count() == 6
diff --git a/tests/test_transforms.py b/tests/test_transforms.py
index b8bef4b9..3a9ffd60 100644
--- a/tests/test_transforms.py
+++ b/tests/test_transforms.py
@@ -17,7 +17,7 @@
 # pylint: disable=eval-used,protected-access,redefined-outer-name
 from datetime import date
 from decimal import Decimal
-from typing import Any, Callable, Optional
+from typing import TYPE_CHECKING, Any, Callable, Optional
 from uuid import UUID
 
 import mmh3 as mmh3
@@ -69,6 +69,7 @@ from pyiceberg.expressions.literals import (
     TimestampLiteral,
     literal,
 )
+from pyiceberg.partitioning import _to_partition_representation
 from pyiceberg.schema import Accessor
 from pyiceberg.transforms import (
     BucketTransform,
@@ -111,6 +112,9 @@ from pyiceberg.utils.datetime import (
     timestamptz_to_micros,
 )
 
+if TYPE_CHECKING:
+    import pyarrow as pa
+
 
 @pytest.mark.parametrize(
     "test_input,test_type,expected",
@@ -1808,3 +1812,31 @@ def test_strict_binary(bound_reference_binary: 
BoundReference[str]) -> None:
     _test_projection(
         lhs=transform.strict_project(name="name", 
pred=BoundIn(term=bound_reference_binary, literals=set_of_literals)), rhs=None
     )
+
+
[email protected](
+    "transform",
+    [
+        pytest.param(YearTransform(), id="year_transform"),
+        pytest.param(MonthTransform(), id="month_transform"),
+        pytest.param(DayTransform(), id="day_transform"),
+        pytest.param(HourTransform(), id="hour_transform"),
+    ],
+)
[email protected](
+    "source_col, source_type", [("date", DateType()), ("timestamp", 
TimestampType()), ("timestamptz", TimestamptzType())]
+)
+def test_ymd_pyarrow_transforms(
+    arrow_table_date_timestamps: "pa.Table",
+    source_col: str,
+    source_type: PrimitiveType,
+    transform: Transform[Any, Any],
+) -> None:
+    if transform.can_transform(source_type):
+        assert 
transform.pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist()
 == [
+            
transform.transform(source_type)(_to_partition_representation(source_type, v))
+            for v in arrow_table_date_timestamps[source_col].to_pylist()
+        ]
+    else:
+        with pytest.raises(ValueError):
+            
transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col])

Reply via email to