This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 65a03d26 Support Appends with TimeTransform Partitions (#784)
65a03d26 is described below
commit 65a03d2667ac073778b03d99d6580149a2abb326
Author: Sung Yun <[email protected]>
AuthorDate: Fri May 31 16:11:35 2024 -0400
Support Appends with TimeTransform Partitions (#784)
* checkpoint
* checkpoint2
* todo: sort with pyarrow_transform vals
* checkpoint
* checkpoint
* fix
* tests
* more tests
* adopt review feedback
* comment
* checkpoint
* checkpoint2
* todo: sort with pyarrow_transform vals
* checkpoint
* checkpoint
* fix
* tests
* more tests
* adopt review feedback
* comment
* rebase
---
pyiceberg/partitioning.py | 2 +-
pyiceberg/table/__init__.py | 67 +++----
pyiceberg/transforms.py | 99 +++++++++-
tests/conftest.py | 43 +++++
.../test_writes/test_partitioned_writes.py | 201 +++++++++++++++++++--
tests/test_transforms.py | 34 +++-
6 files changed, 392 insertions(+), 54 deletions(-)
diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py
index 481207db..da52d5df 100644
--- a/pyiceberg/partitioning.py
+++ b/pyiceberg/partitioning.py
@@ -387,7 +387,7 @@ class PartitionKey:
for raw_partition_field_value in self.raw_partition_field_values:
partition_fields =
self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
if len(partition_fields) != 1:
- raise ValueError("partition_fields must contain exactly one
field.")
+ raise ValueError(f"Cannot have redundant partitions:
{partition_fields}")
partition_field = partition_fields[0]
iceberg_typed_key_values[partition_field.name] =
partition_record_value(
partition_field=partition_field,
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index aa108de0..f160ab24 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -392,10 +392,11 @@ class Transaction:
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")
- supported_transforms = {IdentityTransform}
- if not all(type(field.transform) in supported_transforms for field in
self.table_metadata.spec().fields):
+ if unsupported_partitions := [
+ field for field in self.table_metadata.spec().fields if not
field.transform.supports_pyarrow_transform
+ ]:
raise ValueError(
- f"All transforms are not supported, expected:
{supported_transforms}, but get: {[str(field) for field in
self.table_metadata.spec().fields if field.transform not in
supported_transforms]}."
+ f"Not all partition types are supported for writes. Following
partitions cannot be written using pyarrow: {unsupported_partitions}."
)
_check_schema_compatible(self._table.schema(), other_schema=df.schema)
@@ -3643,33 +3644,6 @@ class TablePartition:
arrow_table_partition: pa.Table
-def _get_partition_sort_order(partition_columns: list[str], reverse: bool =
False) -> dict[str, Any]:
- order = "ascending" if not reverse else "descending"
- null_placement = "at_start" if reverse else "at_end"
- return {"sort_keys": [(column_name, order) for column_name in
partition_columns], "null_placement": null_placement}
-
-
-def group_by_partition_scheme(arrow_table: pa.Table, partition_columns:
list[str]) -> pa.Table:
- """Given a table, sort it by current partition scheme."""
- # only works for identity for now
- sort_options = _get_partition_sort_order(partition_columns, reverse=False)
- sorted_arrow_table =
arrow_table.sort_by(sorting=sort_options["sort_keys"],
null_placement=sort_options["null_placement"])
- return sorted_arrow_table
-
-
-def get_partition_columns(
- spec: PartitionSpec,
- schema: Schema,
-) -> list[str]:
- partition_cols = []
- for partition_field in spec.fields:
- column_name = schema.find_column_name(partition_field.source_id)
- if not column_name:
- raise ValueError(f"{partition_field=} could not be found in
{schema}.")
- partition_cols.append(column_name)
- return partition_cols
-
-
def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
@@ -3724,13 +3698,30 @@ def _determine_partitions(spec: PartitionSpec, schema:
Schema, arrow_table: pa.T
"""
import pyarrow as pa
- partition_columns = get_partition_columns(spec=spec, schema=schema)
- arrow_table = group_by_partition_scheme(arrow_table, partition_columns)
-
- reversing_sort_order_options =
_get_partition_sort_order(partition_columns, reverse=True)
- reversed_indices = pa.compute.sort_indices(arrow_table,
**reversing_sort_order_options).to_pylist()
-
- slice_instructions: list[dict[str, Any]] = []
+ partition_columns: List[Tuple[PartitionField, NestedField]] = [
+ (partition_field, schema.find_field(partition_field.source_id)) for
partition_field in spec.fields
+ ]
+ partition_values_table = pa.table({
+ str(partition.field_id):
partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
+ for partition, field in partition_columns
+ })
+
+ # Sort by partitions
+ sort_indices = pa.compute.sort_indices(
+ partition_values_table,
+ sort_keys=[(col, "ascending") for col in
partition_values_table.column_names],
+ null_placement="at_end",
+ ).to_pylist()
+ arrow_table = arrow_table.take(sort_indices)
+
+ # Get slice_instructions to group by partitions
+ partition_values_table = partition_values_table.take(sort_indices)
+ reversed_indices = pa.compute.sort_indices(
+ partition_values_table,
+ sort_keys=[(col, "descending") for col in
partition_values_table.column_names],
+ null_placement="at_start",
+ ).to_pylist()
+ slice_instructions: List[Dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
@@ -3741,6 +3732,6 @@ def _determine_partitions(spec: PartitionSpec, schema:
Schema, arrow_table: pa.T
last = reversed_indices[ptr]
ptr = ptr + group_size
- table_partitions: list[TablePartition] =
_get_table_partitions(arrow_table, spec, schema, slice_instructions)
+ table_partitions: List[TablePartition] =
_get_table_partitions(arrow_table, spec, schema, slice_instructions)
return table_partitions
diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py
index 6dcae59e..38cc6221 100644
--- a/pyiceberg/transforms.py
+++ b/pyiceberg/transforms.py
@@ -20,7 +20,7 @@ import struct
from abc import ABC, abstractmethod
from enum import IntEnum
from functools import singledispatch
-from typing import Any, Callable, Generic, Optional, TypeVar
+from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar
from typing import Literal as LiteralType
from uuid import UUID
@@ -82,6 +82,9 @@ from pyiceberg.utils.decimal import decimal_to_bytes,
truncate_decimal
from pyiceberg.utils.parsing import ParseNumberFromBrackets
from pyiceberg.utils.singleton import Singleton
+if TYPE_CHECKING:
+ import pyarrow as pa
+
S = TypeVar("S")
T = TypeVar("T")
@@ -175,6 +178,13 @@ class Transform(IcebergRootModel[str], ABC, Generic[S, T]):
return self.root == other.root
return False
+ @property
+ def supports_pyarrow_transform(self) -> bool:
+ return False
+
+ @abstractmethod
+ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array],
pa.Array]": ...
+
class BucketTransform(Transform[S, int]):
"""Base Transform class to transform a value into a bucket partition value.
@@ -290,6 +300,9 @@ class BucketTransform(Transform[S, int]):
"""Return the string representation of the BucketTransform class."""
return f"BucketTransform(num_buckets={self._num_buckets})"
+ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array],
pa.Array]":
+ raise NotImplementedError()
+
class TimeResolution(IntEnum):
YEAR = 6
@@ -349,6 +362,10 @@ class TimeTransform(Transform[S, int], Generic[S],
Singleton):
def preserves_order(self) -> bool:
return True
+ @property
+ def supports_pyarrow_transform(self) -> bool:
+ return True
+
class YearTransform(TimeTransform[S]):
"""Transforms a datetime value into a year value.
@@ -391,6 +408,21 @@ class YearTransform(TimeTransform[S]):
"""Return the string representation of the YearTransform class."""
return "YearTransform()"
+ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array],
pa.Array]":
+ import pyarrow as pa
+ import pyarrow.compute as pc
+
+ if isinstance(source, DateType):
+ epoch = datetime.EPOCH_DATE
+ elif isinstance(source, TimestampType):
+ epoch = datetime.EPOCH_TIMESTAMP
+ elif isinstance(source, TimestamptzType):
+ epoch = datetime.EPOCH_TIMESTAMPTZ
+ else:
+ raise ValueError(f"Cannot apply year transform for type: {source}")
+
+ return lambda v: pc.years_between(pa.scalar(epoch), v) if v is not
None else None
+
class MonthTransform(TimeTransform[S]):
"""Transforms a datetime value into a month value.
@@ -433,6 +465,27 @@ class MonthTransform(TimeTransform[S]):
"""Return the string representation of the MonthTransform class."""
return "MonthTransform()"
+ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array],
pa.Array]":
+ import pyarrow as pa
+ import pyarrow.compute as pc
+
+ if isinstance(source, DateType):
+ epoch = datetime.EPOCH_DATE
+ elif isinstance(source, TimestampType):
+ epoch = datetime.EPOCH_TIMESTAMP
+ elif isinstance(source, TimestamptzType):
+ epoch = datetime.EPOCH_TIMESTAMPTZ
+ else:
+ raise ValueError(f"Cannot apply month transform for type:
{source}")
+
+ def month_func(v: pa.Array) -> pa.Array:
+ return pc.add(
+ pc.multiply(pc.years_between(pa.scalar(epoch), v),
pa.scalar(12)),
+ pc.add(pc.month(v), pa.scalar(-1)),
+ )
+
+ return lambda v: month_func(v) if v is not None else None
+
class DayTransform(TimeTransform[S]):
"""Transforms a datetime value into a day value.
@@ -478,6 +531,21 @@ class DayTransform(TimeTransform[S]):
"""Return the string representation of the DayTransform class."""
return "DayTransform()"
+ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array],
pa.Array]":
+ import pyarrow as pa
+ import pyarrow.compute as pc
+
+ if isinstance(source, DateType):
+ epoch = datetime.EPOCH_DATE
+ elif isinstance(source, TimestampType):
+ epoch = datetime.EPOCH_TIMESTAMP
+ elif isinstance(source, TimestamptzType):
+ epoch = datetime.EPOCH_TIMESTAMPTZ
+ else:
+ raise ValueError(f"Cannot apply day transform for type: {source}")
+
+ return lambda v: pc.days_between(pa.scalar(epoch), v) if v is not None
else None
+
class HourTransform(TimeTransform[S]):
"""Transforms a datetime value into a hour value.
@@ -515,6 +583,19 @@ class HourTransform(TimeTransform[S]):
"""Return the string representation of the HourTransform class."""
return "HourTransform()"
+ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array],
pa.Array]":
+ import pyarrow as pa
+ import pyarrow.compute as pc
+
+ if isinstance(source, TimestampType):
+ epoch = datetime.EPOCH_TIMESTAMP
+ elif isinstance(source, TimestamptzType):
+ epoch = datetime.EPOCH_TIMESTAMPTZ
+ else:
+ raise ValueError(f"Cannot apply hour transform for type: {source}")
+
+ return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not
None else None
+
def _base64encode(buffer: bytes) -> str:
"""Convert bytes to base64 string."""
@@ -585,6 +666,13 @@ class IdentityTransform(Transform[S, S]):
"""Return the string representation of the IdentityTransform class."""
return "IdentityTransform()"
+ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array],
pa.Array]":
+ return lambda v: v
+
+ @property
+ def supports_pyarrow_transform(self) -> bool:
+ return True
+
class TruncateTransform(Transform[S, S]):
"""A transform for truncating a value to a specified width.
@@ -725,6 +813,9 @@ class TruncateTransform(Transform[S, S]):
"""Return the string representation of the TruncateTransform class."""
return f"TruncateTransform(width={self._width})"
+ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array],
pa.Array]":
+ raise NotImplementedError()
+
@singledispatch
def _human_string(value: Any, _type: IcebergType) -> str:
@@ -807,6 +898,9 @@ class UnknownTransform(Transform[S, T]):
"""Return the string representation of the UnknownTransform class."""
return f"UnknownTransform(transform={repr(self._transform)})"
+ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array],
pa.Array]":
+ raise NotImplementedError()
+
class VoidTransform(Transform[S, None], Singleton):
"""A transform that always returns None."""
@@ -835,6 +929,9 @@ class VoidTransform(Transform[S, None], Singleton):
"""Return the string representation of the VoidTransform class."""
return "VoidTransform()"
+ def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array],
pa.Array]":
+ raise NotImplementedError()
+
def _truncate_number(
name: str, pred: BoundLiteralPredicate[L], transform:
Callable[[Optional[L]], Optional[L]]
diff --git a/tests/conftest.py b/tests/conftest.py
index 01915b7d..d3f23689 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2158,3 +2158,46 @@ def arrow_table_with_only_nulls(pa_schema: "pa.Schema")
-> "pa.Table":
import pyarrow as pa
return pa.Table.from_pylist([{}, {}], schema=pa_schema)
+
+
[email protected](scope="session")
+def arrow_table_date_timestamps() -> "pa.Table":
+ """Pyarrow table with only date, timestamp and timestamptz values."""
+ import pyarrow as pa
+
+ return pa.Table.from_pydict(
+ {
+ "date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31),
date(2024, 2, 1), date(2024, 2, 1), None],
+ "timestamp": [
+ datetime(2023, 12, 31, 0, 0, 0),
+ datetime(2024, 1, 1, 0, 0, 0),
+ datetime(2024, 1, 31, 0, 0, 0),
+ datetime(2024, 2, 1, 0, 0, 0),
+ datetime(2024, 2, 1, 6, 0, 0),
+ None,
+ ],
+ "timestamptz": [
+ datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc),
+ datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
+ datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc),
+ datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc),
+ datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc),
+ None,
+ ],
+ },
+ schema=pa.schema([
+ ("date", pa.date32()),
+ ("timestamp", pa.timestamp(unit="us")),
+ ("timestamptz", pa.timestamp(unit="us", tz="UTC")),
+ ]),
+ )
+
+
[email protected](scope="session")
+def arrow_table_date_timestamps_schema() -> Schema:
+ """Pyarrow table Schema with only date, timestamp and timestamptz
values."""
+ return Schema(
+ NestedField(field_id=1, name="date", field_type=DateType(),
required=False),
+ NestedField(field_id=2, name="timestamp", field_type=TimestampType(),
required=False),
+ NestedField(field_id=3, name="timestamptz",
field_type=TimestamptzType(), required=False),
+ )
diff --git a/tests/integration/test_writes/test_partitioned_writes.py
b/tests/integration/test_writes/test_partitioned_writes.py
index 5cb03e59..76d559ca 100644
--- a/tests/integration/test_writes/test_partitioned_writes.py
+++ b/tests/integration/test_writes/test_partitioned_writes.py
@@ -16,6 +16,10 @@
# under the License.
# pylint:disable=redefined-outer-name
+
+from datetime import date
+from typing import Any, Set
+
import pyarrow as pa
import pytest
from pyspark.sql import SparkSession
@@ -23,12 +27,14 @@ from pyspark.sql import SparkSession
from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.partitioning import PartitionField, PartitionSpec
+from pyiceberg.schema import Schema
from pyiceberg.transforms import (
BucketTransform,
DayTransform,
HourTransform,
IdentityTransform,
MonthTransform,
+ Transform,
TruncateTransform,
YearTransform,
)
@@ -351,18 +357,6 @@ def test_invalid_arguments(spark: SparkSession,
session_catalog: Catalog) -> Non
(PartitionSpec(PartitionField(source_id=5, field_id=1001,
transform=TruncateTransform(2), name="long_trunc"))),
(PartitionSpec(PartitionField(source_id=2, field_id=1001,
transform=TruncateTransform(2), name="string_trunc"))),
(PartitionSpec(PartitionField(source_id=11, field_id=1001,
transform=TruncateTransform(2), name="binary_trunc"))),
- (PartitionSpec(PartitionField(source_id=8, field_id=1001,
transform=YearTransform(), name="timestamp_year"))),
- (PartitionSpec(PartitionField(source_id=9, field_id=1001,
transform=YearTransform(), name="timestamptz_year"))),
- (PartitionSpec(PartitionField(source_id=10, field_id=1001,
transform=YearTransform(), name="date_year"))),
- (PartitionSpec(PartitionField(source_id=8, field_id=1001,
transform=MonthTransform(), name="timestamp_month"))),
- (PartitionSpec(PartitionField(source_id=9, field_id=1001,
transform=MonthTransform(), name="timestamptz_month"))),
- (PartitionSpec(PartitionField(source_id=10, field_id=1001,
transform=MonthTransform(), name="date_month"))),
- (PartitionSpec(PartitionField(source_id=8, field_id=1001,
transform=DayTransform(), name="timestamp_day"))),
- (PartitionSpec(PartitionField(source_id=9, field_id=1001,
transform=DayTransform(), name="timestamptz_day"))),
- (PartitionSpec(PartitionField(source_id=10, field_id=1001,
transform=DayTransform(), name="date_day"))),
- (PartitionSpec(PartitionField(source_id=8, field_id=1001,
transform=HourTransform(), name="timestamp_hour"))),
- (PartitionSpec(PartitionField(source_id=9, field_id=1001,
transform=HourTransform(), name="timestamptz_hour"))),
- (PartitionSpec(PartitionField(source_id=10, field_id=1001,
transform=HourTransform(), name="date_hour"))),
],
)
def test_unsupported_transform(
@@ -382,5 +376,186 @@ def test_unsupported_transform(
properties={"format-version": "1"},
)
- with pytest.raises(ValueError, match="All transforms are not supported.*"):
+ with pytest.raises(
+ ValueError,
+ match="Not all partition types are supported for writes. Following
partitions cannot be written using pyarrow: *",
+ ):
tbl.append(arrow_table_with_null)
+
+
[email protected]
[email protected](
+ "transform,expected_rows",
+ [
+ pytest.param(YearTransform(), 2, id="year_transform"),
+ pytest.param(MonthTransform(), 3, id="month_transform"),
+ pytest.param(DayTransform(), 3, id="day_transform"),
+ ],
+)
[email protected]("part_col", ["date", "timestamp", "timestamptz"])
[email protected]("format_version", [1, 2])
+def test_append_ymd_transform_partitioned(
+ session_catalog: Catalog,
+ spark: SparkSession,
+ arrow_table_with_null: pa.Table,
+ transform: Transform[Any, Any],
+ expected_rows: int,
+ part_col: str,
+ format_version: int,
+) -> None:
+ # Given
+ identifier =
f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}"
+ nested_field = TABLE_SCHEMA.find_field(part_col)
+ partition_spec = PartitionSpec(
+ PartitionField(source_id=nested_field.field_id, field_id=1001,
transform=transform, name=part_col)
+ )
+
+ # When
+ tbl = _create_table(
+ session_catalog=session_catalog,
+ identifier=identifier,
+ properties={"format-version": str(format_version)},
+ data=[arrow_table_with_null],
+ partition_spec=partition_spec,
+ )
+
+ # Then
+ assert tbl.format_version == format_version, f"Expected v{format_version},
got: v{tbl.format_version}"
+ df = spark.table(identifier)
+ assert df.count() == 3, f"Expected 3 total rows for {identifier}"
+ for col in TEST_DATA_WITH_NULL.keys():
+ assert df.where(f"{col} is not null").count() == 2, f"Expected 2
non-null rows for {col}"
+ assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row
for {col} is null"
+
+ assert tbl.inspect.partitions().num_rows == expected_rows
+ files_df = spark.sql(
+ f"""
+ SELECT *
+ FROM {identifier}.files
+ """
+ )
+ assert files_df.count() == expected_rows
+
+
[email protected]
[email protected](
+ "transform,expected_partitions",
+ [
+ pytest.param(YearTransform(), {53, 54, None}, id="year_transform"),
+ pytest.param(MonthTransform(), {647, 648, 649, None},
id="month_transform"),
+ pytest.param(
+ DayTransform(), {date(2023, 12, 31), date(2024, 1, 1), date(2024,
1, 31), date(2024, 2, 1), None}, id="day_transform"
+ ),
+ pytest.param(HourTransform(), {473328, 473352, 474072, 474096, 474102,
None}, id="hour_transform"),
+ ],
+)
[email protected]("format_version", [1, 2])
+def test_append_transform_partition_verify_partitions_count(
+ session_catalog: Catalog,
+ spark: SparkSession,
+ arrow_table_date_timestamps: pa.Table,
+ arrow_table_date_timestamps_schema: Schema,
+ transform: Transform[Any, Any],
+ expected_partitions: Set[Any],
+ format_version: int,
+) -> None:
+ # Given
+ part_col = "timestamptz"
+ identifier =
f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
+ nested_field = arrow_table_date_timestamps_schema.find_field(part_col)
+ partition_spec = PartitionSpec(
+ PartitionField(source_id=nested_field.field_id, field_id=1001,
transform=transform, name=part_col),
+ )
+
+ # When
+ tbl = _create_table(
+ session_catalog=session_catalog,
+ identifier=identifier,
+ properties={"format-version": str(format_version)},
+ data=[arrow_table_date_timestamps],
+ partition_spec=partition_spec,
+ schema=arrow_table_date_timestamps_schema,
+ )
+
+ # Then
+ assert tbl.format_version == format_version, f"Expected v{format_version},
got: v{tbl.format_version}"
+ df = spark.table(identifier)
+ assert df.count() == 6, f"Expected 6 total rows for {identifier}"
+ for col in arrow_table_date_timestamps.column_names:
+ assert df.where(f"{col} is not null").count() == 5, f"Expected 2
non-null rows for {col}"
+ assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row
for {col} is null"
+
+ partitions_table = tbl.inspect.partitions()
+ assert partitions_table.num_rows == len(expected_partitions)
+ assert {part[part_col] for part in
partitions_table["partition"].to_pylist()} == expected_partitions
+ files_df = spark.sql(
+ f"""
+ SELECT *
+ FROM {identifier}.files
+ """
+ )
+ assert files_df.count() == len(expected_partitions)
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_append_multiple_partitions(
+ session_catalog: Catalog,
+ spark: SparkSession,
+ arrow_table_date_timestamps: pa.Table,
+ arrow_table_date_timestamps_schema: Schema,
+ format_version: int,
+) -> None:
+ # Given
+ identifier =
f"default.arrow_table_v{format_version}_with_multiple_partitions"
+ partition_spec = PartitionSpec(
+ PartitionField(
+
source_id=arrow_table_date_timestamps_schema.find_field("date").field_id,
+ field_id=1001,
+ transform=YearTransform(),
+ name="date_year",
+ ),
+ PartitionField(
+
source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id,
+ field_id=1000,
+ transform=HourTransform(),
+ name="timestamptz_hour",
+ ),
+ )
+
+ # When
+ tbl = _create_table(
+ session_catalog=session_catalog,
+ identifier=identifier,
+ properties={"format-version": str(format_version)},
+ data=[arrow_table_date_timestamps],
+ partition_spec=partition_spec,
+ schema=arrow_table_date_timestamps_schema,
+ )
+
+ # Then
+ assert tbl.format_version == format_version, f"Expected v{format_version},
got: v{tbl.format_version}"
+ df = spark.table(identifier)
+ assert df.count() == 6, f"Expected 6 total rows for {identifier}"
+ for col in arrow_table_date_timestamps.column_names:
+ assert df.where(f"{col} is not null").count() == 5, f"Expected 2
non-null rows for {col}"
+ assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row
for {col} is null"
+
+ partitions_table = tbl.inspect.partitions()
+ assert partitions_table.num_rows == 6
+ partitions = partitions_table["partition"].to_pylist()
+ assert {(part["date_year"], part["timestamptz_hour"]) for part in
partitions} == {
+ (53, 473328),
+ (54, 473352),
+ (54, 474072),
+ (54, 474096),
+ (54, 474102),
+ (None, None),
+ }
+ files_df = spark.sql(
+ f"""
+ SELECT *
+ FROM {identifier}.files
+ """
+ )
+ assert files_df.count() == 6
diff --git a/tests/test_transforms.py b/tests/test_transforms.py
index b8bef4b9..3a9ffd60 100644
--- a/tests/test_transforms.py
+++ b/tests/test_transforms.py
@@ -17,7 +17,7 @@
# pylint: disable=eval-used,protected-access,redefined-outer-name
from datetime import date
from decimal import Decimal
-from typing import Any, Callable, Optional
+from typing import TYPE_CHECKING, Any, Callable, Optional
from uuid import UUID
import mmh3 as mmh3
@@ -69,6 +69,7 @@ from pyiceberg.expressions.literals import (
TimestampLiteral,
literal,
)
+from pyiceberg.partitioning import _to_partition_representation
from pyiceberg.schema import Accessor
from pyiceberg.transforms import (
BucketTransform,
@@ -111,6 +112,9 @@ from pyiceberg.utils.datetime import (
timestamptz_to_micros,
)
+if TYPE_CHECKING:
+ import pyarrow as pa
+
@pytest.mark.parametrize(
"test_input,test_type,expected",
@@ -1808,3 +1812,31 @@ def test_strict_binary(bound_reference_binary:
BoundReference[str]) -> None:
_test_projection(
lhs=transform.strict_project(name="name",
pred=BoundIn(term=bound_reference_binary, literals=set_of_literals)), rhs=None
)
+
+
[email protected](
+ "transform",
+ [
+ pytest.param(YearTransform(), id="year_transform"),
+ pytest.param(MonthTransform(), id="month_transform"),
+ pytest.param(DayTransform(), id="day_transform"),
+ pytest.param(HourTransform(), id="hour_transform"),
+ ],
+)
[email protected](
+ "source_col, source_type", [("date", DateType()), ("timestamp",
TimestampType()), ("timestamptz", TimestamptzType())]
+)
+def test_ymd_pyarrow_transforms(
+ arrow_table_date_timestamps: "pa.Table",
+ source_col: str,
+ source_type: PrimitiveType,
+ transform: Transform[Any, Any],
+) -> None:
+ if transform.can_transform(source_type):
+ assert
transform.pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist()
== [
+
transform.transform(source_type)(_to_partition_representation(source_type, v))
+ for v in arrow_table_date_timestamps[source_col].to_pylist()
+ ]
+ else:
+ with pytest.raises(ValueError):
+
transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col])