This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 50c33aa0 feat: Support Bucket and Truncate transforms on write (#1345)
50c33aa0 is described below

commit 50c33aa0119d9e2478b3865d864ec23a7c45b1d7
Author: Sung Yun <107272191+sun...@users.noreply.github.com>
AuthorDate: Thu Jan 16 10:54:37 2025 -0500

    feat: Support Bucket and Truncate transforms on write (#1345)
    
    * introduce bucket transform
    
    * include pyiceberg-core
    
    * introduce bucket transform
    
    * include pyiceberg-core
    
    * resolve poetry conflict
    
    * support truncate transforms
    
    * Remove stale comment
    
    * fix poetry hash
    
    * avoid codespell error for truncate transform
    
    * adopt nits
---
 poetry.lock                                        |  18 ++-
 pyiceberg/transforms.py                            |  39 ++++-
 pyproject.toml                                     |   6 +
 .../test_writes/test_partitioned_writes.py         | 170 +++++++++++++++++++--
 tests/test_transforms.py                           |  46 +++++-
 5 files changed, 259 insertions(+), 20 deletions(-)

diff --git a/poetry.lock b/poetry.lock
index 1d17ba6b..1c94a5f2 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -3717,6 +3717,21 @@ files = [
 [package.extras]
 windows-terminal = ["colorama (>=0.4.6)"]
 
+[[package]]
+name = "pyiceberg-core"
+version = "0.4.0"
+description = ""
+optional = true
+python-versions = "*"
+files = [
+    {file = 
"pyiceberg_core-0.4.0-cp39-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl",
 hash = 
"sha256:5aec569271c96e18428d542f9b7007117a7232c06017f95cb239d42e952ad3b4"},
+    {file = 
"pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl",
 hash = 
"sha256:5e74773e58efa4df83aba6f6265cdd41e446fa66fa4e343ca86395fed9f209ae"},
+    {file = 
"pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl",
 hash = 
"sha256:7675d21a54bf3753c740d8df78ad7efe33f438096844e479d4f3493f84830925"},
+    {file = 
"pyiceberg_core-0.4.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl",
 hash = 
"sha256:7058ad935a40b1838e4cdc5febd768878c1a51f83dca005d5a52a7fa280a2489"},
+    {file = "pyiceberg_core-0.4.0-cp39-abi3-win_amd64.whl", hash = 
"sha256:a83eb4c2307ae3dd321a9360828fb043a4add2cc9797bef0bafa20894488fb07"},
+    {file = "pyiceberg_core-0.4.0.tar.gz", hash = 
"sha256:d2e6138707868477b806ed354aee9c476e437913a331cb9ad9ad46b4054cd11f"},
+]
+
 [[package]]
 name = "pyjwt"
 version = "2.10.1"
@@ -5346,6 +5361,7 @@ glue = ["boto3", "mypy-boto3-glue"]
 hive = ["thrift"]
 pandas = ["pandas", "pyarrow"]
 pyarrow = ["pyarrow"]
+pyiceberg-core = ["pyiceberg-core"]
 ray = ["pandas", "pyarrow", "ray", "ray"]
 rest-sigv4 = ["boto3"]
 s3fs = ["s3fs"]
@@ -5357,4 +5373,4 @@ zstandard = ["zstandard"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.9, !=3.9.7"
-content-hash = 
"306213628bcc69346e14742843c8e6bccf19c2615886943c2e1482a954a388ec"
+content-hash = 
"cc789ef423714710f51e5452de7071642f4512511b1d205f77b952bb1df63a64"
diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py
index 84e1c942..22dcdfe8 100644
--- a/pyiceberg/transforms.py
+++ b/pyiceberg/transforms.py
@@ -85,6 +85,8 @@ from pyiceberg.utils.singleton import Singleton
 if TYPE_CHECKING:
     import pyarrow as pa
 
+    ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray)
+
 S = TypeVar("S")
 T = TypeVar("T")
 
@@ -193,6 +195,27 @@ class Transform(IcebergRootModel[str], ABC, Generic[S, T]):
     @abstractmethod
     def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]": ...
 
+    def _pyiceberg_transform_wrapper(
+        self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], 
*args: Any
+    ) -> Callable[["ArrayLike"], "ArrayLike"]:
+        try:
+            import pyarrow as pa
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError("For bucket/truncate transforms, PyArrow 
needs to be installed") from e
+
+        def _transform(array: "ArrayLike") -> "ArrayLike":
+            if isinstance(array, pa.Array):
+                return transform_func(array, *args)
+            elif isinstance(array, pa.ChunkedArray):
+                result_chunks = []
+                for arr in array.iterchunks():
+                    result_chunks.append(transform_func(arr, *args))
+                return pa.chunked_array(result_chunks)
+            else:
+                raise ValueError(f"PyArrow array can only be of type pa.Array 
or pa.ChunkedArray, but found {type(array)}")
+
+        return _transform
+
 
 class BucketTransform(Transform[S, int]):
     """Base Transform class to transform a value into a bucket partition value.
@@ -309,7 +332,13 @@ class BucketTransform(Transform[S, int]):
         return f"BucketTransform(num_buckets={self._num_buckets})"
 
     def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]":
-        raise NotImplementedError()
+        from pyiceberg_core import transform as pyiceberg_core_transform
+
+        return 
self._pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, 
self._num_buckets)
+
+    @property
+    def supports_pyarrow_transform(self) -> bool:
+        return True
 
 
 class TimeResolution(IntEnum):
@@ -827,7 +856,13 @@ class TruncateTransform(Transform[S, S]):
         return f"TruncateTransform(width={self._width})"
 
     def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], 
pa.Array]":
-        raise NotImplementedError()
+        from pyiceberg_core import transform as pyiceberg_core_transform
+
+        return 
self._pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, 
self._width)
+
+    @property
+    def supports_pyarrow_transform(self) -> bool:
+        return True
 
 
 @singledispatch
diff --git a/pyproject.toml b/pyproject.toml
index 4b425141..5d2808db 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -79,6 +79,7 @@ psycopg2-binary = { version = ">=2.9.6", optional = true }
 sqlalchemy = { version = "^2.0.18", optional = true }
 getdaft = { version = ">=0.2.12", optional = true }
 cachetools = "^5.5.0"
+pyiceberg-core = { version = "^0.4.0", optional = true }
 
 [tool.poetry.group.dev.dependencies]
 pytest = "7.4.4"
@@ -842,6 +843,10 @@ ignore_missing_imports = true
 module = "daft.*"
 ignore_missing_imports = true
 
+[[tool.mypy.overrides]]
+module = "pyiceberg_core.*"
+ignore_missing_imports = true
+
 [[tool.mypy.overrides]]
 module = "pyparsing.*"
 ignore_missing_imports = true
@@ -1206,6 +1211,7 @@ sql-postgres = ["sqlalchemy", "psycopg2-binary"]
 sql-sqlite = ["sqlalchemy"]
 gcsfs = ["gcsfs"]
 rest-sigv4 = ["boto3"]
+pyiceberg-core = ["pyiceberg-core"]
 
 [tool.pytest.ini_options]
 markers = [
diff --git a/tests/integration/test_writes/test_partitioned_writes.py 
b/tests/integration/test_writes/test_partitioned_writes.py
index 9e763285..1e6ea1b7 100644
--- a/tests/integration/test_writes/test_partitioned_writes.py
+++ b/tests/integration/test_writes/test_partitioned_writes.py
@@ -412,6 +412,12 @@ def 
test_dynamic_partition_overwrite_unpartitioned_evolve_to_identity_transform(
     spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: 
pa.Table, part_col: str, format_version: int
 ) -> None:
     identifier = 
f"default.unpartitioned_table_v{format_version}_evolve_into_identity_transformed_partition_field_{part_col}"
+
+    try:
+        session_catalog.drop_table(identifier=identifier)
+    except NoSuchTableError:
+        pass
+
     tbl = session_catalog.create_table(
         identifier=identifier,
         schema=TABLE_SCHEMA,
@@ -756,6 +762,55 @@ def test_invalid_arguments(spark: SparkSession, 
session_catalog: Catalog) -> Non
         tbl.append("not a df")
 
 
+@pytest.mark.integration
+@pytest.mark.parametrize(
+    "spec",
+    [
+        (PartitionSpec(PartitionField(source_id=4, field_id=1001, 
transform=TruncateTransform(2), name="int_trunc"))),
+        (PartitionSpec(PartitionField(source_id=5, field_id=1001, 
transform=TruncateTransform(2), name="long_trunc"))),
+        (PartitionSpec(PartitionField(source_id=2, field_id=1001, 
transform=TruncateTransform(2), name="string_trunc"))),
+    ],
+)
+@pytest.mark.parametrize("format_version", [1, 2])
+def test_truncate_transform(
+    spec: PartitionSpec,
+    spark: SparkSession,
+    session_catalog: Catalog,
+    arrow_table_with_null: pa.Table,
+    format_version: int,
+) -> None:
+    identifier = "default.truncate_transform"
+
+    try:
+        session_catalog.drop_table(identifier=identifier)
+    except NoSuchTableError:
+        pass
+
+    tbl = _create_table(
+        session_catalog=session_catalog,
+        identifier=identifier,
+        properties={"format-version": str(format_version)},
+        data=[arrow_table_with_null],
+        partition_spec=spec,
+    )
+
+    assert tbl.format_version == format_version, f"Expected v{format_version}, 
got: v{tbl.format_version}"
+    df = spark.table(identifier)
+    assert df.count() == 3, f"Expected 3 total rows for {identifier}"
+    for col in arrow_table_with_null.column_names:
+        assert df.where(f"{col} is not null").count() == 2, f"Expected 2 
non-null rows for {col}"
+        assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row 
for {col} is null"
+
+    assert tbl.inspect.partitions().num_rows == 3
+    files_df = spark.sql(
+        f"""
+            SELECT *
+            FROM {identifier}.files
+        """
+    )
+    assert files_df.count() == 3
+
+
 @pytest.mark.integration
 @pytest.mark.parametrize(
     "spec",
@@ -767,18 +822,52 @@ def test_invalid_arguments(spark: SparkSession, 
session_catalog: Catalog) -> Non
                 PartitionField(source_id=1, field_id=1002, 
transform=IdentityTransform(), name="bool"),
             )
         ),
-        # none of non-identity is supported
-        (PartitionSpec(PartitionField(source_id=4, field_id=1001, 
transform=BucketTransform(2), name="int_bucket"))),
-        (PartitionSpec(PartitionField(source_id=5, field_id=1001, 
transform=BucketTransform(2), name="long_bucket"))),
-        (PartitionSpec(PartitionField(source_id=10, field_id=1001, 
transform=BucketTransform(2), name="date_bucket"))),
-        (PartitionSpec(PartitionField(source_id=8, field_id=1001, 
transform=BucketTransform(2), name="timestamp_bucket"))),
-        (PartitionSpec(PartitionField(source_id=9, field_id=1001, 
transform=BucketTransform(2), name="timestamptz_bucket"))),
-        (PartitionSpec(PartitionField(source_id=2, field_id=1001, 
transform=BucketTransform(2), name="string_bucket"))),
-        (PartitionSpec(PartitionField(source_id=12, field_id=1001, 
transform=BucketTransform(2), name="fixed_bucket"))),
-        (PartitionSpec(PartitionField(source_id=11, field_id=1001, 
transform=BucketTransform(2), name="binary_bucket"))),
-        (PartitionSpec(PartitionField(source_id=4, field_id=1001, 
transform=TruncateTransform(2), name="int_trunc"))),
-        (PartitionSpec(PartitionField(source_id=5, field_id=1001, 
transform=TruncateTransform(2), name="long_trunc"))),
-        (PartitionSpec(PartitionField(source_id=2, field_id=1001, 
transform=TruncateTransform(2), name="string_trunc"))),
+    ],
+)
+@pytest.mark.parametrize("format_version", [1, 2])
+def test_identity_and_bucket_transform_spec(
+    spec: PartitionSpec,
+    spark: SparkSession,
+    session_catalog: Catalog,
+    arrow_table_with_null: pa.Table,
+    format_version: int,
+) -> None:
+    identifier = "default.identity_and_bucket_transform"
+
+    try:
+        session_catalog.drop_table(identifier=identifier)
+    except NoSuchTableError:
+        pass
+
+    tbl = _create_table(
+        session_catalog=session_catalog,
+        identifier=identifier,
+        properties={"format-version": str(format_version)},
+        data=[arrow_table_with_null],
+        partition_spec=spec,
+    )
+
+    assert tbl.format_version == format_version, f"Expected v{format_version}, 
got: v{tbl.format_version}"
+    df = spark.table(identifier)
+    assert df.count() == 3, f"Expected 3 total rows for {identifier}"
+    for col in arrow_table_with_null.column_names:
+        assert df.where(f"{col} is not null").count() == 2, f"Expected 2 
non-null rows for {col}"
+        assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row 
for {col} is null"
+
+    assert tbl.inspect.partitions().num_rows == 3
+    files_df = spark.sql(
+        f"""
+            SELECT *
+            FROM {identifier}.files
+        """
+    )
+    assert files_df.count() == 3
+
+
+@pytest.mark.integration
+@pytest.mark.parametrize(
+    "spec",
+    [
         (PartitionSpec(PartitionField(source_id=11, field_id=1001, 
transform=TruncateTransform(2), name="binary_trunc"))),
     ],
 )
@@ -801,11 +890,66 @@ def test_unsupported_transform(
 
     with pytest.raises(
         ValueError,
-        match="Not all partition types are supported for writes. Following 
partitions cannot be written using pyarrow: *",
+        match="FeatureUnsupported => Unsupported data type for truncate 
transform: LargeBinary",
     ):
         tbl.append(arrow_table_with_null)
 
 
+@pytest.mark.integration
+@pytest.mark.parametrize(
+    "spec, expected_rows",
+    [
+        (PartitionSpec(PartitionField(source_id=4, field_id=1001, 
transform=BucketTransform(2), name="int_bucket")), 3),
+        (PartitionSpec(PartitionField(source_id=5, field_id=1001, 
transform=BucketTransform(2), name="long_bucket")), 3),
+        (PartitionSpec(PartitionField(source_id=10, field_id=1001, 
transform=BucketTransform(2), name="date_bucket")), 3),
+        (PartitionSpec(PartitionField(source_id=8, field_id=1001, 
transform=BucketTransform(2), name="timestamp_bucket")), 3),
+        (PartitionSpec(PartitionField(source_id=9, field_id=1001, 
transform=BucketTransform(2), name="timestamptz_bucket")), 3),
+        (PartitionSpec(PartitionField(source_id=2, field_id=1001, 
transform=BucketTransform(2), name="string_bucket")), 3),
+        (PartitionSpec(PartitionField(source_id=12, field_id=1001, 
transform=BucketTransform(2), name="fixed_bucket")), 2),
+        (PartitionSpec(PartitionField(source_id=11, field_id=1001, 
transform=BucketTransform(2), name="binary_bucket")), 2),
+    ],
+)
+@pytest.mark.parametrize("format_version", [1, 2])
+def test_bucket_transform(
+    spark: SparkSession,
+    session_catalog: Catalog,
+    arrow_table_with_null: pa.Table,
+    spec: PartitionSpec,
+    expected_rows: int,
+    format_version: int,
+) -> None:
+    identifier = "default.bucket_transform"
+
+    try:
+        session_catalog.drop_table(identifier=identifier)
+    except NoSuchTableError:
+        pass
+
+    tbl = _create_table(
+        session_catalog=session_catalog,
+        identifier=identifier,
+        properties={"format-version": str(format_version)},
+        data=[arrow_table_with_null],
+        partition_spec=spec,
+    )
+
+    assert tbl.format_version == format_version, f"Expected v{format_version}, 
got: v{tbl.format_version}"
+    df = spark.table(identifier)
+    assert df.count() == 3, f"Expected 3 total rows for {identifier}"
+    for col in arrow_table_with_null.column_names:
+        assert df.where(f"{col} is not null").count() == 2, f"Expected 2 
non-null rows for {col}"
+        assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row 
for {col} is null"
+
+    assert tbl.inspect.partitions().num_rows == expected_rows
+    files_df = spark.sql(
+        f"""
+            SELECT *
+            FROM {identifier}.files
+        """
+    )
+    assert files_df.count() == expected_rows
+
+
 @pytest.mark.integration
 @pytest.mark.parametrize(
     "transform,expected_rows",
diff --git a/tests/test_transforms.py b/tests/test_transforms.py
index 6d04a1e4..3088719a 100644
--- a/tests/test_transforms.py
+++ b/tests/test_transforms.py
@@ -18,10 +18,11 @@
 # pylint: disable=eval-used,protected-access,redefined-outer-name
 from datetime import date
 from decimal import Decimal
-from typing import TYPE_CHECKING, Any, Callable, Optional
+from typing import Any, Callable, Optional, Union
 from uuid import UUID
 
 import mmh3 as mmh3
+import pyarrow as pa
 import pytest
 from pydantic import (
     BeforeValidator,
@@ -116,9 +117,6 @@ from pyiceberg.utils.datetime import (
     timestamptz_to_micros,
 )
 
-if TYPE_CHECKING:
-    import pyarrow as pa
-
 
 @pytest.mark.parametrize(
     "test_input,test_type,expected",
@@ -1563,3 +1561,43 @@ def test_ymd_pyarrow_transforms(
     else:
         with pytest.raises(ValueError):
             
transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col])
+
+
+@pytest.mark.parametrize(
+    "source_type, input_arr, expected, num_buckets",
+    [
+        (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 
10),
+        (
+            IntegerType(),
+            pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]),
+            pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 
0], type=pa.int32())]),
+            10,
+        ),
+        (IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 
10),
+    ],
+)
+def test_bucket_pyarrow_transforms(
+    source_type: PrimitiveType,
+    input_arr: Union[pa.Array, pa.ChunkedArray],
+    expected: Union[pa.Array, pa.ChunkedArray],
+    num_buckets: int,
+) -> None:
+    transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets)
+    assert expected == transform.pyarrow_transform(source_type)(input_arr)
+
+
+@pytest.mark.parametrize(
+    "source_type, input_arr, expected, width",
+    [
+        (StringType(), pa.array(["developer", "iceberg"]), pa.array(["dev", 
"ice"]), 3),
+        (IntegerType(), pa.array([1, -1]), pa.array([0, -10]), 10),
+    ],
+)
+def test_truncate_pyarrow_transforms(
+    source_type: PrimitiveType,
+    input_arr: Union[pa.Array, pa.ChunkedArray],
+    expected: Union[pa.Array, pa.ChunkedArray],
+    width: int,
+) -> None:
+    transform: Transform[Any, Any] = TruncateTransform(width=width)
+    assert expected == transform.pyarrow_transform(source_type)(input_arr)

Reply via email to