This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 4148edb5 Partitioned Append on Identity Transform (#555)
4148edb5 is described below

commit 4148edb5e28ae88024a55e0b112238e65b873957
Author: Adrian Qin <[email protected]>
AuthorDate: Fri Apr 5 03:52:16 2024 -0400

    Partitioned Append on Identity Transform (#555)
    
    * partitioned append on identity transform
    
    * remove unnecessary fixture
    
    * added null/empty table tests; fixed part of PR comments
    
    * tests for unsupported transforms; unit tests for partition slicing 
algorithm
    
    * add a comprehensive partition unit test
    
    * clean up
    
    * move common fixtures utils to utils.py and conftest
    
    * pull partitioned table fixtures into tests for more real-time feedback of 
running test
    
    * fix linting
    
    * license
    
    * save changes for swtiching codespaces
    
    * part of the comment fixes
    
    * fix one type error
    
    * add support for timetype
    
    * small fix for type hint
---
 pyiceberg/io/pyarrow.py                            |   4 +-
 pyiceberg/manifest.py                              |  27 +-
 pyiceberg/partitioning.py                          |  10 +-
 pyiceberg/table/__init__.py                        | 174 +++++++++-
 pyiceberg/typedef.py                               |   4 +
 tests/conftest.py                                  |  29 +-
 tests/integration/test_partitioning_key.py         |   2 +-
 .../test_writes/test_partitioned_writes.py         | 386 +++++++++++++++++++++
 tests/integration/{ => test_writes}/test_writes.py |  99 +-----
 tests/integration/test_writes/utils.py             |  85 +++++
 tests/table/test_init.py                           |  89 ++++-
 11 files changed, 763 insertions(+), 146 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 738cd77b..06d03e21 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -1772,7 +1772,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, 
tasks: Iterator[WriteT
     )
 
     def write_parquet(task: WriteTask) -> DataFile:
-        file_path = 
f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
+        file_path = 
f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
         fo = io.new_output(file_path)
         with fo.create(overwrite=True) as fos:
             with pq.ParquetWriter(fos, schema=arrow_file_schema, 
**parquet_writer_kwargs) as writer:
@@ -1787,7 +1787,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, 
tasks: Iterator[WriteT
             content=DataFileContent.DATA,
             file_path=file_path,
             file_format=FileFormat.PARQUET,
-            partition=Record(),
+            partition=task.partition_key.partition if task.partition_key else 
Record(),
             file_size_in_bytes=len(fo),
             # After this has been fixed:
             # https://github.com/apache/iceberg-python/issues/271
diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py
index 5277eed9..3b8138b6 100644
--- a/pyiceberg/manifest.py
+++ b/pyiceberg/manifest.py
@@ -19,7 +19,6 @@ from __future__ import annotations
 import math
 from abc import ABC, abstractmethod
 from enum import Enum
-from functools import singledispatch
 from types import TracebackType
 from typing import (
     Any,
@@ -41,8 +40,6 @@ from pyiceberg.typedef import EMPTY_DICT, Record, TableVersion
 from pyiceberg.types import (
     BinaryType,
     BooleanType,
-    DateType,
-    IcebergType,
     IntegerType,
     ListType,
     LongType,
@@ -51,9 +48,6 @@ from pyiceberg.types import (
     PrimitiveType,
     StringType,
     StructType,
-    TimestampType,
-    TimestamptzType,
-    TimeType,
 )
 
 UNASSIGNED_SEQ = -1
@@ -283,31 +277,12 @@ DATA_FILE_TYPE: Dict[int, StructType] = {
 }
 
 
-@singledispatch
-def partition_field_to_data_file_partition_field(partition_field_type: 
IcebergType) -> PrimitiveType:
-    raise TypeError(f"Unsupported partition field type: 
{partition_field_type}")
-
-
-@partition_field_to_data_file_partition_field.register(LongType)
-@partition_field_to_data_file_partition_field.register(DateType)
-@partition_field_to_data_file_partition_field.register(TimeType)
-@partition_field_to_data_file_partition_field.register(TimestampType)
-@partition_field_to_data_file_partition_field.register(TimestamptzType)
-def _(partition_field_type: PrimitiveType) -> IntegerType:
-    return IntegerType()
-
-
-@partition_field_to_data_file_partition_field.register(PrimitiveType)
-def _(partition_field_type: PrimitiveType) -> PrimitiveType:
-    return partition_field_type
-
-
 def data_file_with_partition(partition_type: StructType, format_version: 
TableVersion) -> StructType:
     data_file_partition_type = StructType(*[
         NestedField(
             field_id=field.field_id,
             name=field.name,
-            
field_type=partition_field_to_data_file_partition_field(field.field_type),
+            field_type=field.field_type,
             required=field.required,
         )
         for field in partition_type.fields
diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py
index 16f15882..a3cf2553 100644
--- a/pyiceberg/partitioning.py
+++ b/pyiceberg/partitioning.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 import uuid
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
-from datetime import date, datetime
+from datetime import date, datetime, time
 from functools import cached_property, singledispatch
 from typing import (
     Any,
@@ -62,9 +62,10 @@ from pyiceberg.types import (
     StructType,
     TimestampType,
     TimestamptzType,
+    TimeType,
     UUIDType,
 )
-from pyiceberg.utils.datetime import date_to_days, datetime_to_micros
+from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, 
time_to_micros
 
 INITIAL_PARTITION_SPEC_ID = 0
 PARTITION_FIELD_ID_START: int = 1000
@@ -431,6 +432,11 @@ def _(type: IcebergType, value: Optional[date]) -> 
Optional[int]:
     return date_to_days(value) if value is not None else None
 
 
+@_to_partition_representation.register(TimeType)
+def _(type: IcebergType, value: Optional[time]) -> Optional[int]:
+    return time_to_micros(value) if value is not None else None
+
+
 @_to_partition_representation.register(UUIDType)
 def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]:
     return str(value) if value is not None else None
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index e183d827..2dbc32d8 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -16,13 +16,13 @@
 # under the License.
 from __future__ import annotations
 
-import datetime
 import itertools
 import uuid
 import warnings
 from abc import ABC, abstractmethod
 from copy import copy
 from dataclasses import dataclass
+from datetime import datetime
 from enum import Enum
 from functools import cached_property, singledispatch
 from itertools import chain
@@ -79,6 +79,8 @@ from pyiceberg.partitioning import (
     PARTITION_FIELD_ID_START,
     UNPARTITIONED_PARTITION_SPEC,
     PartitionField,
+    PartitionFieldValue,
+    PartitionKey,
     PartitionSpec,
     _PartitionNameGenerator,
     _visit_partition_field,
@@ -373,8 +375,11 @@ class Transaction:
         if not isinstance(df, pa.Table):
             raise ValueError(f"Expected PyArrow table, got: {df}")
 
-        if len(self._table.spec().fields) > 0:
-            raise ValueError("Cannot write to partitioned tables")
+        supported_transforms = {IdentityTransform}
+        if not all(type(field.transform) in supported_transforms for field in 
self.table_metadata.spec().fields):
+            raise ValueError(
+                f"All transforms are not supported, expected: 
{supported_transforms}, but get: {[str(field) for field in 
self.table_metadata.spec().fields if field.transform not in 
supported_transforms]}."
+            )
 
         _check_schema_compatible(self._table.schema(), other_schema=df.schema)
         # cast if the two schemas are compatible but not equal
@@ -897,7 +902,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: 
TableMetadata, context: _Tabl
     if update.ref_name == MAIN_BRANCH:
         metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id
         if "last_updated_ms" not in metadata_updates:
-            metadata_updates["last_updated_ms"] = 
datetime_to_millis(datetime.datetime.now().astimezone())
+            metadata_updates["last_updated_ms"] = 
datetime_to_millis(datetime.now().astimezone())
 
         metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [
             SnapshotLogEntry(
@@ -2646,16 +2651,23 @@ def _add_and_move_fields(
 class WriteTask:
     write_uuid: uuid.UUID
     task_id: int
+    schema: Schema
     record_batches: List[pa.RecordBatch]
     sort_order_id: Optional[int] = None
-
-    # Later to be extended with partition information
+    partition_key: Optional[PartitionKey] = None
 
     def generate_data_file_filename(self, extension: str) -> str:
         # Mimics the behavior in the Java API:
         # 
https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101
         return f"00000-{self.task_id}-{self.write_uuid}.{extension}"
 
+    def generate_data_file_path(self, extension: str) -> str:
+        if self.partition_key:
+            file_path = 
f"{self.partition_key.to_path()}/{self.generate_data_file_filename(extension)}"
+            return file_path
+        else:
+            return self.generate_data_file_filename(extension)
+
 
 @dataclass(frozen=True)
 class AddFileTask:
@@ -2683,25 +2695,40 @@ def _dataframe_to_data_files(
     """
     from pyiceberg.io.pyarrow import bin_pack_arrow_table, write_file
 
-    if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 
0]) > 0:
-        raise ValueError("Cannot write to partitioned tables")
-
     counter = itertools.count(0)
     write_uuid = write_uuid or uuid.uuid4()
-
-    target_file_size = PropertyUtil.property_as_int(
+    target_file_size: int = PropertyUtil.property_as_int(  # type: ignore  # 
The property is set with non-None value.
         properties=table_metadata.properties,
         property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
         default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
     )
 
-    # This is an iter, so we don't have to materialize everything every time
-    # This will be more relevant when we start doing partitioned writes
-    yield from write_file(
-        io=io,
-        table_metadata=table_metadata,
-        tasks=iter([WriteTask(write_uuid, next(counter), batches) for batches 
in bin_pack_arrow_table(df, target_file_size)]),  # type: ignore
-    )
+    if len(table_metadata.spec().fields) > 0:
+        partitions = _determine_partitions(spec=table_metadata.spec(), 
schema=table_metadata.schema(), arrow_table=df)
+        yield from write_file(
+            io=io,
+            table_metadata=table_metadata,
+            tasks=iter([
+                WriteTask(
+                    write_uuid=write_uuid,
+                    task_id=next(counter),
+                    record_batches=batches,
+                    partition_key=partition.partition_key,
+                    schema=table_metadata.schema(),
+                )
+                for partition in partitions
+                for batches in 
bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
+            ]),
+        )
+    else:
+        yield from write_file(
+            io=io,
+            table_metadata=table_metadata,
+            tasks=iter([
+                WriteTask(write_uuid=write_uuid, task_id=next(counter), 
record_batches=batches, schema=table_metadata.schema())
+                for batches in bin_pack_arrow_table(df, target_file_size)
+            ]),
+        )
 
 
 def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: 
List[str], io: FileIO) -> Iterable[DataFile]:
@@ -3253,7 +3280,7 @@ class InspectTable:
                 additional_properties = None
 
             snapshots.append({
-                'committed_at': 
datetime.datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0),
+                'committed_at': 
datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0),
                 'snapshot_id': snapshot.snapshot_id,
                 'parent_id': snapshot.parent_snapshot_id,
                 'operation': str(operation),
@@ -3388,3 +3415,112 @@ class InspectTable:
             entries,
             schema=entries_schema,
         )
+
+
+@dataclass(frozen=True)
+class TablePartition:
+    partition_key: PartitionKey
+    arrow_table_partition: pa.Table
+
+
+def _get_partition_sort_order(partition_columns: list[str], reverse: bool = 
False) -> dict[str, Any]:
+    order = 'ascending' if not reverse else 'descending'
+    null_placement = 'at_start' if reverse else 'at_end'
+    return {'sort_keys': [(column_name, order) for column_name in 
partition_columns], 'null_placement': null_placement}
+
+
+def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: 
list[str]) -> pa.Table:
+    """Given a table, sort it by current partition scheme."""
+    # only works for identity for now
+    sort_options = _get_partition_sort_order(partition_columns, reverse=False)
+    sorted_arrow_table = 
arrow_table.sort_by(sorting=sort_options['sort_keys'], 
null_placement=sort_options['null_placement'])
+    return sorted_arrow_table
+
+
+def get_partition_columns(
+    spec: PartitionSpec,
+    schema: Schema,
+) -> list[str]:
+    partition_cols = []
+    for partition_field in spec.fields:
+        column_name = schema.find_column_name(partition_field.source_id)
+        if not column_name:
+            raise ValueError(f"{partition_field=} could not be found in 
{schema}.")
+        partition_cols.append(column_name)
+    return partition_cols
+
+
+def _get_table_partitions(
+    arrow_table: pa.Table,
+    partition_spec: PartitionSpec,
+    schema: Schema,
+    slice_instructions: list[dict[str, Any]],
+) -> list[TablePartition]:
+    sorted_slice_instructions = sorted(slice_instructions, key=lambda x: 
x['offset'])
+
+    partition_fields = partition_spec.fields
+
+    offsets = [inst["offset"] for inst in sorted_slice_instructions]
+    projected_and_filtered = {
+        partition_field.source_id: 
arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
+        .take(offsets)
+        .to_pylist()
+        for partition_field in partition_fields
+    }
+
+    table_partitions = []
+    for idx, inst in enumerate(sorted_slice_instructions):
+        partition_slice = arrow_table.slice(**inst)
+        fieldvalues = [
+            PartitionFieldValue(partition_field, 
projected_and_filtered[partition_field.source_id][idx])
+            for partition_field in partition_fields
+        ]
+        partition_key = PartitionKey(raw_partition_field_values=fieldvalues, 
partition_spec=partition_spec, schema=schema)
+        table_partitions.append(TablePartition(partition_key=partition_key, 
arrow_table_partition=partition_slice))
+    return table_partitions
+
+
+def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: 
pa.Table) -> List[TablePartition]:
+    """Based on the iceberg table partition spec, slice the arrow table into 
partitions with their keys.
+
+    Example:
+    Input:
+    An arrow table with partition key of ['n_legs', 'year'] and with data of
+    {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
+     'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
+     'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", 
"Horse","Brittle stars", "Centipede"]}.
+    The algrithm:
+    Firstly we group the rows into partitions by sorting with sort order 
[('n_legs', 'descending'), ('year', 'descending')]
+    and null_placement of "at_end".
+    This gives the same table as raw input.
+    Then we sort_indices using reverse order of [('n_legs', 'descending'), 
('year', 'descending')]
+    and null_placement : "at_start".
+    This gives:
+    [8, 7, 4, 5, 6, 3, 1, 2, 0]
+    Based on this we get partition groups of indices:
+    [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 
'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, 
{'offset': 0, 'length': 1}]
+    We then retrieve the partition keys by offsets.
+    And slice the arrow table by offsets and lengths of each partition.
+    """
+    import pyarrow as pa
+
+    partition_columns = get_partition_columns(spec=spec, schema=schema)
+    arrow_table = group_by_partition_scheme(arrow_table, partition_columns)
+
+    reversing_sort_order_options = 
_get_partition_sort_order(partition_columns, reverse=True)
+    reversed_indices = pa.compute.sort_indices(arrow_table, 
**reversing_sort_order_options).to_pylist()
+
+    slice_instructions: list[dict[str, Any]] = []
+    last = len(reversed_indices)
+    reversed_indices_size = len(reversed_indices)
+    ptr = 0
+    while ptr < reversed_indices_size:
+        group_size = last - reversed_indices[ptr]
+        offset = reversed_indices[ptr]
+        slice_instructions.append({"offset": offset, "length": group_size})
+        last = reversed_indices[ptr]
+        ptr = ptr + group_size
+
+    table_partitions: list[TablePartition] = 
_get_table_partitions(arrow_table, spec, schema, slice_instructions)
+
+    return table_partitions
diff --git a/pyiceberg/typedef.py b/pyiceberg/typedef.py
index 4bed386c..6ccf9526 100644
--- a/pyiceberg/typedef.py
+++ b/pyiceberg/typedef.py
@@ -202,5 +202,9 @@ class Record(StructProtocol):
         """Return values of all the fields of the Record class except those 
specified in skip_fields."""
         return [self.__getattribute__(v) if hasattr(self, v) else None for v 
in self._position_to_field_name]
 
+    def __hash__(self) -> int:
+        """Return hash value of the Record class."""
+        return hash(str(self))
+
 
 TableVersion: TypeAlias = Literal[1, 2]
diff --git a/tests/conftest.py b/tests/conftest.py
index aa09517b..4a820fed 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -30,7 +30,7 @@ import re
 import socket
 import string
 import uuid
-from datetime import date, datetime
+from datetime import date, datetime, timezone
 from pathlib import Path
 from random import choice
 from tempfile import TemporaryDirectory
@@ -1999,8 +1999,13 @@ TEST_DATA_WITH_NULL = {
     'long': [1, None, 9],
     'float': [0.0, None, 0.9],
     'double': [0.0, None, 0.9],
+    # 'time': [1_000_000, None, 3_000_000],  # Example times: 1s, none, and 3s 
past midnight #Spark does not support time fields
     'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 
19, 25, 00)],
-    'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 
1, 19, 25, 00)],
+    'timestamptz': [
+        datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+        None,
+        datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+    ],
     'date': [date(2023, 1, 1), None, date(2023, 3, 1)],
     # Not supported by Spark
     # 'time': [time(1, 22, 0), None, time(19, 25, 0)],
@@ -2027,6 +2032,8 @@ def pa_schema() -> "pa.Schema":
         ("long", pa.int64()),
         ("float", pa.float32()),
         ("double", pa.float64()),
+        # Not supported by Spark
+        # ("time", pa.time64('us')),
         ("timestamp", pa.timestamp(unit="us")),
         ("timestamptz", pa.timestamp(unit="us", tz="UTC")),
         ("date", pa.date32()),
@@ -2041,7 +2048,23 @@ def pa_schema() -> "pa.Schema":
 
 @pytest.fixture(scope="session")
 def arrow_table_with_null(pa_schema: "pa.Schema") -> "pa.Table":
+    """Pyarrow table with all kinds of columns."""
     import pyarrow as pa
 
-    """Pyarrow table with all kinds of columns."""
     return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema)
+
+
[email protected](scope="session")
+def arrow_table_without_data(pa_schema: "pa.Schema") -> "pa.Table":
+    """Pyarrow table without data."""
+    import pyarrow as pa
+
+    return pa.Table.from_pylist([], schema=pa_schema)
+
+
[email protected](scope="session")
+def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table":
+    """Pyarrow table with only null values."""
+    import pyarrow as pa
+
+    return pa.Table.from_pylist([{}, {}], schema=pa_schema)
diff --git a/tests/integration/test_partitioning_key.py 
b/tests/integration/test_partitioning_key.py
index 12056bac..d89ecaf2 100644
--- a/tests/integration/test_partitioning_key.py
+++ b/tests/integration/test_partitioning_key.py
@@ -749,7 +749,7 @@ def test_partition_key(
     # key.to_path() generates the hive partitioning part of the to-write 
parquet file path
     assert key.to_path() == expected_hive_partition_path_slice
 
-    # Justify expected values are not made up but conform to spark behaviors
+    # Justify expected values are not made up but conforming to spark behaviors
     if spark_create_table_sql_for_justification is not None and 
spark_data_insert_sql_for_justification is not None:
         try:
             spark.sql(f"drop table {identifier}")
diff --git a/tests/integration/test_writes/test_partitioned_writes.py 
b/tests/integration/test_writes/test_partitioned_writes.py
new file mode 100644
index 00000000..d84b9745
--- /dev/null
+++ b/tests/integration/test_writes/test_partitioned_writes.py
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint:disable=redefined-outer-name
+
+import pyarrow as pa
+import pytest
+from pyspark.sql import SparkSession
+
+from pyiceberg.catalog import Catalog
+from pyiceberg.exceptions import NoSuchTableError
+from pyiceberg.partitioning import PartitionField, PartitionSpec
+from pyiceberg.transforms import (
+    BucketTransform,
+    DayTransform,
+    HourTransform,
+    IdentityTransform,
+    MonthTransform,
+    TruncateTransform,
+    YearTransform,
+)
+from tests.conftest import TEST_DATA_WITH_NULL
+from utils import TABLE_SCHEMA, _create_table
+
+
[email protected]
[email protected](
+    "part_col", ['int', 'bool', 'string', "string_long", "long", "float", 
"double", "date", 'timestamp', 'timestamptz', 'binary']
+)
[email protected]("format_version", [1, 2])
+def test_query_filter_null_partitioned(
+    session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: 
pa.Table, part_col: str, format_version: int
+) -> None:
+    # Given
+    identifier = 
f"default.arrow_table_v{format_version}_with_null_partitioned_on_col_{part_col}"
+    nested_field = TABLE_SCHEMA.find_field(part_col)
+    partition_spec = PartitionSpec(
+        PartitionField(source_id=nested_field.field_id, field_id=1001, 
transform=IdentityTransform(), name=part_col)
+    )
+
+    # When
+    tbl = _create_table(
+        session_catalog=session_catalog,
+        identifier=identifier,
+        properties={"format-version": str(format_version)},
+        data=[arrow_table_with_null],
+        partition_spec=partition_spec,
+    )
+
+    # Then
+    assert tbl.format_version == format_version, f"Expected v{format_version}, 
got: v{tbl.format_version}"
+    df = spark.table(identifier)
+    assert df.count() == 3, f"Expected 3 total rows for {identifier}"
+    for col in TEST_DATA_WITH_NULL.keys():
+        assert df.where(f"{col} is not null").count() == 2, f"Expected 2 
non-null rows for {col}"
+        assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row 
for {col} is null"
+
+
[email protected]
[email protected](
+    "part_col", ['int', 'bool', 'string', "string_long", "long", "float", 
"double", "date", 'timestamp', 'timestamptz', 'binary']
+)
[email protected]("format_version", [1, 2])
+def test_query_filter_without_data_partitioned(
+    session_catalog: Catalog, spark: SparkSession, arrow_table_without_data: 
pa.Table, part_col: str, format_version: int
+) -> None:
+    # Given
+    identifier = 
f"default.arrow_table_v{format_version}_without_data_partitioned_on_col_{part_col}"
+    nested_field = TABLE_SCHEMA.find_field(part_col)
+    partition_spec = PartitionSpec(
+        PartitionField(source_id=nested_field.field_id, field_id=1001, 
transform=IdentityTransform(), name=part_col)
+    )
+
+    # When
+    tbl = _create_table(
+        session_catalog=session_catalog,
+        identifier=identifier,
+        properties={"format-version": str(format_version)},
+        data=[arrow_table_without_data],
+        partition_spec=partition_spec,
+    )
+
+    # Then
+    assert tbl.format_version == format_version, f"Expected v{format_version}, 
got: v{tbl.format_version}"
+    df = spark.table(identifier)
+    for col in TEST_DATA_WITH_NULL.keys():
+        assert df.where(f"{col} is null").count() == 0, f"Expected 0 row for 
{col}"
+        assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row 
for {col}"
+
+
[email protected]
[email protected](
+    "part_col", ['int', 'bool', 'string', "string_long", "long", "float", 
"double", "date", 'timestamp', 'timestamptz', 'binary']
+)
[email protected]("format_version", [1, 2])
+def test_query_filter_only_nulls_partitioned(
+    session_catalog: Catalog, spark: SparkSession, 
arrow_table_with_only_nulls: pa.Table, part_col: str, format_version: int
+) -> None:
+    # Given
+    identifier = 
f"default.arrow_table_v{format_version}_with_only_nulls_partitioned_on_col_{part_col}"
+    nested_field = TABLE_SCHEMA.find_field(part_col)
+    partition_spec = PartitionSpec(
+        PartitionField(source_id=nested_field.field_id, field_id=1001, 
transform=IdentityTransform(), name=part_col)
+    )
+
+    # When
+    tbl = _create_table(
+        session_catalog=session_catalog,
+        identifier=identifier,
+        properties={"format-version": str(format_version)},
+        data=[arrow_table_with_only_nulls],
+        partition_spec=partition_spec,
+    )
+
+    # Then
+    assert tbl.format_version == format_version, f"Expected v{format_version}, 
got: v{tbl.format_version}"
+    df = spark.table(identifier)
+    for col in TEST_DATA_WITH_NULL.keys():
+        assert df.where(f"{col} is null").count() == 2, f"Expected 2 row for 
{col}"
+        assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows 
for {col}"
+
+
[email protected]
[email protected](
+    "part_col", ['int', 'bool', 'string', "string_long", "long", "float", 
"double", "date", "timestamptz", "timestamp", "binary"]
+)
[email protected]("format_version", [1, 2])
+def test_query_filter_appended_null_partitioned(
+    session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: 
pa.Table, part_col: str, format_version: int
+) -> None:
+    # Given
+    identifier = 
f"default.arrow_table_v{format_version}_appended_with_null_partitioned_on_col_{part_col}"
+    nested_field = TABLE_SCHEMA.find_field(part_col)
+    partition_spec = PartitionSpec(
+        PartitionField(source_id=nested_field.field_id, field_id=1001, 
transform=IdentityTransform(), name=part_col)
+    )
+
+    # When
+    tbl = _create_table(
+        session_catalog=session_catalog,
+        identifier=identifier,
+        properties={"format-version": str(format_version)},
+        data=[],
+        partition_spec=partition_spec,
+    )
+    # Append with arrow_table_1 with lines [A,B,C] and then arrow_table_2 with 
lines[A,B,C,A,B,C]
+    tbl.append(arrow_table_with_null)
+    tbl.append(pa.concat_tables([arrow_table_with_null, 
arrow_table_with_null]))
+
+    # Then
+    assert tbl.format_version == format_version, f"Expected v{format_version}, 
got: v{tbl.format_version}"
+    df = spark.table(identifier)
+    for col in TEST_DATA_WITH_NULL.keys():
+        df = spark.table(identifier)
+        assert df.where(f"{col} is not null").count() == 6, f"Expected 6 
non-null rows for {col}"
+        assert df.where(f"{col} is null").count() == 3, f"Expected 3 null rows 
for {col}"
+    # expecting 6 files: first append with [A], [B], [C],  second append with 
[A, A], [B, B], [C, C]
+    rows = spark.sql(f"select partition from {identifier}.files").collect()
+    assert len(rows) == 6
+
+
[email protected]
[email protected](
+    "part_col", ['int', 'bool', 'string', "string_long", "long", "float", 
"double", "date", "timestamptz", "timestamp", "binary"]
+)
+def test_query_filter_v1_v2_append_null(
+    session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: 
pa.Table, part_col: str
+) -> None:
+    # Given
+    identifier = 
f"default.arrow_table_v1_v2_appended_with_null_partitioned_on_col_{part_col}"
+    nested_field = TABLE_SCHEMA.find_field(part_col)
+    partition_spec = PartitionSpec(
+        PartitionField(source_id=nested_field.field_id, field_id=1001, 
transform=IdentityTransform(), name=part_col)
+    )
+
+    # When
+    tbl = _create_table(
+        session_catalog=session_catalog,
+        identifier=identifier,
+        properties={"format-version": "1"},
+        data=[],
+        partition_spec=partition_spec,
+    )
+    tbl.append(arrow_table_with_null)
+
+    # Then
+    assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}"
+
+    # When
+    with tbl.transaction() as tx:
+        tx.upgrade_table_version(format_version=2)
+
+    tbl.append(arrow_table_with_null)
+
+    # Then
+    assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}"
+    for col in TEST_DATA_WITH_NULL.keys():  # type: ignore
+        df = spark.table(identifier)
+        assert df.where(f"{col} is not null").count() == 4, f"Expected 4 
non-null rows for {col}"
+        assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows 
for {col}"
+
+
[email protected]
+def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, 
arrow_table_with_null: pa.Table) -> None:
+    identifier = "default.arrow_table_summaries"
+
+    try:
+        session_catalog.drop_table(identifier=identifier)
+    except NoSuchTableError:
+        pass
+    tbl = session_catalog.create_table(
+        identifier=identifier,
+        schema=TABLE_SCHEMA,
+        partition_spec=PartitionSpec(PartitionField(source_id=4, 
field_id=1001, transform=IdentityTransform(), name="int")),
+        properties={'format-version': '2'},
+    )
+
+    tbl.append(arrow_table_with_null)
+    tbl.append(arrow_table_with_null)
+
+    rows = spark.sql(
+        f"""
+        SELECT operation, summary
+        FROM {identifier}.snapshots
+        ORDER BY committed_at ASC
+    """
+    ).collect()
+
+    operations = [row.operation for row in rows]
+    assert operations == ['append', 'append']
+
+    summaries = [row.summary for row in rows]
+    assert summaries[0] == {
+        'changed-partition-count': '3',
+        'added-data-files': '3',
+        'added-files-size': '15029',
+        'added-records': '3',
+        'total-data-files': '3',
+        'total-delete-files': '0',
+        'total-equality-deletes': '0',
+        'total-files-size': '15029',
+        'total-position-deletes': '0',
+        'total-records': '3',
+    }
+
+    assert summaries[1] == {
+        'changed-partition-count': '3',
+        'added-data-files': '3',
+        'added-files-size': '15029',
+        'added-records': '3',
+        'total-data-files': '6',
+        'total-delete-files': '0',
+        'total-equality-deletes': '0',
+        'total-files-size': '30058',
+        'total-position-deletes': '0',
+        'total-records': '6',
+    }
+
+
[email protected]
+def test_data_files_with_table_partitioned_with_null(
+    spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: 
pa.Table
+) -> None:
+    identifier = "default.arrow_data_files"
+
+    try:
+        session_catalog.drop_table(identifier=identifier)
+    except NoSuchTableError:
+        pass
+    tbl = session_catalog.create_table(
+        identifier=identifier,
+        schema=TABLE_SCHEMA,
+        partition_spec=PartitionSpec(PartitionField(source_id=4, 
field_id=1001, transform=IdentityTransform(), name="int")),
+        properties={'format-version': '1'},
+    )
+
+    tbl.append(arrow_table_with_null)
+    tbl.append(arrow_table_with_null)
+
+    # added_data_files_count, existing_data_files_count, 
deleted_data_files_count
+    rows = spark.sql(
+        f"""
+        SELECT added_data_files_count, existing_data_files_count, 
deleted_data_files_count
+        FROM {identifier}.all_manifests
+    """
+    ).collect()
+
+    assert [row.added_data_files_count for row in rows] == [3, 3, 3]
+    assert [row.existing_data_files_count for row in rows] == [
+        0,
+        0,
+        0,
+    ]
+    assert [row.deleted_data_files_count for row in rows] == [0, 0, 0]
+
+
[email protected]
+def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> 
None:
+    identifier = "default.arrow_data_files"
+
+    try:
+        session_catalog.drop_table(identifier=identifier)
+    except NoSuchTableError:
+        pass
+
+    tbl = session_catalog.create_table(
+        identifier=identifier,
+        schema=TABLE_SCHEMA,
+        partition_spec=PartitionSpec(PartitionField(source_id=4, 
field_id=1001, transform=IdentityTransform(), name="int")),
+        properties={'format-version': '1'},
+    )
+
+    with pytest.raises(ValueError, match="Expected PyArrow table, got: not a 
df"):
+        tbl.append("not a df")
+
+
[email protected]
[email protected](
+    "spec",
+    [
+        # mixed with non-identity is not supported
+        (
+            PartitionSpec(
+                PartitionField(source_id=4, field_id=1001, 
transform=BucketTransform(2), name="int_bucket"),
+                PartitionField(source_id=1, field_id=1002, 
transform=IdentityTransform(), name="bool"),
+            )
+        ),
+        # none of non-identity is supported
+        (PartitionSpec(PartitionField(source_id=4, field_id=1001, 
transform=BucketTransform(2), name="int_bucket"))),
+        (PartitionSpec(PartitionField(source_id=5, field_id=1001, 
transform=BucketTransform(2), name="long_bucket"))),
+        (PartitionSpec(PartitionField(source_id=10, field_id=1001, 
transform=BucketTransform(2), name="date_bucket"))),
+        (PartitionSpec(PartitionField(source_id=8, field_id=1001, 
transform=BucketTransform(2), name="timestamp_bucket"))),
+        (PartitionSpec(PartitionField(source_id=9, field_id=1001, 
transform=BucketTransform(2), name="timestamptz_bucket"))),
+        (PartitionSpec(PartitionField(source_id=2, field_id=1001, 
transform=BucketTransform(2), name="string_bucket"))),
+        (PartitionSpec(PartitionField(source_id=12, field_id=1001, 
transform=BucketTransform(2), name="fixed_bucket"))),
+        (PartitionSpec(PartitionField(source_id=11, field_id=1001, 
transform=BucketTransform(2), name="binary_bucket"))),
+        (PartitionSpec(PartitionField(source_id=4, field_id=1001, 
transform=TruncateTransform(2), name="int_trunc"))),
+        (PartitionSpec(PartitionField(source_id=5, field_id=1001, 
transform=TruncateTransform(2), name="long_trunc"))),
+        (PartitionSpec(PartitionField(source_id=2, field_id=1001, 
transform=TruncateTransform(2), name="string_trunc"))),
+        (PartitionSpec(PartitionField(source_id=11, field_id=1001, 
transform=TruncateTransform(2), name="binary_trunc"))),
+        (PartitionSpec(PartitionField(source_id=8, field_id=1001, 
transform=YearTransform(), name="timestamp_year"))),
+        (PartitionSpec(PartitionField(source_id=9, field_id=1001, 
transform=YearTransform(), name="timestamptz_year"))),
+        (PartitionSpec(PartitionField(source_id=10, field_id=1001, 
transform=YearTransform(), name="date_year"))),
+        (PartitionSpec(PartitionField(source_id=8, field_id=1001, 
transform=MonthTransform(), name="timestamp_month"))),
+        (PartitionSpec(PartitionField(source_id=9, field_id=1001, 
transform=MonthTransform(), name="timestamptz_month"))),
+        (PartitionSpec(PartitionField(source_id=10, field_id=1001, 
transform=MonthTransform(), name="date_month"))),
+        (PartitionSpec(PartitionField(source_id=8, field_id=1001, 
transform=DayTransform(), name="timestamp_day"))),
+        (PartitionSpec(PartitionField(source_id=9, field_id=1001, 
transform=DayTransform(), name="timestamptz_day"))),
+        (PartitionSpec(PartitionField(source_id=10, field_id=1001, 
transform=DayTransform(), name="date_day"))),
+        (PartitionSpec(PartitionField(source_id=8, field_id=1001, 
transform=HourTransform(), name="timestamp_hour"))),
+        (PartitionSpec(PartitionField(source_id=9, field_id=1001, 
transform=HourTransform(), name="timestamptz_hour"))),
+        (PartitionSpec(PartitionField(source_id=10, field_id=1001, 
transform=HourTransform(), name="date_hour"))),
+    ],
+)
+def test_unsupported_transform(
+    spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, 
arrow_table_with_null: pa.Table
+) -> None:
+    identifier = "default.unsupported_transform"
+
+    try:
+        session_catalog.drop_table(identifier=identifier)
+    except NoSuchTableError:
+        pass
+
+    tbl = session_catalog.create_table(
+        identifier=identifier,
+        schema=TABLE_SCHEMA,
+        partition_spec=spec,
+        properties={'format-version': '1'},
+    )
+
+    with pytest.raises(ValueError, match="All transforms are not supported.*"):
+        tbl.append(arrow_table_with_null)
diff --git a/tests/integration/test_writes.py 
b/tests/integration/test_writes/test_writes.py
similarity index 88%
rename from tests/integration/test_writes.py
rename to tests/integration/test_writes/test_writes.py
index e950fb43..62d3bb11 100644
--- a/tests/integration/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -18,10 +18,9 @@
 import math
 import os
 import time
-import uuid
 from datetime import date, datetime
 from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict
 from urllib.parse import urlparse
 
 import pyarrow as pa
@@ -36,93 +35,9 @@ from pytest_mock.plugin import MockerFixture
 from pyiceberg.catalog import Catalog
 from pyiceberg.catalog.sql import SqlCatalog
 from pyiceberg.exceptions import NoSuchTableError
-from pyiceberg.schema import Schema
-from pyiceberg.table import Table, TableProperties, _dataframe_to_data_files
-from pyiceberg.typedef import Properties
-from pyiceberg.types import (
-    BinaryType,
-    BooleanType,
-    DateType,
-    DoubleType,
-    FixedType,
-    FloatType,
-    IntegerType,
-    LongType,
-    NestedField,
-    StringType,
-    TimestampType,
-    TimestamptzType,
-)
-
-TEST_DATA_WITH_NULL = {
-    'bool': [False, None, True],
-    'string': ['a', None, 'z'],
-    # Go over the 16 bytes to kick in truncation
-    'string_long': ['a' * 22, None, 'z' * 22],
-    'int': [1, None, 9],
-    'long': [1, None, 9],
-    'float': [0.0, None, 0.9],
-    'double': [0.0, None, 0.9],
-    'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 
19, 25, 00)],
-    'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 
1, 19, 25, 00)],
-    'date': [date(2023, 1, 1), None, date(2023, 3, 1)],
-    # Not supported by Spark
-    # 'time': [time(1, 22, 0), None, time(19, 25, 0)],
-    # Not natively supported by Arrow
-    # 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, 
uuid.UUID('11111111-1111-1111-1111-111111111111').bytes],
-    'binary': [b'\01', None, b'\22'],
-    'fixed': [
-        uuid.UUID('00000000-0000-0000-0000-000000000000').bytes,
-        None,
-        uuid.UUID('11111111-1111-1111-1111-111111111111').bytes,
-    ],
-}
-
-TABLE_SCHEMA = Schema(
-    NestedField(field_id=1, name="bool", field_type=BooleanType(), 
required=False),
-    NestedField(field_id=2, name="string", field_type=StringType(), 
required=False),
-    NestedField(field_id=3, name="string_long", field_type=StringType(), 
required=False),
-    NestedField(field_id=4, name="int", field_type=IntegerType(), 
required=False),
-    NestedField(field_id=5, name="long", field_type=LongType(), 
required=False),
-    NestedField(field_id=6, name="float", field_type=FloatType(), 
required=False),
-    NestedField(field_id=7, name="double", field_type=DoubleType(), 
required=False),
-    NestedField(field_id=8, name="timestamp", field_type=TimestampType(), 
required=False),
-    NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(), 
required=False),
-    NestedField(field_id=10, name="date", field_type=DateType(), 
required=False),
-    # NestedField(field_id=11, name="time", field_type=TimeType(), 
required=False),
-    # NestedField(field_id=12, name="uuid", field_type=UuidType(), 
required=False),
-    NestedField(field_id=12, name="binary", field_type=BinaryType(), 
required=False),
-    NestedField(field_id=13, name="fixed", field_type=FixedType(16), 
required=False),
-)
-
-
[email protected](scope="session")
-def arrow_table_without_data(pa_schema: pa.Schema) -> pa.Table:
-    """PyArrow table with all kinds of columns"""
-    return pa.Table.from_pylist([], schema=pa_schema)
-
-
[email protected](scope="session")
-def arrow_table_with_only_nulls(pa_schema: pa.Schema) -> pa.Table:
-    """PyArrow table with all kinds of columns"""
-    return pa.Table.from_pylist([{}, {}], schema=pa_schema)
-
-
-def _create_table(
-    session_catalog: Catalog, identifier: str, properties: Properties, data: 
Optional[List[pa.Table]] = None
-) -> Table:
-    try:
-        session_catalog.drop_table(identifier=identifier)
-    except NoSuchTableError:
-        pass
-
-    tbl = session_catalog.create_table(identifier=identifier, 
schema=TABLE_SCHEMA, properties=properties)
-
-    if data:
-        for d in data:
-            tbl.append(d)
-
-    return tbl
+from pyiceberg.table import TableProperties, _dataframe_to_data_files
+from tests.conftest import TEST_DATA_WITH_NULL
+from utils import _create_table
 
 
 @pytest.fixture(scope="session", autouse=True)
@@ -219,7 +134,7 @@ def test_query_filter_without_data(spark: SparkSession, 
col: str, format_version
     identifier = f"default.arrow_table_v{format_version}_without_data"
     df = spark.table(identifier)
     assert df.where(f"{col} is null").count() == 0, f"Expected 0 row for {col}"
-    assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows for 
{col}"
+    assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for 
{col}"
 
 
 @pytest.mark.integration
@@ -228,8 +143,8 @@ def test_query_filter_without_data(spark: SparkSession, 
col: str, format_version
 def test_query_filter_only_nulls(spark: SparkSession, col: str, 
format_version: int) -> None:
     identifier = f"default.arrow_table_v{format_version}_with_only_nulls"
     df = spark.table(identifier)
-    assert df.where(f"{col} is null").count() == 2, f"Expected 2 row for {col}"
-    assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows for 
{col}"
+    assert df.where(f"{col} is null").count() == 2, f"Expected 2 rows for 
{col}"
+    assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for 
{col}"
 
 
 @pytest.mark.integration
diff --git a/tests/integration/test_writes/utils.py 
b/tests/integration/test_writes/utils.py
new file mode 100644
index 00000000..792e2518
--- /dev/null
+++ b/tests/integration/test_writes/utils.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint:disable=redefined-outer-name
+from typing import List, Optional
+
+import pyarrow as pa
+
+from pyiceberg.catalog import Catalog
+from pyiceberg.exceptions import NoSuchTableError
+from pyiceberg.partitioning import PartitionSpec
+from pyiceberg.schema import Schema
+from pyiceberg.table import Table
+from pyiceberg.typedef import Properties
+from pyiceberg.types import (
+    BinaryType,
+    BooleanType,
+    DateType,
+    DoubleType,
+    FixedType,
+    FloatType,
+    IntegerType,
+    LongType,
+    NestedField,
+    StringType,
+    TimestampType,
+    TimestamptzType,
+)
+
+TABLE_SCHEMA = Schema(
+    NestedField(field_id=1, name="bool", field_type=BooleanType(), 
required=False),
+    NestedField(field_id=2, name="string", field_type=StringType(), 
required=False),
+    NestedField(field_id=3, name="string_long", field_type=StringType(), 
required=False),
+    NestedField(field_id=4, name="int", field_type=IntegerType(), 
required=False),
+    NestedField(field_id=5, name="long", field_type=LongType(), 
required=False),
+    NestedField(field_id=6, name="float", field_type=FloatType(), 
required=False),
+    NestedField(field_id=7, name="double", field_type=DoubleType(), 
required=False),
+    # NestedField(field_id=8, name="time", field_type=TimeType(), 
required=False), # Spark does not support time fields
+    NestedField(field_id=8, name="timestamp", field_type=TimestampType(), 
required=False),
+    NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(), 
required=False),
+    NestedField(field_id=10, name="date", field_type=DateType(), 
required=False),
+    # NestedField(field_id=11, name="time", field_type=TimeType(), 
required=False),
+    # NestedField(field_id=12, name="uuid", field_type=UuidType(), 
required=False),
+    NestedField(field_id=11, name="binary", field_type=BinaryType(), 
required=False),
+    NestedField(field_id=12, name="fixed", field_type=FixedType(16), 
required=False),
+)
+
+
+def _create_table(
+    session_catalog: Catalog,
+    identifier: str,
+    properties: Properties,
+    data: Optional[List[pa.Table]] = None,
+    partition_spec: Optional[PartitionSpec] = None,
+) -> Table:
+    try:
+        session_catalog.drop_table(identifier=identifier)
+    except NoSuchTableError:
+        pass
+
+    if partition_spec:
+        tbl = session_catalog.create_table(
+            identifier=identifier, schema=TABLE_SCHEMA, properties=properties, 
partition_spec=partition_spec
+        )
+    else:
+        tbl = session_catalog.create_table(identifier=identifier, 
schema=TABLE_SCHEMA, properties=properties)
+
+    if data:
+        for d in data:
+            tbl.append(d)
+
+    return tbl
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index f1191295..2bc78f31 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -64,6 +64,7 @@ from pyiceberg.table import (
     UpdateSchema,
     _apply_table_update,
     _check_schema_compatible,
+    _determine_partitions,
     _match_deletes_to_data_file,
     _TableMetadataUpdateContext,
     update_table_metadata,
@@ -82,7 +83,11 @@ from pyiceberg.table.sorting import (
     SortField,
     SortOrder,
 )
-from pyiceberg.transforms import BucketTransform, IdentityTransform
+from pyiceberg.transforms import (
+    BucketTransform,
+    IdentityTransform,
+)
+from pyiceberg.typedef import Record
 from pyiceberg.types import (
     BinaryType,
     BooleanType,
@@ -1139,3 +1144,85 @@ def test_serialize_commit_table_request() -> None:
 
     deserialized_request = 
CommitTableRequest.model_validate_json(request.model_dump_json())
     assert request == deserialized_request
+
+
+def test_partition_for_demo() -> None:
+    import pyarrow as pa
+
+    test_pa_schema = pa.schema([('year', pa.int64()), ("n_legs", pa.int64()), 
("animal", pa.string())])
+    test_schema = Schema(
+        NestedField(field_id=1, name='year', field_type=StringType(), 
required=False),
+        NestedField(field_id=2, name='n_legs', field_type=IntegerType(), 
required=True),
+        NestedField(field_id=3, name='animal', field_type=StringType(), 
required=False),
+        schema_id=1,
+    )
+    test_data = {
+        'year': [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021],
+        'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
+        'animal': ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", 
"Horse", "Brittle stars", "Centipede"],
+    }
+    arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
+    partition_spec = PartitionSpec(
+        PartitionField(source_id=2, field_id=1002, 
transform=IdentityTransform(), name="n_legs_identity"),
+        PartitionField(source_id=1, field_id=1001, 
transform=IdentityTransform(), name="year_identity"),
+    )
+    result = _determine_partitions(partition_spec, test_schema, arrow_table)
+    assert {table_partition.partition_key.partition for table_partition in 
result} == {
+        Record(n_legs_identity=2, year_identity=2020),
+        Record(n_legs_identity=100, year_identity=2021),
+        Record(n_legs_identity=4, year_identity=2021),
+        Record(n_legs_identity=4, year_identity=2022),
+        Record(n_legs_identity=2, year_identity=2022),
+        Record(n_legs_identity=5, year_identity=2019),
+    }
+    assert (
+        pa.concat_tables([table_partition.arrow_table_partition for 
table_partition in result]).num_rows == arrow_table.num_rows
+    )
+
+
+def test_identity_partition_on_multi_columns() -> None:
+    import pyarrow as pa
+
+    test_pa_schema = pa.schema([('born_year', pa.int64()), ("n_legs", 
pa.int64()), ("animal", pa.string())])
+    test_schema = Schema(
+        NestedField(field_id=1, name='born_year', field_type=StringType(), 
required=False),
+        NestedField(field_id=2, name='n_legs', field_type=IntegerType(), 
required=True),
+        NestedField(field_id=3, name='animal', field_type=StringType(), 
required=False),
+        schema_id=1,
+    )
+    # 5 partitions, 6 unique row values, 12 rows
+    test_rows = [
+        (2021, 4, "Dog"),
+        (2022, 4, "Horse"),
+        (2022, 4, "Another Horse"),
+        (2021, 100, "Centipede"),
+        (None, 4, "Kirin"),
+        (2021, None, "Fish"),
+    ] * 2
+    expected = {Record(n_legs_identity=test_rows[i][1], 
year_identity=test_rows[i][0]) for i in range(len(test_rows))}
+    partition_spec = PartitionSpec(
+        PartitionField(source_id=2, field_id=1002, 
transform=IdentityTransform(), name="n_legs_identity"),
+        PartitionField(source_id=1, field_id=1001, 
transform=IdentityTransform(), name="year_identity"),
+    )
+    import random
+
+    # there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all
+    for _ in range(1000):
+        random.shuffle(test_rows)
+        test_data = {
+            'born_year': [row[0] for row in test_rows],
+            'n_legs': [row[1] for row in test_rows],
+            'animal': [row[2] for row in test_rows],
+        }
+        arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
+
+        result = _determine_partitions(partition_spec, test_schema, 
arrow_table)
+
+        assert {table_partition.partition_key.partition for table_partition in 
result} == expected
+        concatenated_arrow_table = 
pa.concat_tables([table_partition.arrow_table_partition for table_partition in 
result])
+        assert concatenated_arrow_table.num_rows == arrow_table.num_rows
+        assert concatenated_arrow_table.sort_by([
+            ('born_year', 'ascending'),
+            ('n_legs', 'ascending'),
+            ('animal', 'ascending'),
+        ]) == arrow_table.sort_by([('born_year', 'ascending'), ('n_legs', 
'ascending'), ('animal', 'ascending')])

Reply via email to