This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 32e8f88e support PyArrow timestamptz with Etc/UTC (#910)
32e8f88e is described below

commit 32e8f88ebf8e45ae0a7f60a848ea44044a9564ef
Author: Sung Yun <[email protected]>
AuthorDate: Fri Jul 12 15:26:00 2024 -0400

    support PyArrow timestamptz with Etc/UTC (#910)
    
    Co-authored-by: Fokko Driesprong <[email protected]>
---
 pyiceberg/io/pyarrow.py                            |  51 ++++++---
 pyiceberg/table/__init__.py                        |   8 --
 tests/conftest.py                                  | 116 ++++++++++++++++++++-
 tests/integration/test_add_files.py                |   1 +
 .../test_writes/test_partitioned_writes.py         |  14 +--
 tests/integration/test_writes/test_writes.py       |  81 +++++---------
 tests/io/test_pyarrow.py                           |  33 ++++++
 7 files changed, 218 insertions(+), 86 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 91745a58..1ef9fc9b 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -174,6 +174,7 @@ LIST_ELEMENT_NAME = "element"
 MAP_KEY_NAME = "key"
 MAP_VALUE_NAME = "value"
 DOC = "doc"
+UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"}
 
 T = TypeVar("T")
 
@@ -937,7 +938,7 @@ class 
_ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
             else:
                 raise TypeError(f"Unsupported precision for timestamp type: 
{primitive.unit}")
 
-            if primitive.tz == "UTC" or primitive.tz == "+00:00":
+            if primitive.tz in UTC_ALIASES:
                 return TimestamptzType()
             elif primitive.tz is None:
                 return TimestampType()
@@ -1073,7 +1074,7 @@ def _task_to_record_batches(
                     arrow_table = pa.Table.from_batches([batch])
                     arrow_table = arrow_table.filter(pyarrow_filter)
                     batch = arrow_table.to_batches()[0]
-            yield to_requested_schema(projected_schema, file_project_schema, 
batch, downcast_ns_timestamp_to_us=True)
+            yield _to_requested_schema(projected_schema, file_project_schema, 
batch, downcast_ns_timestamp_to_us=True)
             current_index += len(batch)
 
 
@@ -1278,7 +1279,7 @@ def project_batches(
             total_row_count += len(batch)
 
 
-def to_requested_schema(
+def _to_requested_schema(
     requested_schema: Schema,
     file_schema: Schema,
     batch: pa.RecordBatch,
@@ -1296,16 +1297,17 @@ def to_requested_schema(
 
 
 class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, 
Optional[pa.Array]]):
-    file_schema: Schema
+    _file_schema: Schema
     _include_field_ids: bool
+    _downcast_ns_timestamp_to_us: bool
 
     def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool 
= False, include_field_ids: bool = False) -> None:
-        self.file_schema = file_schema
+        self._file_schema = file_schema
         self._include_field_ids = include_field_ids
-        self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
+        self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
 
     def _cast_if_needed(self, field: NestedField, values: pa.Array) -> 
pa.Array:
-        file_field = self.file_schema.find_field(field.field_id)
+        file_field = self._file_schema.find_field(field.field_id)
 
         if field.field_type.is_primitive:
             if field.field_type != file_field.field_type:
@@ -1313,14 +1315,31 @@ class 
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
                     schema_to_pyarrow(promote(file_field.field_type, 
field.field_type), include_field_ids=self._include_field_ids)
                 )
             elif (target_type := schema_to_pyarrow(field.field_type, 
include_field_ids=self._include_field_ids)) != values.type:
-                # Downcasting of nanoseconds to microseconds
-                if (
-                    pa.types.is_timestamp(target_type)
-                    and target_type.unit == "us"
-                    and pa.types.is_timestamp(values.type)
-                    and values.type.unit == "ns"
-                ):
-                    return values.cast(target_type, safe=False)
+                if field.field_type == TimestampType():
+                    # Downcasting of nanoseconds to microseconds
+                    if (
+                        pa.types.is_timestamp(target_type)
+                        and not target_type.tz
+                        and pa.types.is_timestamp(values.type)
+                        and not values.type.tz
+                    ):
+                        if target_type.unit == "us" and values.type.unit == 
"ns" and self._downcast_ns_timestamp_to_us:
+                            return values.cast(target_type, safe=False)
+                        elif target_type.unit == "us" and values.type.unit in 
{"s", "ms"}:
+                            return values.cast(target_type)
+                    raise ValueError(f"Unsupported schema projection from 
{values.type} to {target_type}")
+                elif field.field_type == TimestamptzType():
+                    if (
+                        pa.types.is_timestamp(target_type)
+                        and target_type.tz == "UTC"
+                        and pa.types.is_timestamp(values.type)
+                        and values.type.tz in UTC_ALIASES
+                    ):
+                        if target_type.unit == "us" and values.type.unit == 
"ns" and self._downcast_ns_timestamp_to_us:
+                            return values.cast(target_type, safe=False)
+                        elif target_type.unit == "us" and values.type.unit in 
{"s", "ms", "us"}:
+                            return values.cast(target_type)
+                    raise ValueError(f"Unsupported schema projection from 
{values.type} to {target_type}")
         return values
 
     def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> 
pa.Field:
@@ -1970,7 +1989,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, 
tasks: Iterator[WriteT
 
         downcast_ns_timestamp_to_us = 
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
         batches = [
-            to_requested_schema(
+            _to_requested_schema(
                 requested_schema=file_schema,
                 file_schema=table_schema,
                 batch=batch,
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 62440c47..b43dc320 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -484,10 +484,6 @@ class Transaction:
         _check_schema_compatible(
             self._table.schema(), other_schema=df.schema, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
         )
-        # cast if the two schemas are compatible but not equal
-        table_arrow_schema = self._table.schema().as_arrow()
-        if table_arrow_schema != df.schema:
-            df = df.cast(table_arrow_schema)
 
         manifest_merge_enabled = PropertyUtil.property_as_bool(
             self.table_metadata.properties,
@@ -545,10 +541,6 @@ class Transaction:
         _check_schema_compatible(
             self._table.schema(), other_schema=df.schema, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
         )
-        # cast if the two schemas are compatible but not equal
-        table_arrow_schema = self._table.schema().as_arrow()
-        if table_arrow_schema != df.schema:
-            df = df.cast(table_arrow_schema)
 
         self.delete(delete_filter=overwrite_filter, 
snapshot_properties=snapshot_properties)
 
diff --git a/tests/conftest.py b/tests/conftest.py
index 95e1128a..6b1a2b43 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2382,10 +2382,122 @@ def arrow_table_date_timestamps() -> "pa.Table":
 
 
 @pytest.fixture(scope="session")
-def arrow_table_date_timestamps_schema() -> Schema:
-    """Pyarrow table Schema with only date, timestamp and timestamptz 
values."""
+def table_date_timestamps_schema() -> Schema:
+    """Iceberg table Schema with only date, timestamp and timestamptz 
values."""
     return Schema(
         NestedField(field_id=1, name="date", field_type=DateType(), 
required=False),
         NestedField(field_id=2, name="timestamp", field_type=TimestampType(), 
required=False),
         NestedField(field_id=3, name="timestamptz", 
field_type=TimestamptzType(), required=False),
     )
+
+
[email protected](scope="session")
+def arrow_table_schema_with_all_timestamp_precisions() -> "pa.Schema":
+    """Pyarrow Schema with all supported timestamp types."""
+    import pyarrow as pa
+
+    return pa.schema([
+        ("timestamp_s", pa.timestamp(unit="s")),
+        ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")),
+        ("timestamp_ms", pa.timestamp(unit="ms")),
+        ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")),
+        ("timestamp_us", pa.timestamp(unit="us")),
+        ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
+        ("timestamp_ns", pa.timestamp(unit="ns")),
+        ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")),
+        ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")),
+        ("timestamptz_ns_z", pa.timestamp(unit="ns", tz="Z")),
+        ("timestamptz_s_0000", pa.timestamp(unit="s", tz="+00:00")),
+    ])
+
+
[email protected](scope="session")
+def 
arrow_table_with_all_timestamp_precisions(arrow_table_schema_with_all_timestamp_precisions:
 "pa.Schema") -> "pa.Table":
+    """Pyarrow table with all supported timestamp types."""
+    import pandas as pd
+    import pyarrow as pa
+
+    test_data = pd.DataFrame({
+        "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 
3, 1, 19, 25, 00)],
+        "timestamptz_s": [
+            datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+            None,
+            datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+        ],
+        "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, 
datetime(2023, 3, 1, 19, 25, 00)],
+        "timestamptz_ms": [
+            datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+            None,
+            datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+        ],
+        "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, 
datetime(2023, 3, 1, 19, 25, 00)],
+        "timestamptz_us": [
+            datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+            None,
+            datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+        ],
+        "timestamp_ns": [
+            pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, 
second=0, microsecond=12, nanosecond=6),
+            None,
+            pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, 
second=0, microsecond=12, nanosecond=7),
+        ],
+        "timestamptz_ns": [
+            datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+            None,
+            datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+        ],
+        "timestamptz_us_etc_utc": [
+            datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+            None,
+            datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+        ],
+        "timestamptz_ns_z": [
+            pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, 
second=0, microsecond=12, nanosecond=6, tz="UTC"),
+            None,
+            pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, 
second=0, microsecond=12, nanosecond=7, tz="UTC"),
+        ],
+        "timestamptz_s_0000": [
+            datetime(2023, 1, 1, 19, 25, 1, tzinfo=timezone.utc),
+            None,
+            datetime(2023, 3, 1, 19, 25, 1, tzinfo=timezone.utc),
+        ],
+    })
+    return pa.Table.from_pandas(test_data, 
schema=arrow_table_schema_with_all_timestamp_precisions)
+
+
[email protected](scope="session")
+def arrow_table_schema_with_all_microseconds_timestamp_precisions() -> 
"pa.Schema":
+    """Pyarrow Schema with all microseconds timestamp."""
+    import pyarrow as pa
+
+    return pa.schema([
+        ("timestamp_s", pa.timestamp(unit="us")),
+        ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")),
+        ("timestamp_ms", pa.timestamp(unit="us")),
+        ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")),
+        ("timestamp_us", pa.timestamp(unit="us")),
+        ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
+        ("timestamp_ns", pa.timestamp(unit="us")),
+        ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")),
+        ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")),
+        ("timestamptz_ns_z", pa.timestamp(unit="us", tz="UTC")),
+        ("timestamptz_s_0000", pa.timestamp(unit="us", tz="UTC")),
+    ])
+
+
[email protected](scope="session")
+def table_schema_with_all_microseconds_timestamp_precision() -> Schema:
+    """Iceberg table Schema with only date, timestamp and timestamptz 
values."""
+    return Schema(
+        NestedField(field_id=1, name="timestamp_s", 
field_type=TimestampType(), required=False),
+        NestedField(field_id=2, name="timestamptz_s", 
field_type=TimestamptzType(), required=False),
+        NestedField(field_id=3, name="timestamp_ms", 
field_type=TimestampType(), required=False),
+        NestedField(field_id=4, name="timestamptz_ms", 
field_type=TimestamptzType(), required=False),
+        NestedField(field_id=5, name="timestamp_us", 
field_type=TimestampType(), required=False),
+        NestedField(field_id=6, name="timestamptz_us", 
field_type=TimestamptzType(), required=False),
+        NestedField(field_id=7, name="timestamp_ns", 
field_type=TimestampType(), required=False),
+        NestedField(field_id=8, name="timestamptz_ns", 
field_type=TimestamptzType(), required=False),
+        NestedField(field_id=9, name="timestamptz_us_etc_utc", 
field_type=TimestamptzType(), required=False),
+        NestedField(field_id=10, name="timestamptz_ns_z", 
field_type=TimestamptzType(), required=False),
+        NestedField(field_id=11, name="timestamptz_s_0000", 
field_type=TimestamptzType(), required=False),
+    )
diff --git a/tests/integration/test_add_files.py 
b/tests/integration/test_add_files.py
index 984c7d11..b8fd6d09 100644
--- a/tests/integration/test_add_files.py
+++ b/tests/integration/test_add_files.py
@@ -570,6 +570,7 @@ def test_add_files_with_large_and_regular_schema(spark: 
SparkSession, session_ca
     assert table_schema == arrow_schema_large
 
 
[email protected]
 def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, 
format_version: int, mocker: MockerFixture) -> None:
     nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", 
TimestamptzType()))
 
diff --git a/tests/integration/test_writes/test_partitioned_writes.py 
b/tests/integration/test_writes/test_partitioned_writes.py
index 12da9c92..b199f002 100644
--- a/tests/integration/test_writes/test_partitioned_writes.py
+++ b/tests/integration/test_writes/test_partitioned_writes.py
@@ -461,7 +461,7 @@ def test_append_transform_partition_verify_partitions_count(
     session_catalog: Catalog,
     spark: SparkSession,
     arrow_table_date_timestamps: pa.Table,
-    arrow_table_date_timestamps_schema: Schema,
+    table_date_timestamps_schema: Schema,
     transform: Transform[Any, Any],
     expected_partitions: Set[Any],
     format_version: int,
@@ -469,7 +469,7 @@ def test_append_transform_partition_verify_partitions_count(
     # Given
     part_col = "timestamptz"
     identifier = 
f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}"
-    nested_field = arrow_table_date_timestamps_schema.find_field(part_col)
+    nested_field = table_date_timestamps_schema.find_field(part_col)
     partition_spec = PartitionSpec(
         PartitionField(source_id=nested_field.field_id, field_id=1001, 
transform=transform, name=part_col),
     )
@@ -481,7 +481,7 @@ def test_append_transform_partition_verify_partitions_count(
         properties={"format-version": str(format_version)},
         data=[arrow_table_date_timestamps],
         partition_spec=partition_spec,
-        schema=arrow_table_date_timestamps_schema,
+        schema=table_date_timestamps_schema,
     )
 
     # Then
@@ -510,20 +510,20 @@ def test_append_multiple_partitions(
     session_catalog: Catalog,
     spark: SparkSession,
     arrow_table_date_timestamps: pa.Table,
-    arrow_table_date_timestamps_schema: Schema,
+    table_date_timestamps_schema: Schema,
     format_version: int,
 ) -> None:
     # Given
     identifier = 
f"default.arrow_table_v{format_version}_with_multiple_partitions"
     partition_spec = PartitionSpec(
         PartitionField(
-            
source_id=arrow_table_date_timestamps_schema.find_field("date").field_id,
+            source_id=table_date_timestamps_schema.find_field("date").field_id,
             field_id=1001,
             transform=YearTransform(),
             name="date_year",
         ),
         PartitionField(
-            
source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id,
+            
source_id=table_date_timestamps_schema.find_field("timestamptz").field_id,
             field_id=1000,
             transform=HourTransform(),
             name="timestamptz_hour",
@@ -537,7 +537,7 @@ def test_append_multiple_partitions(
         properties={"format-version": str(format_version)},
         data=[arrow_table_date_timestamps],
         partition_spec=partition_spec,
-        schema=arrow_table_date_timestamps_schema,
+        schema=table_date_timestamps_schema,
     )
 
     # Then
diff --git a/tests/integration/test_writes/test_writes.py 
b/tests/integration/test_writes/test_writes.py
index af626718..41bc6fb5 100644
--- a/tests/integration/test_writes/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -18,11 +18,12 @@
 import math
 import os
 import time
-from datetime import date, datetime, timezone
+from datetime import date, datetime
 from pathlib import Path
 from typing import Any, Dict
 from urllib.parse import urlparse
 
+import pandas as pd
 import pyarrow as pa
 import pyarrow.parquet as pq
 import pytest
@@ -977,69 +978,43 @@ def table_write_subset_of_schema(session_catalog: 
Catalog, arrow_table_with_null
 
 @pytest.mark.integration
 @pytest.mark.parametrize("format_version", [1, 2])
-def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: 
Catalog, format_version: int) -> None:
+def test_write_all_timestamp_precision(
+    mocker: MockerFixture,
+    spark: SparkSession,
+    session_catalog: Catalog,
+    format_version: int,
+    arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
+    arrow_table_with_all_timestamp_precisions: pa.Table,
+    arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema,
+) -> None:
     identifier = "default.table_all_timestamp_precision"
-    arrow_table_schema_with_all_timestamp_precisions = pa.schema([
-        ("timestamp_s", pa.timestamp(unit="s")),
-        ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")),
-        ("timestamp_ms", pa.timestamp(unit="ms")),
-        ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")),
-        ("timestamp_us", pa.timestamp(unit="us")),
-        ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
-        ("timestamp_ns", pa.timestamp(unit="ns")),
-        ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")),
-    ])
-    TEST_DATA_WITH_NULL = {
-        "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 
3, 1, 19, 25, 00)],
-        "timestamptz_s": [
-            datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
-            None,
-            datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
-        ],
-        "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, 
datetime(2023, 3, 1, 19, 25, 00)],
-        "timestamptz_ms": [
-            datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
-            None,
-            datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
-        ],
-        "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, 
datetime(2023, 3, 1, 19, 25, 00)],
-        "timestamptz_us": [
-            datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
-            None,
-            datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
-        ],
-        "timestamp_ns": [datetime(2023, 1, 1, 19, 25, 00), None, 
datetime(2023, 3, 1, 19, 25, 00)],
-        "timestamptz_ns": [
-            datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
-            None,
-            datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
-        ],
-    }
-    input_arrow_table = pa.Table.from_pydict(TEST_DATA_WITH_NULL, 
schema=arrow_table_schema_with_all_timestamp_precisions)
     mocker.patch.dict(os.environ, 
values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"})
 
     tbl = _create_table(
         session_catalog,
         identifier,
         {"format-version": format_version},
-        data=[input_arrow_table],
+        data=[arrow_table_with_all_timestamp_precisions],
         schema=arrow_table_schema_with_all_timestamp_precisions,
     )
-    tbl.overwrite(input_arrow_table)
+    tbl.overwrite(arrow_table_with_all_timestamp_precisions)
     written_arrow_table = tbl.scan().to_arrow()
 
-    expected_schema_in_all_us = pa.schema([
-        ("timestamp_s", pa.timestamp(unit="us")),
-        ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")),
-        ("timestamp_ms", pa.timestamp(unit="us")),
-        ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")),
-        ("timestamp_us", pa.timestamp(unit="us")),
-        ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")),
-        ("timestamp_ns", pa.timestamp(unit="us")),
-        ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")),
-    ])
-    assert written_arrow_table.schema == expected_schema_in_all_us
-    assert written_arrow_table == 
input_arrow_table.cast(expected_schema_in_all_us)
+    assert written_arrow_table.schema == 
arrow_table_schema_with_all_microseconds_timestamp_precisions
+    assert written_arrow_table == 
arrow_table_with_all_timestamp_precisions.cast(
+        arrow_table_schema_with_all_microseconds_timestamp_precisions, 
safe=False
+    )
+    lhs = spark.table(f"{identifier}").toPandas()
+    rhs = written_arrow_table.to_pandas()
+
+    for column in written_arrow_table.column_names:
+        for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+            if pd.isnull(left):
+                assert pd.isnull(right)
+            else:
+                # Check only upto microsecond precision since Spark loaded 
dtype is timezone unaware
+                # and supports upto microsecond precision
+                assert left.timestamp() == right.timestamp(), f"Difference in 
column {column}: {left} != {right}"
 
 
 @pytest.mark.integration
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index 326eeff1..37198b7e 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -65,6 +65,7 @@ from pyiceberg.io.pyarrow import (
     _determine_partitions,
     _primitive_to_physical,
     _read_deletes,
+    _to_requested_schema,
     bin_pack_arrow_table,
     expression_to_pyarrow,
     project_table,
@@ -1889,3 +1890,35 @@ def test_identity_partition_on_multi_columns() -> None:
             ("n_legs", "ascending"),
             ("animal", "ascending"),
         ]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", 
"ascending"), ("animal", "ascending")])
+
+
+def test__to_requested_schema_timestamps(
+    arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
+    arrow_table_with_all_timestamp_precisions: pa.Table,
+    arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema,
+    table_schema_with_all_microseconds_timestamp_precision: Schema,
+) -> None:
+    requested_schema = table_schema_with_all_microseconds_timestamp_precision
+    file_schema = requested_schema
+    batch = arrow_table_with_all_timestamp_precisions.to_batches()[0]
+    result = _to_requested_schema(requested_schema, file_schema, batch, 
downcast_ns_timestamp_to_us=True, include_field_ids=False)
+
+    expected = arrow_table_with_all_timestamp_precisions.cast(
+        arrow_table_schema_with_all_microseconds_timestamp_precisions, 
safe=False
+    ).to_batches()[0]
+    assert result == expected
+
+
+def test__to_requested_schema_timestamps_without_downcast_raises_exception(
+    arrow_table_schema_with_all_timestamp_precisions: pa.Schema,
+    arrow_table_with_all_timestamp_precisions: pa.Table,
+    arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema,
+    table_schema_with_all_microseconds_timestamp_precision: Schema,
+) -> None:
+    requested_schema = table_schema_with_all_microseconds_timestamp_precision
+    file_schema = requested_schema
+    batch = arrow_table_with_all_timestamp_precisions.to_batches()[0]
+    with pytest.raises(ValueError) as exc_info:
+        _to_requested_schema(requested_schema, file_schema, batch, 
downcast_ns_timestamp_to_us=False, include_field_ids=False)
+
+    assert "Unsupported schema projection from timestamp[ns] to timestamp[us]" 
in str(exc_info.value)

Reply via email to