This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 1ed3abdd Allow writing `pa.Table` that are either a subset of table 
schema or in arbitrary order, and support type promotion on write (#921)
1ed3abdd is described below

commit 1ed3abdd1aec480911eeec4f0f46a04efe53dc06
Author: Sung Yun <[email protected]>
AuthorDate: Wed Jul 17 02:04:52 2024 -0400

    Allow writing `pa.Table` that are either a subset of table schema or in 
arbitrary order, and support type promotion on write (#921)
    
    * merge
    
    * thanks @HonahX :)
    
    Co-authored-by: Honah J. <[email protected]>
    
    * support promote
    
    * revert promote
    
    * use a visitor
    
    * support promotion on write
    
    * fix
    
    * Thank you @Fokko !
    
    Co-authored-by: Fokko Driesprong <[email protected]>
    
    * revert
    
    * add-files promotiontest
    
    * support promote for add_files
    
    * add tests for uuid
    
    * add_files subset schema test
    
    ---------
    
    Co-authored-by: Honah J. <[email protected]>
    Co-authored-by: Fokko Driesprong <[email protected]>
---
 pyiceberg/io/pyarrow.py                      |  81 +++++++--------
 pyiceberg/schema.py                          | 100 +++++++++++++++++++
 pyiceberg/table/__init__.py                  |  15 ++-
 tests/conftest.py                            |  59 +++++++++++
 tests/integration/test_add_files.py          | 141 ++++++++++++++++++++++++---
 tests/integration/test_writes/test_writes.py | 102 ++++++++++++++++++-
 tests/io/test_pyarrow.py                     | 126 ++++++++++++++++++++++--
 7 files changed, 545 insertions(+), 79 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 199133f7..cd6736fb 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -120,6 +120,7 @@ from pyiceberg.schema import (
     Schema,
     SchemaVisitorPerPrimitiveType,
     SchemaWithPartnerVisitor,
+    _check_schema_compatible,
     pre_order_visit,
     promote,
     prune_columns,
@@ -1407,7 +1408,7 @@ class 
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
                 # This can be removed once this has been fixed:
                 # https://github.com/apache/arrow/issues/38809
                 list_array = pa.LargeListArray.from_arrays(list_array.offsets, 
value_array)
-
+            value_array = self._cast_if_needed(list_type.element_field, 
value_array)
             arrow_field = 
pa.large_list(self._construct_field(list_type.element_field, value_array.type))
             return list_array.cast(arrow_field)
         else:
@@ -1417,6 +1418,8 @@ class 
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
         self, map_type: MapType, map_array: Optional[pa.Array], key_result: 
Optional[pa.Array], value_result: Optional[pa.Array]
     ) -> Optional[pa.Array]:
         if isinstance(map_array, pa.MapArray) and key_result is not None and 
value_result is not None:
+            key_result = self._cast_if_needed(map_type.key_field, key_result)
+            value_result = self._cast_if_needed(map_type.value_field, 
value_result)
             arrow_field = pa.map_(
                 self._construct_field(map_type.key_field, key_result.type),
                 self._construct_field(map_type.value_field, value_result.type),
@@ -1549,9 +1552,16 @@ class StatsAggregator:
 
         expected_physical_type = _primitive_to_physical(iceberg_type)
         if expected_physical_type != physical_type_string:
-            raise ValueError(
-                f"Unexpected physical type {physical_type_string} for 
{iceberg_type}, expected {expected_physical_type}"
-            )
+            # Allow promotable physical types
+            # INT32 -> INT64 and FLOAT -> DOUBLE are safe type casts
+            if (physical_type_string == "INT32" and expected_physical_type == 
"INT64") or (
+                physical_type_string == "FLOAT" and expected_physical_type == 
"DOUBLE"
+            ):
+                pass
+            else:
+                raise ValueError(
+                    f"Unexpected physical type {physical_type_string} for 
{iceberg_type}, expected {expected_physical_type}"
+                )
 
         self.primitive_type = iceberg_type
 
@@ -1896,16 +1906,6 @@ def data_file_statistics_from_parquet_metadata(
             set the mode for column metrics collection
         parquet_column_mapping (Dict[str, int]): The mapping of the parquet 
file name to the field ID
     """
-    if parquet_metadata.num_columns != len(stats_columns):
-        raise ValueError(
-            f"Number of columns in statistics configuration 
({len(stats_columns)}) is different from the number of columns in pyarrow table 
({parquet_metadata.num_columns})"
-        )
-
-    if parquet_metadata.num_columns != len(parquet_column_mapping):
-        raise ValueError(
-            f"Number of columns in column mapping 
({len(parquet_column_mapping)}) is different from the number of columns in 
pyarrow table ({parquet_metadata.num_columns})"
-        )
-
     column_sizes: Dict[int, int] = {}
     value_counts: Dict[int, int] = {}
     split_offsets: List[int] = []
@@ -1998,8 +1998,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, 
tasks: Iterator[WriteT
     )
 
     def write_parquet(task: WriteTask) -> DataFile:
-        table_schema = task.schema
-
+        table_schema = table_metadata.schema()
         # if schema needs to be transformed, use the transformed schema and 
adjust the arrow table accordingly
         # otherwise use the original schema
         if (sanitized_schema := sanitize_column_names(table_schema)) != 
table_schema:
@@ -2011,7 +2010,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, 
tasks: Iterator[WriteT
         batches = [
             _to_requested_schema(
                 requested_schema=file_schema,
-                file_schema=table_schema,
+                file_schema=task.schema,
                 batch=batch,
                 downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
                 include_field_ids=True,
@@ -2070,47 +2069,30 @@ def bin_pack_arrow_table(tbl: pa.Table, 
target_file_size: int) -> Iterator[List[
     return bin_packed_record_batches
 
 
-def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, 
downcast_ns_timestamp_to_us: bool = False) -> None:
+def _check_pyarrow_schema_compatible(
+    requested_schema: Schema, provided_schema: pa.Schema, 
downcast_ns_timestamp_to_us: bool = False
+) -> None:
     """
-    Check if the `table_schema` is compatible with `other_schema`.
+    Check if the `requested_schema` is compatible with `provided_schema`.
 
     Two schemas are considered compatible when they are equal in terms of the 
Iceberg Schema type.
 
     Raises:
         ValueError: If the schemas are not compatible.
     """
-    name_mapping = table_schema.name_mapping
+    name_mapping = requested_schema.name_mapping
     try:
-        task_schema = pyarrow_to_schema(
-            other_schema, name_mapping=name_mapping, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
+        provided_schema = pyarrow_to_schema(
+            provided_schema, name_mapping=name_mapping, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
         )
     except ValueError as e:
-        other_schema = _pyarrow_to_schema_without_ids(other_schema, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
-        additional_names = set(other_schema.column_names) - 
set(table_schema.column_names)
+        provided_schema = _pyarrow_to_schema_without_ids(provided_schema, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+        additional_names = set(provided_schema._name_to_id.keys()) - 
set(requested_schema._name_to_id.keys())
         raise ValueError(
             f"PyArrow table contains more columns: {', 
'.join(sorted(additional_names))}. Update the schema first (hint, use 
union_by_name)."
         ) from e
 
-    if table_schema.as_struct() != task_schema.as_struct():
-        from rich.console import Console
-        from rich.table import Table as RichTable
-
-        console = Console(record=True)
-
-        rich_table = RichTable(show_header=True, header_style="bold")
-        rich_table.add_column("")
-        rich_table.add_column("Table field")
-        rich_table.add_column("Dataframe field")
-
-        for lhs in table_schema.fields:
-            try:
-                rhs = task_schema.find_field(lhs.field_id)
-                rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), 
str(rhs))
-            except ValueError:
-                rich_table.add_row("❌", str(lhs), "Missing")
-
-        console.print(rich_table)
-        raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
+    _check_schema_compatible(requested_schema, provided_schema)
 
 
 def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, 
file_paths: Iterator[str]) -> Iterator[DataFile]:
@@ -2124,7 +2106,7 @@ def parquet_files_to_data_files(io: FileIO, 
table_metadata: TableMetadata, file_
                 f"Cannot add file {file_path} because it has field IDs. 
`add_files` only supports addition of files without field_ids"
             )
         schema = table_metadata.schema()
-        _check_schema_compatible(schema, 
parquet_metadata.schema.to_arrow_schema())
+        _check_pyarrow_schema_compatible(schema, 
parquet_metadata.schema.to_arrow_schema())
 
         statistics = data_file_statistics_from_parquet_metadata(
             parquet_metadata=parquet_metadata,
@@ -2205,7 +2187,7 @@ def _dataframe_to_data_files(
     Returns:
         An iterable that supplies datafiles that represent the table.
     """
-    from pyiceberg.table import PropertyUtil, TableProperties, WriteTask
+    from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, 
PropertyUtil, TableProperties, WriteTask
 
     counter = counter or itertools.count(0)
     write_uuid = write_uuid or uuid.uuid4()
@@ -2214,13 +2196,16 @@ def _dataframe_to_data_files(
         property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
         default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
     )
+    name_mapping = table_metadata.schema().name_mapping
+    downcast_ns_timestamp_to_us = 
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
+    task_schema = pyarrow_to_schema(df.schema, name_mapping=name_mapping, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
 
     if table_metadata.spec().is_unpartitioned():
         yield from write_file(
             io=io,
             table_metadata=table_metadata,
             tasks=iter([
-                WriteTask(write_uuid=write_uuid, task_id=next(counter), 
record_batches=batches, schema=table_metadata.schema())
+                WriteTask(write_uuid=write_uuid, task_id=next(counter), 
record_batches=batches, schema=task_schema)
                 for batches in bin_pack_arrow_table(df, target_file_size)
             ]),
         )
@@ -2235,7 +2220,7 @@ def _dataframe_to_data_files(
                     task_id=next(counter),
                     record_batches=batches,
                     partition_key=partition.partition_key,
-                    schema=table_metadata.schema(),
+                    schema=task_schema,
                 )
                 for partition in partitions
                 for batches in 
bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py
index 77f1addb..cfe3fe3a 100644
--- a/pyiceberg/schema.py
+++ b/pyiceberg/schema.py
@@ -1616,3 +1616,103 @@ def _(file_type: FixedType, read_type: IcebergType) -> 
IcebergType:
         return read_type
     else:
         raise ResolveError(f"Cannot promote {file_type} to {read_type}")
+
+
+def _check_schema_compatible(requested_schema: Schema, provided_schema: 
Schema) -> None:
+    """
+    Check if the `provided_schema` is compatible with `requested_schema`.
+
+    Both Schemas must have valid IDs and share the same ID for the same field 
names.
+
+    Two schemas are considered compatible when:
+    1. All `required` fields in `requested_schema` are present and are also 
`required` in the `provided_schema`
+    2. Field Types are consistent for fields that are present in both schemas. 
I.e. the field type
+       in the `provided_schema` can be promoted to the field type of the same 
field ID in `requested_schema`
+
+    Raises:
+        ValueError: If the schemas are not compatible.
+    """
+    pre_order_visit(requested_schema, 
_SchemaCompatibilityVisitor(provided_schema))
+
+
+class _SchemaCompatibilityVisitor(PreOrderSchemaVisitor[bool]):
+    provided_schema: Schema
+
+    def __init__(self, provided_schema: Schema):
+        from rich.console import Console
+        from rich.table import Table as RichTable
+
+        self.provided_schema = provided_schema
+        self.rich_table = RichTable(show_header=True, header_style="bold")
+        self.rich_table.add_column("")
+        self.rich_table.add_column("Table field")
+        self.rich_table.add_column("Dataframe field")
+        self.console = Console(record=True)
+
+    def _is_field_compatible(self, lhs: NestedField) -> bool:
+        # Validate nullability first.
+        # An optional field can be missing in the provided schema
+        # But a required field must exist as a required field
+        try:
+            rhs = self.provided_schema.find_field(lhs.field_id)
+        except ValueError:
+            if lhs.required:
+                self.rich_table.add_row("❌", str(lhs), "Missing")
+                return False
+            else:
+                self.rich_table.add_row("✅", str(lhs), "Missing")
+                return True
+
+        if lhs.required and not rhs.required:
+            self.rich_table.add_row("❌", str(lhs), str(rhs))
+            return False
+
+        # Check type compatibility
+        if lhs.field_type == rhs.field_type:
+            self.rich_table.add_row("✅", str(lhs), str(rhs))
+            return True
+        # We only check that the parent node is also of the same type.
+        # We check the type of the child nodes when we traverse them later.
+        elif any(
+            (isinstance(lhs.field_type, container_type) and 
isinstance(rhs.field_type, container_type))
+            for container_type in {StructType, MapType, ListType}
+        ):
+            self.rich_table.add_row("✅", str(lhs), str(rhs))
+            return True
+        else:
+            try:
+                # If type can be promoted to the requested schema
+                # it is considered compatible
+                promote(rhs.field_type, lhs.field_type)
+                self.rich_table.add_row("✅", str(lhs), str(rhs))
+                return True
+            except ResolveError:
+                self.rich_table.add_row("❌", str(lhs), str(rhs))
+                return False
+
+    def schema(self, schema: Schema, struct_result: Callable[[], bool]) -> 
bool:
+        if not (result := struct_result()):
+            self.console.print(self.rich_table)
+            raise ValueError(f"Mismatch in 
fields:\n{self.console.export_text()}")
+        return result
+
+    def struct(self, struct: StructType, field_results: List[Callable[[], 
bool]]) -> bool:
+        results = [result() for result in field_results]
+        return all(results)
+
+    def field(self, field: NestedField, field_result: Callable[[], bool]) -> 
bool:
+        return self._is_field_compatible(field) and field_result()
+
+    def list(self, list_type: ListType, element_result: Callable[[], bool]) -> 
bool:
+        return self._is_field_compatible(list_type.element_field) and 
element_result()
+
+    def map(self, map_type: MapType, key_result: Callable[[], bool], 
value_result: Callable[[], bool]) -> bool:
+        return all([
+            self._is_field_compatible(map_type.key_field),
+            self._is_field_compatible(map_type.value_field),
+            key_result(),
+            value_result(),
+        ])
+
+    def primitive(self, primitive: PrimitiveType) -> bool:
+        return True
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index b43dc320..0b211e67 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -73,7 +73,6 @@ from pyiceberg.expressions.visitors import (
     manifest_evaluator,
 )
 from pyiceberg.io import FileIO, OutputFile, load_file_io
-from pyiceberg.io.pyarrow import _check_schema_compatible, 
_dataframe_to_data_files, expression_to_pyarrow, project_table
 from pyiceberg.manifest import (
     POSITIONAL_DELETE_SCHEMA,
     DataFile,
@@ -471,6 +470,8 @@ class Transaction:
         except ModuleNotFoundError as e:
             raise ModuleNotFoundError("For writes PyArrow needs to be 
installed") from e
 
+        from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, 
_dataframe_to_data_files
+
         if not isinstance(df, pa.Table):
             raise ValueError(f"Expected PyArrow table, got: {df}")
 
@@ -481,8 +482,8 @@ class Transaction:
                 f"Not all partition types are supported for writes. Following 
partitions cannot be written using pyarrow: {unsupported_partitions}."
             )
         downcast_ns_timestamp_to_us = 
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
-        _check_schema_compatible(
-            self._table.schema(), other_schema=df.schema, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
+        _check_pyarrow_schema_compatible(
+            self._table.schema(), provided_schema=df.schema, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
         )
 
         manifest_merge_enabled = PropertyUtil.property_as_bool(
@@ -528,6 +529,8 @@ class Transaction:
         except ModuleNotFoundError as e:
             raise ModuleNotFoundError("For writes PyArrow needs to be 
installed") from e
 
+        from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, 
_dataframe_to_data_files
+
         if not isinstance(df, pa.Table):
             raise ValueError(f"Expected PyArrow table, got: {df}")
 
@@ -538,8 +541,8 @@ class Transaction:
                 f"Not all partition types are supported for writes. Following 
partitions cannot be written using pyarrow: {unsupported_partitions}."
             )
         downcast_ns_timestamp_to_us = 
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
-        _check_schema_compatible(
-            self._table.schema(), other_schema=df.schema, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
+        _check_pyarrow_schema_compatible(
+            self._table.schema(), provided_schema=df.schema, 
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
         )
 
         self.delete(delete_filter=overwrite_filter, 
snapshot_properties=snapshot_properties)
@@ -566,6 +569,8 @@ class Transaction:
             delete_filter: A boolean expression to delete rows from a table
             snapshot_properties: Custom properties to be added to the snapshot 
summary
         """
+        from pyiceberg.io.pyarrow import _dataframe_to_data_files, 
expression_to_pyarrow, project_table
+
         if (
             self.table_metadata.properties.get(TableProperties.DELETE_MODE, 
TableProperties.DELETE_MODE_DEFAULT)
             == TableProperties.DELETE_MODE_MERGE_ON_READ
diff --git a/tests/conftest.py b/tests/conftest.py
index 91ab8f2e..7f9a2bcf 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2506,3 +2506,62 @@ def 
table_schema_with_all_microseconds_timestamp_precision() -> Schema:
         NestedField(field_id=10, name="timestamptz_ns_z", 
field_type=TimestamptzType(), required=False),
         NestedField(field_id=11, name="timestamptz_s_0000", 
field_type=TimestamptzType(), required=False),
     )
+
+
[email protected](scope="session")
+def table_schema_with_promoted_types() -> Schema:
+    """Iceberg table Schema with longs, doubles and uuid in simple and nested 
types."""
+    return Schema(
+        NestedField(field_id=1, name="long", field_type=LongType(), 
required=False),
+        NestedField(
+            field_id=2,
+            name="list",
+            field_type=ListType(element_id=4, element_type=LongType(), 
element_required=False),
+            required=True,
+        ),
+        NestedField(
+            field_id=3,
+            name="map",
+            field_type=MapType(
+                key_id=5,
+                key_type=StringType(),
+                value_id=6,
+                value_type=LongType(),
+                value_required=False,
+            ),
+            required=True,
+        ),
+        NestedField(field_id=7, name="double", field_type=DoubleType(), 
required=False),
+        NestedField(field_id=8, name="uuid", field_type=UUIDType(), 
required=False),
+    )
+
+
[email protected](scope="session")
+def pyarrow_schema_with_promoted_types() -> "pa.Schema":
+    """Pyarrow Schema with longs, doubles and uuid in simple and nested 
types."""
+    import pyarrow as pa
+
+    return pa.schema((
+        pa.field("long", pa.int32(), nullable=True),  # can support upcasting 
integer to long
+        pa.field("list", pa.list_(pa.int32()), nullable=False),  # can support 
upcasting integer to long
+        pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False),  # 
can support upcasting integer to long
+        pa.field("double", pa.float32(), nullable=True),  # can support 
upcasting float to double
+        pa.field("uuid", pa.binary(length=16), nullable=True),  # can support 
upcasting float to double
+    ))
+
+
[email protected](scope="session")
+def pyarrow_table_with_promoted_types(pyarrow_schema_with_promoted_types: 
"pa.Schema") -> "pa.Table":
+    """Pyarrow table with longs, doubles and uuid in simple and nested 
types."""
+    import pyarrow as pa
+
+    return pa.Table.from_pydict(
+        {
+            "long": [1, 9],
+            "list": [[1, 1], [2, 2]],
+            "map": [{"a": 1}, {"b": 2}],
+            "double": [1.1, 9.2],
+            "uuid": [b"qZx\xefNS@\x89\x9b\xf9:\xd0\xee\x9b\xf5E", 
b"\x97]\x87T^JDJ\x96\x97\xf4v\xe4\x03\x0c\xde"],
+        },
+        schema=pyarrow_schema_with_promoted_types,
+    )
diff --git a/tests/integration/test_add_files.py 
b/tests/integration/test_add_files.py
index b8fd6d09..3703a9e0 100644
--- a/tests/integration/test_add_files.py
+++ b/tests/integration/test_add_files.py
@@ -30,6 +30,7 @@ from pytest_mock.plugin import MockerFixture
 from pyiceberg.catalog import Catalog
 from pyiceberg.exceptions import NoSuchTableError
 from pyiceberg.io import FileIO
+from pyiceberg.io.pyarrow import _pyarrow_schema_ensure_large_types
 from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, 
PartitionField, PartitionSpec
 from pyiceberg.schema import Schema
 from pyiceberg.table import Table
@@ -38,6 +39,7 @@ from pyiceberg.types import (
     BooleanType,
     DateType,
     IntegerType,
+    LongType,
     NestedField,
     StringType,
     TimestamptzType,
@@ -505,7 +507,7 @@ def test_add_files_fails_on_schema_mismatch(spark: 
SparkSession, session_catalog
 ┃    ┃ Table field              ┃ Dataframe field          ┃
 ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
 │ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │
-| ✅ │ 2: bar: optional string  │ 2: bar: optional string  │
+│ ✅ │ 2: bar: optional string  │ 2: bar: optional string  │
 │ ❌ │ 3: baz: optional int     │ 3: baz: optional string  │
 │ ✅ │ 4: qux: optional date    │ 4: qux: optional date    │
 └────┴──────────────────────────┴──────────────────────────┘
@@ -589,18 +591,7 @@ def 
test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_v
     mocker.patch.dict(os.environ, 
values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"})
 
     identifier = f"default.timestamptz_ns_added{format_version}"
-
-    try:
-        session_catalog.drop_table(identifier=identifier)
-    except NoSuchTableError:
-        pass
-
-    tbl = session_catalog.create_table(
-        identifier=identifier,
-        schema=nanoseconds_schema_iceberg,
-        properties={"format-version": str(format_version)},
-        partition_spec=PartitionSpec(),
-    )
+    tbl = _create_table(session_catalog, identifier, format_version, 
schema=nanoseconds_schema_iceberg)
 
     file_path = 
f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test.parquet"
     # write parquet files
@@ -617,3 +608,127 @@ def 
test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_v
         ),
     ):
         tbl.add_files(file_paths=[file_path])
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_add_file_with_valid_nullability_diff(spark: SparkSession, 
session_catalog: Catalog, format_version: int) -> None:
+    identifier = 
f"default.test_table_with_valid_nullability_diff{format_version}"
+    table_schema = Schema(
+        NestedField(field_id=1, name="long", field_type=LongType(), 
required=False),
+    )
+    other_schema = pa.schema((
+        pa.field("long", pa.int64(), nullable=False),  # can support writing 
required pyarrow field to optional Iceberg field
+    ))
+    arrow_table = pa.Table.from_pydict(
+        {
+            "long": [1, 9],
+        },
+        schema=other_schema,
+    )
+    tbl = _create_table(session_catalog, identifier, format_version, 
schema=table_schema)
+
+    file_path = 
f"s3://warehouse/default/test_add_file_with_valid_nullability_diff/v{format_version}/test.parquet"
+    # write parquet files
+    fo = tbl.io.new_output(file_path)
+    with fo.create(overwrite=True) as fos:
+        with pq.ParquetWriter(fos, schema=other_schema) as writer:
+            writer.write_table(arrow_table)
+
+    tbl.add_files(file_paths=[file_path])
+    # table's long field should cast to be optional on read
+    written_arrow_table = tbl.scan().to_arrow()
+    assert written_arrow_table == arrow_table.cast(pa.schema((pa.field("long", 
pa.int64(), nullable=True),)))
+    lhs = spark.table(f"{identifier}").toPandas()
+    rhs = written_arrow_table.to_pandas()
+
+    for column in written_arrow_table.column_names:
+        for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+            assert left == right
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_add_files_with_valid_upcast(
+    spark: SparkSession,
+    session_catalog: Catalog,
+    format_version: int,
+    table_schema_with_promoted_types: Schema,
+    pyarrow_schema_with_promoted_types: pa.Schema,
+    pyarrow_table_with_promoted_types: pa.Table,
+) -> None:
+    identifier = f"default.test_table_with_valid_upcast{format_version}"
+    tbl = _create_table(session_catalog, identifier, format_version, 
schema=table_schema_with_promoted_types)
+
+    file_path = 
f"s3://warehouse/default/test_add_files_with_valid_upcast/v{format_version}/test.parquet"
+    # write parquet files
+    fo = tbl.io.new_output(file_path)
+    with fo.create(overwrite=True) as fos:
+        with pq.ParquetWriter(fos, schema=pyarrow_schema_with_promoted_types) 
as writer:
+            writer.write_table(pyarrow_table_with_promoted_types)
+
+    tbl.add_files(file_paths=[file_path])
+    # table's long field should cast to long on read
+    written_arrow_table = tbl.scan().to_arrow()
+    assert written_arrow_table == pyarrow_table_with_promoted_types.cast(
+        pa.schema((
+            pa.field("long", pa.int64(), nullable=True),
+            pa.field("list", pa.large_list(pa.int64()), nullable=False),
+            pa.field("map", pa.map_(pa.large_string(), pa.int64()), 
nullable=False),
+            pa.field("double", pa.float64(), nullable=True),
+            pa.field("uuid", pa.binary(length=16), nullable=True),  # can UUID 
is read as fixed length binary of length 16
+        ))
+    )
+    lhs = spark.table(f"{identifier}").toPandas()
+    rhs = written_arrow_table.to_pandas()
+
+    for column in written_arrow_table.column_names:
+        for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+            if column == "map":
+                # Arrow returns a list of tuples, instead of a dict
+                right = dict(right)
+            if column == "list":
+                # Arrow returns an array, convert to list for equality check
+                left, right = list(left), list(right)
+            if column == "uuid":
+                # Spark Iceberg represents UUID as hex string like 
'715a78ef-4e53-4089-9bf9-3ad0ee9bf545'
+                # whereas PyIceberg represents UUID as bytes on read
+                left, right = left.replace("-", ""), right.hex()
+            assert left == right
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_add_files_subset_of_schema(spark: SparkSession, session_catalog: 
Catalog, format_version: int) -> None:
+    identifier = f"default.test_table_subset_of_schema{format_version}"
+    tbl = _create_table(session_catalog, identifier, format_version)
+
+    file_path = 
f"s3://warehouse/default/test_add_files_subset_of_schema/v{format_version}/test.parquet"
+    arrow_table_without_some_columns = 
ARROW_TABLE.combine_chunks().drop(ARROW_TABLE.column_names[0])
+
+    # write parquet files
+    fo = tbl.io.new_output(file_path)
+    with fo.create(overwrite=True) as fos:
+        with pq.ParquetWriter(fos, 
schema=arrow_table_without_some_columns.schema) as writer:
+            writer.write_table(arrow_table_without_some_columns)
+
+    tbl.add_files(file_paths=[file_path])
+    written_arrow_table = tbl.scan().to_arrow()
+    assert tbl.scan().to_arrow() == pa.Table.from_pylist(
+        [
+            {
+                "foo": None,  # Missing column is read as None on read
+                "bar": "bar_string",
+                "baz": 123,
+                "qux": date(2024, 3, 7),
+            }
+        ],
+        schema=_pyarrow_schema_ensure_large_types(ARROW_SCHEMA),
+    )
+
+    lhs = spark.table(f"{identifier}").toPandas()
+    rhs = written_arrow_table.to_pandas()
+
+    for column in written_arrow_table.column_names:
+        for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+            assert left == right
diff --git a/tests/integration/test_writes/test_writes.py 
b/tests/integration/test_writes/test_writes.py
index 41bc6fb5..09fe654d 100644
--- a/tests/integration/test_writes/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -43,7 +43,7 @@ from pyiceberg.partitioning import PartitionField, 
PartitionSpec
 from pyiceberg.schema import Schema
 from pyiceberg.table import TableProperties
 from pyiceberg.transforms import IdentityTransform
-from pyiceberg.types import IntegerType, NestedField
+from pyiceberg.types import IntegerType, LongType, NestedField
 from utils import _create_table
 
 
@@ -964,9 +964,10 @@ def test_sanitize_character_partitioned(catalog: Catalog) 
-> None:
     assert len(tbl.scan().to_arrow()) == 22
 
 
[email protected]
 @pytest.mark.parametrize("format_version", [1, 2])
-def table_write_subset_of_schema(session_catalog: Catalog, 
arrow_table_with_null: pa.Table, format_version: int) -> None:
-    identifier = "default.table_append_subset_of_schema"
+def test_table_write_subset_of_schema(session_catalog: Catalog, 
arrow_table_with_null: pa.Table, format_version: int) -> None:
+    identifier = "default.test_table_write_subset_of_schema"
     tbl = _create_table(session_catalog, identifier, {"format-version": 
format_version}, [arrow_table_with_null])
     arrow_table_without_some_columns = 
arrow_table_with_null.combine_chunks().drop(arrow_table_with_null.column_names[0])
     assert len(arrow_table_without_some_columns.columns) < 
len(arrow_table_with_null.columns)
@@ -976,6 +977,101 @@ def table_write_subset_of_schema(session_catalog: 
Catalog, arrow_table_with_null
     assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns) 
* 2
 
 
[email protected]
[email protected]("format_version", [1, 2])
+def test_table_write_out_of_order_schema(session_catalog: Catalog, 
arrow_table_with_null: pa.Table, format_version: int) -> None:
+    identifier = "default.test_table_write_out_of_order_schema"
+    # rotate the schema fields by 1
+    fields = list(arrow_table_with_null.schema)
+    rotated_fields = fields[1:] + fields[:1]
+    rotated_schema = pa.schema(rotated_fields)
+    assert arrow_table_with_null.schema != rotated_schema
+    tbl = _create_table(session_catalog, identifier, {"format-version": 
format_version}, schema=rotated_schema)
+
+    tbl.overwrite(arrow_table_with_null)
+    tbl.append(arrow_table_with_null)
+    # overwrite and then append should produce twice the data
+    assert len(tbl.scan().to_arrow()) == len(arrow_table_with_null) * 2
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_table_write_schema_with_valid_nullability_diff(
+    spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: 
pa.Table, format_version: int
+) -> None:
+    identifier = "default.test_table_write_with_valid_nullability_diff"
+    table_schema = Schema(
+        NestedField(field_id=1, name="long", field_type=LongType(), 
required=False),
+    )
+    other_schema = pa.schema((
+        pa.field("long", pa.int64(), nullable=False),  # can support writing 
required pyarrow field to optional Iceberg field
+    ))
+    arrow_table = pa.Table.from_pydict(
+        {
+            "long": [1, 9],
+        },
+        schema=other_schema,
+    )
+    tbl = _create_table(session_catalog, identifier, {"format-version": 
format_version}, [arrow_table], schema=table_schema)
+    # table's long field should cast to be optional on read
+    written_arrow_table = tbl.scan().to_arrow()
+    assert written_arrow_table == arrow_table.cast(pa.schema((pa.field("long", 
pa.int64(), nullable=True),)))
+    lhs = spark.table(f"{identifier}").toPandas()
+    rhs = written_arrow_table.to_pandas()
+
+    for column in written_arrow_table.column_names:
+        for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+            assert left == right
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_table_write_schema_with_valid_upcast(
+    spark: SparkSession,
+    session_catalog: Catalog,
+    format_version: int,
+    table_schema_with_promoted_types: Schema,
+    pyarrow_schema_with_promoted_types: pa.Schema,
+    pyarrow_table_with_promoted_types: pa.Table,
+) -> None:
+    identifier = "default.test_table_write_with_valid_upcast"
+
+    tbl = _create_table(
+        session_catalog,
+        identifier,
+        {"format-version": format_version},
+        [pyarrow_table_with_promoted_types],
+        schema=table_schema_with_promoted_types,
+    )
+    # table's long field should cast to long on read
+    written_arrow_table = tbl.scan().to_arrow()
+    assert written_arrow_table == pyarrow_table_with_promoted_types.cast(
+        pa.schema((
+            pa.field("long", pa.int64(), nullable=True),
+            pa.field("list", pa.large_list(pa.int64()), nullable=False),
+            pa.field("map", pa.map_(pa.large_string(), pa.int64()), 
nullable=False),
+            pa.field("double", pa.float64(), nullable=True),  # can support 
upcasting float to double
+            pa.field("uuid", pa.binary(length=16), nullable=True),  # can UUID 
is read as fixed length binary of length 16
+        ))
+    )
+    lhs = spark.table(f"{identifier}").toPandas()
+    rhs = written_arrow_table.to_pandas()
+
+    for column in written_arrow_table.column_names:
+        for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+            if column == "map":
+                # Arrow returns a list of tuples, instead of a dict
+                right = dict(right)
+            if column == "list":
+                # Arrow returns an array, convert to list for equality check
+                left, right = list(left), list(right)
+            if column == "uuid":
+                # Spark Iceberg represents UUID as hex string like 
'715a78ef-4e53-4089-9bf9-3ad0ee9bf545'
+                # whereas PyIceberg represents UUID as bytes on read
+                left, right = left.replace("-", ""), right.hex()
+            assert left == right
+
+
 @pytest.mark.integration
 @pytest.mark.parametrize("format_version", [1, 2])
 def test_write_all_timestamp_precision(
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index 37198b7e..d61a50bb 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -60,7 +60,7 @@ from pyiceberg.io.pyarrow import (
     PyArrowFile,
     PyArrowFileIO,
     StatsAggregator,
-    _check_schema_compatible,
+    _check_pyarrow_schema_compatible,
     _ConvertToArrowSchema,
     _determine_partitions,
     _primitive_to_physical,
@@ -1742,7 +1742,7 @@ def test_schema_mismatch_type(table_schema_simple: 
Schema) -> None:
 """
 
     with pytest.raises(ValueError, match=expected):
-        _check_schema_compatible(table_schema_simple, other_schema)
+        _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
 
 
 def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
@@ -1763,7 +1763,20 @@ def 
test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
 """
 
     with pytest.raises(ValueError, match=expected):
-        _check_schema_compatible(table_schema_simple, other_schema)
+        _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
+
+
+def test_schema_compatible_nullability_diff(table_schema_simple: Schema) -> 
None:
+    other_schema = pa.schema((
+        pa.field("foo", pa.string(), nullable=True),
+        pa.field("bar", pa.int32(), nullable=False),
+        pa.field("baz", pa.bool_(), nullable=False),
+    ))
+
+    try:
+        _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
+    except Exception:
+        pytest.fail("Unexpected Exception raised when calling 
`_check_pyarrow_schema_compatible`")
 
 
 def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
@@ -1783,21 +1796,114 @@ def 
test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
 """
 
     with pytest.raises(ValueError, match=expected):
-        _check_schema_compatible(table_schema_simple, other_schema)
+        _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
+
+
+def test_schema_compatible_missing_nullable_field_nested(table_schema_nested: 
Schema) -> None:
+    schema = table_schema_nested.as_arrow()
+    schema = schema.remove(6).insert(
+        6,
+        pa.field(
+            "person",
+            pa.struct([
+                pa.field("age", pa.int32(), nullable=False),
+            ]),
+            nullable=True,
+        ),
+    )
+    try:
+        _check_pyarrow_schema_compatible(table_schema_nested, schema)
+    except Exception:
+        pytest.fail("Unexpected Exception raised when calling 
`_check_pyarrow_schema_compatible`")
+
+
+def test_schema_mismatch_missing_required_field_nested(table_schema_nested: 
Schema) -> None:
+    other_schema = table_schema_nested.as_arrow()
+    other_schema = other_schema.remove(6).insert(
+        6,
+        pa.field(
+            "person",
+            pa.struct([
+                pa.field("name", pa.string(), nullable=True),
+            ]),
+            nullable=True,
+        ),
+    )
+    expected = """Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃    ┃ Table field                        ┃ Dataframe field                    
┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional string            │ 1: foo: optional string            │
+│ ✅ │ 2: bar: required int               │ 2: bar: required int               │
+│ ✅ │ 3: baz: optional boolean           │ 3: baz: optional boolean           │
+│ ✅ │ 4: qux: required list<string>      │ 4: qux: required list<string>      │
+│ ✅ │ 5: element: required string        │ 5: element: required string        │
+│ ✅ │ 6: quux: required map<string,      │ 6: quux: required map<string,      │
+│    │ map<string, int>>                  │ map<string, int>>                  
│
+│ ✅ │ 7: key: required string            │ 7: key: required string            │
+│ ✅ │ 8: value: required map<string,     │ 8: value: required map<string,     │
+│    │ int>                               │ int>                               
│
+│ ✅ │ 9: key: required string            │ 9: key: required string            │
+│ ✅ │ 10: value: required int            │ 10: value: required int            │
+│ ✅ │ 11: location: required             │ 11: location: required             │
+│    │ list<struct<13: latitude: optional │ list<struct<13: latitude: optional 
│
+│    │ float, 14: longitude: optional     │ float, 14: longitude: optional     
│
+│    │ float>>                            │ float>>                            
│
+│ ✅ │ 12: element: required struct<13:   │ 12: element: required struct<13:   │
+│    │ latitude: optional float, 14:      │ latitude: optional float, 14:      
│
+│    │ longitude: optional float>         │ longitude: optional float>         
│
+│ ✅ │ 13: latitude: optional float       │ 13: latitude: optional float       │
+│ ✅ │ 14: longitude: optional float      │ 14: longitude: optional float      │
+│ ✅ │ 15: person: optional struct<16:    │ 15: person: optional struct<16:    │
+│    │ name: optional string, 17: age:    │ name: optional string>             
│
+│    │ required int>                      │                                    
│
+│ ✅ │ 16: name: optional string          │ 16: name: optional string          │
+│ ❌ │ 17: age: required int              │ Missing                            │
+└────┴────────────────────────────────────┴────────────────────────────────────┘
+"""
+
+    with pytest.raises(ValueError, match=expected):
+        _check_pyarrow_schema_compatible(table_schema_nested, other_schema)
+
+
+def test_schema_compatible_nested(table_schema_nested: Schema) -> None:
+    try:
+        _check_pyarrow_schema_compatible(table_schema_nested, 
table_schema_nested.as_arrow())
+    except Exception:
+        pytest.fail("Unexpected Exception raised when calling 
`_check_pyarrow_schema_compatible`")
 
 
 def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
     other_schema = pa.schema((
         pa.field("foo", pa.string(), nullable=True),
-        pa.field("bar", pa.int32(), nullable=True),
+        pa.field("bar", pa.int32(), nullable=False),
         pa.field("baz", pa.bool_(), nullable=True),
         pa.field("new_field", pa.date32(), nullable=True),
     ))
 
-    expected = r"PyArrow table contains more columns: new_field. Update the 
schema first \(hint, use union_by_name\)."
+    with pytest.raises(
+        ValueError, match=r"PyArrow table contains more columns: new_field. 
Update the schema first \(hint, use union_by_name\)."
+    ):
+        _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
 
-    with pytest.raises(ValueError, match=expected):
-        _check_schema_compatible(table_schema_simple, other_schema)
+
+def test_schema_compatible(table_schema_simple: Schema) -> None:
+    try:
+        _check_pyarrow_schema_compatible(table_schema_simple, 
table_schema_simple.as_arrow())
+    except Exception:
+        pytest.fail("Unexpected Exception raised when calling 
`_check_pyarrow_schema_compatible`")
+
+
+def test_schema_projection(table_schema_simple: Schema) -> None:
+    # remove optional `baz` field from `table_schema_simple`
+    other_schema = pa.schema((
+        pa.field("foo", pa.string(), nullable=True),
+        pa.field("bar", pa.int32(), nullable=False),
+    ))
+    try:
+        _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
+    except Exception:
+        pytest.fail("Unexpected Exception raised when calling 
`_check_pyarrow_schema_compatible`")
 
 
 def test_schema_downcast(table_schema_simple: Schema) -> None:
@@ -1809,9 +1915,9 @@ def test_schema_downcast(table_schema_simple: Schema) -> 
None:
     ))
 
     try:
-        _check_schema_compatible(table_schema_simple, other_schema)
+        _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
     except Exception:
-        pytest.fail("Unexpected Exception raised when calling `_check_schema`")
+        pytest.fail("Unexpected Exception raised when calling 
`_check_pyarrow_schema_compatible`")
 
 
 def test_partition_for_demo() -> None:


Reply via email to