This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 1ed3abdd Allow writing `pa.Table` that are either a subset of table
schema or in arbitrary order, and support type promotion on write (#921)
1ed3abdd is described below
commit 1ed3abdd1aec480911eeec4f0f46a04efe53dc06
Author: Sung Yun <[email protected]>
AuthorDate: Wed Jul 17 02:04:52 2024 -0400
Allow writing `pa.Table` that are either a subset of table schema or in
arbitrary order, and support type promotion on write (#921)
* merge
* thanks @HonahX :)
Co-authored-by: Honah J. <[email protected]>
* support promote
* revert promote
* use a visitor
* support promotion on write
* fix
* Thank you @Fokko !
Co-authored-by: Fokko Driesprong <[email protected]>
* revert
* add-files promotiontest
* support promote for add_files
* add tests for uuid
* add_files subset schema test
---------
Co-authored-by: Honah J. <[email protected]>
Co-authored-by: Fokko Driesprong <[email protected]>
---
pyiceberg/io/pyarrow.py | 81 +++++++--------
pyiceberg/schema.py | 100 +++++++++++++++++++
pyiceberg/table/__init__.py | 15 ++-
tests/conftest.py | 59 +++++++++++
tests/integration/test_add_files.py | 141 ++++++++++++++++++++++++---
tests/integration/test_writes/test_writes.py | 102 ++++++++++++++++++-
tests/io/test_pyarrow.py | 126 ++++++++++++++++++++++--
7 files changed, 545 insertions(+), 79 deletions(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 199133f7..cd6736fb 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -120,6 +120,7 @@ from pyiceberg.schema import (
Schema,
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
+ _check_schema_compatible,
pre_order_visit,
promote,
prune_columns,
@@ -1407,7 +1408,7 @@ class
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
# This can be removed once this has been fixed:
# https://github.com/apache/arrow/issues/38809
list_array = pa.LargeListArray.from_arrays(list_array.offsets,
value_array)
-
+ value_array = self._cast_if_needed(list_type.element_field,
value_array)
arrow_field =
pa.large_list(self._construct_field(list_type.element_field, value_array.type))
return list_array.cast(arrow_field)
else:
@@ -1417,6 +1418,8 @@ class
ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Arra
self, map_type: MapType, map_array: Optional[pa.Array], key_result:
Optional[pa.Array], value_result: Optional[pa.Array]
) -> Optional[pa.Array]:
if isinstance(map_array, pa.MapArray) and key_result is not None and
value_result is not None:
+ key_result = self._cast_if_needed(map_type.key_field, key_result)
+ value_result = self._cast_if_needed(map_type.value_field,
value_result)
arrow_field = pa.map_(
self._construct_field(map_type.key_field, key_result.type),
self._construct_field(map_type.value_field, value_result.type),
@@ -1549,9 +1552,16 @@ class StatsAggregator:
expected_physical_type = _primitive_to_physical(iceberg_type)
if expected_physical_type != physical_type_string:
- raise ValueError(
- f"Unexpected physical type {physical_type_string} for
{iceberg_type}, expected {expected_physical_type}"
- )
+ # Allow promotable physical types
+ # INT32 -> INT64 and FLOAT -> DOUBLE are safe type casts
+ if (physical_type_string == "INT32" and expected_physical_type ==
"INT64") or (
+ physical_type_string == "FLOAT" and expected_physical_type ==
"DOUBLE"
+ ):
+ pass
+ else:
+ raise ValueError(
+ f"Unexpected physical type {physical_type_string} for
{iceberg_type}, expected {expected_physical_type}"
+ )
self.primitive_type = iceberg_type
@@ -1896,16 +1906,6 @@ def data_file_statistics_from_parquet_metadata(
set the mode for column metrics collection
parquet_column_mapping (Dict[str, int]): The mapping of the parquet
file name to the field ID
"""
- if parquet_metadata.num_columns != len(stats_columns):
- raise ValueError(
- f"Number of columns in statistics configuration
({len(stats_columns)}) is different from the number of columns in pyarrow table
({parquet_metadata.num_columns})"
- )
-
- if parquet_metadata.num_columns != len(parquet_column_mapping):
- raise ValueError(
- f"Number of columns in column mapping
({len(parquet_column_mapping)}) is different from the number of columns in
pyarrow table ({parquet_metadata.num_columns})"
- )
-
column_sizes: Dict[int, int] = {}
value_counts: Dict[int, int] = {}
split_offsets: List[int] = []
@@ -1998,8 +1998,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata,
tasks: Iterator[WriteT
)
def write_parquet(task: WriteTask) -> DataFile:
- table_schema = task.schema
-
+ table_schema = table_metadata.schema()
# if schema needs to be transformed, use the transformed schema and
adjust the arrow table accordingly
# otherwise use the original schema
if (sanitized_schema := sanitize_column_names(table_schema)) !=
table_schema:
@@ -2011,7 +2010,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata,
tasks: Iterator[WriteT
batches = [
_to_requested_schema(
requested_schema=file_schema,
- file_schema=table_schema,
+ file_schema=task.schema,
batch=batch,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
include_field_ids=True,
@@ -2070,47 +2069,30 @@ def bin_pack_arrow_table(tbl: pa.Table,
target_file_size: int) -> Iterator[List[
return bin_packed_record_batches
-def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema,
downcast_ns_timestamp_to_us: bool = False) -> None:
+def _check_pyarrow_schema_compatible(
+ requested_schema: Schema, provided_schema: pa.Schema,
downcast_ns_timestamp_to_us: bool = False
+) -> None:
"""
- Check if the `table_schema` is compatible with `other_schema`.
+ Check if the `requested_schema` is compatible with `provided_schema`.
Two schemas are considered compatible when they are equal in terms of the
Iceberg Schema type.
Raises:
ValueError: If the schemas are not compatible.
"""
- name_mapping = table_schema.name_mapping
+ name_mapping = requested_schema.name_mapping
try:
- task_schema = pyarrow_to_schema(
- other_schema, name_mapping=name_mapping,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
+ provided_schema = pyarrow_to_schema(
+ provided_schema, name_mapping=name_mapping,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
- other_schema = _pyarrow_to_schema_without_ids(other_schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
- additional_names = set(other_schema.column_names) -
set(table_schema.column_names)
+ provided_schema = _pyarrow_to_schema_without_ids(provided_schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+ additional_names = set(provided_schema._name_to_id.keys()) -
set(requested_schema._name_to_id.keys())
raise ValueError(
f"PyArrow table contains more columns: {',
'.join(sorted(additional_names))}. Update the schema first (hint, use
union_by_name)."
) from e
- if table_schema.as_struct() != task_schema.as_struct():
- from rich.console import Console
- from rich.table import Table as RichTable
-
- console = Console(record=True)
-
- rich_table = RichTable(show_header=True, header_style="bold")
- rich_table.add_column("")
- rich_table.add_column("Table field")
- rich_table.add_column("Dataframe field")
-
- for lhs in table_schema.fields:
- try:
- rhs = task_schema.find_field(lhs.field_id)
- rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs),
str(rhs))
- except ValueError:
- rich_table.add_row("❌", str(lhs), "Missing")
-
- console.print(rich_table)
- raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
+ _check_schema_compatible(requested_schema, provided_schema)
def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata,
file_paths: Iterator[str]) -> Iterator[DataFile]:
@@ -2124,7 +2106,7 @@ def parquet_files_to_data_files(io: FileIO,
table_metadata: TableMetadata, file_
f"Cannot add file {file_path} because it has field IDs.
`add_files` only supports addition of files without field_ids"
)
schema = table_metadata.schema()
- _check_schema_compatible(schema,
parquet_metadata.schema.to_arrow_schema())
+ _check_pyarrow_schema_compatible(schema,
parquet_metadata.schema.to_arrow_schema())
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=parquet_metadata,
@@ -2205,7 +2187,7 @@ def _dataframe_to_data_files(
Returns:
An iterable that supplies datafiles that represent the table.
"""
- from pyiceberg.table import PropertyUtil, TableProperties, WriteTask
+ from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE,
PropertyUtil, TableProperties, WriteTask
counter = counter or itertools.count(0)
write_uuid = write_uuid or uuid.uuid4()
@@ -2214,13 +2196,16 @@ def _dataframe_to_data_files(
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)
+ name_mapping = table_metadata.schema().name_mapping
+ downcast_ns_timestamp_to_us =
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
+ task_schema = pyarrow_to_schema(df.schema, name_mapping=name_mapping,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
if table_metadata.spec().is_unpartitioned():
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([
- WriteTask(write_uuid=write_uuid, task_id=next(counter),
record_batches=batches, schema=table_metadata.schema())
+ WriteTask(write_uuid=write_uuid, task_id=next(counter),
record_batches=batches, schema=task_schema)
for batches in bin_pack_arrow_table(df, target_file_size)
]),
)
@@ -2235,7 +2220,7 @@ def _dataframe_to_data_files(
task_id=next(counter),
record_batches=batches,
partition_key=partition.partition_key,
- schema=table_metadata.schema(),
+ schema=task_schema,
)
for partition in partitions
for batches in
bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py
index 77f1addb..cfe3fe3a 100644
--- a/pyiceberg/schema.py
+++ b/pyiceberg/schema.py
@@ -1616,3 +1616,103 @@ def _(file_type: FixedType, read_type: IcebergType) ->
IcebergType:
return read_type
else:
raise ResolveError(f"Cannot promote {file_type} to {read_type}")
+
+
+def _check_schema_compatible(requested_schema: Schema, provided_schema:
Schema) -> None:
+ """
+ Check if the `provided_schema` is compatible with `requested_schema`.
+
+ Both Schemas must have valid IDs and share the same ID for the same field
names.
+
+ Two schemas are considered compatible when:
+ 1. All `required` fields in `requested_schema` are present and are also
`required` in the `provided_schema`
+ 2. Field Types are consistent for fields that are present in both schemas.
I.e. the field type
+ in the `provided_schema` can be promoted to the field type of the same
field ID in `requested_schema`
+
+ Raises:
+ ValueError: If the schemas are not compatible.
+ """
+ pre_order_visit(requested_schema,
_SchemaCompatibilityVisitor(provided_schema))
+
+
+class _SchemaCompatibilityVisitor(PreOrderSchemaVisitor[bool]):
+ provided_schema: Schema
+
+ def __init__(self, provided_schema: Schema):
+ from rich.console import Console
+ from rich.table import Table as RichTable
+
+ self.provided_schema = provided_schema
+ self.rich_table = RichTable(show_header=True, header_style="bold")
+ self.rich_table.add_column("")
+ self.rich_table.add_column("Table field")
+ self.rich_table.add_column("Dataframe field")
+ self.console = Console(record=True)
+
+ def _is_field_compatible(self, lhs: NestedField) -> bool:
+ # Validate nullability first.
+ # An optional field can be missing in the provided schema
+ # But a required field must exist as a required field
+ try:
+ rhs = self.provided_schema.find_field(lhs.field_id)
+ except ValueError:
+ if lhs.required:
+ self.rich_table.add_row("❌", str(lhs), "Missing")
+ return False
+ else:
+ self.rich_table.add_row("✅", str(lhs), "Missing")
+ return True
+
+ if lhs.required and not rhs.required:
+ self.rich_table.add_row("❌", str(lhs), str(rhs))
+ return False
+
+ # Check type compatibility
+ if lhs.field_type == rhs.field_type:
+ self.rich_table.add_row("✅", str(lhs), str(rhs))
+ return True
+ # We only check that the parent node is also of the same type.
+ # We check the type of the child nodes when we traverse them later.
+ elif any(
+ (isinstance(lhs.field_type, container_type) and
isinstance(rhs.field_type, container_type))
+ for container_type in {StructType, MapType, ListType}
+ ):
+ self.rich_table.add_row("✅", str(lhs), str(rhs))
+ return True
+ else:
+ try:
+ # If type can be promoted to the requested schema
+ # it is considered compatible
+ promote(rhs.field_type, lhs.field_type)
+ self.rich_table.add_row("✅", str(lhs), str(rhs))
+ return True
+ except ResolveError:
+ self.rich_table.add_row("❌", str(lhs), str(rhs))
+ return False
+
+ def schema(self, schema: Schema, struct_result: Callable[[], bool]) ->
bool:
+ if not (result := struct_result()):
+ self.console.print(self.rich_table)
+ raise ValueError(f"Mismatch in
fields:\n{self.console.export_text()}")
+ return result
+
+ def struct(self, struct: StructType, field_results: List[Callable[[],
bool]]) -> bool:
+ results = [result() for result in field_results]
+ return all(results)
+
+ def field(self, field: NestedField, field_result: Callable[[], bool]) ->
bool:
+ return self._is_field_compatible(field) and field_result()
+
+ def list(self, list_type: ListType, element_result: Callable[[], bool]) ->
bool:
+ return self._is_field_compatible(list_type.element_field) and
element_result()
+
+ def map(self, map_type: MapType, key_result: Callable[[], bool],
value_result: Callable[[], bool]) -> bool:
+ return all([
+ self._is_field_compatible(map_type.key_field),
+ self._is_field_compatible(map_type.value_field),
+ key_result(),
+ value_result(),
+ ])
+
+ def primitive(self, primitive: PrimitiveType) -> bool:
+ return True
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index b43dc320..0b211e67 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -73,7 +73,6 @@ from pyiceberg.expressions.visitors import (
manifest_evaluator,
)
from pyiceberg.io import FileIO, OutputFile, load_file_io
-from pyiceberg.io.pyarrow import _check_schema_compatible,
_dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.manifest import (
POSITIONAL_DELETE_SCHEMA,
DataFile,
@@ -471,6 +470,8 @@ class Transaction:
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be
installed") from e
+ from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible,
_dataframe_to_data_files
+
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")
@@ -481,8 +482,8 @@ class Transaction:
f"Not all partition types are supported for writes. Following
partitions cannot be written using pyarrow: {unsupported_partitions}."
)
downcast_ns_timestamp_to_us =
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
- _check_schema_compatible(
- self._table.schema(), other_schema=df.schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
+ _check_pyarrow_schema_compatible(
+ self._table.schema(), provided_schema=df.schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
manifest_merge_enabled = PropertyUtil.property_as_bool(
@@ -528,6 +529,8 @@ class Transaction:
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be
installed") from e
+ from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible,
_dataframe_to_data_files
+
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")
@@ -538,8 +541,8 @@ class Transaction:
f"Not all partition types are supported for writes. Following
partitions cannot be written using pyarrow: {unsupported_partitions}."
)
downcast_ns_timestamp_to_us =
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
- _check_schema_compatible(
- self._table.schema(), other_schema=df.schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
+ _check_pyarrow_schema_compatible(
+ self._table.schema(), provided_schema=df.schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
self.delete(delete_filter=overwrite_filter,
snapshot_properties=snapshot_properties)
@@ -566,6 +569,8 @@ class Transaction:
delete_filter: A boolean expression to delete rows from a table
snapshot_properties: Custom properties to be added to the snapshot
summary
"""
+ from pyiceberg.io.pyarrow import _dataframe_to_data_files,
expression_to_pyarrow, project_table
+
if (
self.table_metadata.properties.get(TableProperties.DELETE_MODE,
TableProperties.DELETE_MODE_DEFAULT)
== TableProperties.DELETE_MODE_MERGE_ON_READ
diff --git a/tests/conftest.py b/tests/conftest.py
index 91ab8f2e..7f9a2bcf 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2506,3 +2506,62 @@ def
table_schema_with_all_microseconds_timestamp_precision() -> Schema:
NestedField(field_id=10, name="timestamptz_ns_z",
field_type=TimestamptzType(), required=False),
NestedField(field_id=11, name="timestamptz_s_0000",
field_type=TimestamptzType(), required=False),
)
+
+
[email protected](scope="session")
+def table_schema_with_promoted_types() -> Schema:
+ """Iceberg table Schema with longs, doubles and uuid in simple and nested
types."""
+ return Schema(
+ NestedField(field_id=1, name="long", field_type=LongType(),
required=False),
+ NestedField(
+ field_id=2,
+ name="list",
+ field_type=ListType(element_id=4, element_type=LongType(),
element_required=False),
+ required=True,
+ ),
+ NestedField(
+ field_id=3,
+ name="map",
+ field_type=MapType(
+ key_id=5,
+ key_type=StringType(),
+ value_id=6,
+ value_type=LongType(),
+ value_required=False,
+ ),
+ required=True,
+ ),
+ NestedField(field_id=7, name="double", field_type=DoubleType(),
required=False),
+ NestedField(field_id=8, name="uuid", field_type=UUIDType(),
required=False),
+ )
+
+
[email protected](scope="session")
+def pyarrow_schema_with_promoted_types() -> "pa.Schema":
+ """Pyarrow Schema with longs, doubles and uuid in simple and nested
types."""
+ import pyarrow as pa
+
+ return pa.schema((
+ pa.field("long", pa.int32(), nullable=True), # can support upcasting
integer to long
+ pa.field("list", pa.list_(pa.int32()), nullable=False), # can support
upcasting integer to long
+ pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), #
can support upcasting integer to long
+ pa.field("double", pa.float32(), nullable=True), # can support
upcasting float to double
+ pa.field("uuid", pa.binary(length=16), nullable=True), # can support
upcasting float to double
+ ))
+
+
[email protected](scope="session")
+def pyarrow_table_with_promoted_types(pyarrow_schema_with_promoted_types:
"pa.Schema") -> "pa.Table":
+ """Pyarrow table with longs, doubles and uuid in simple and nested
types."""
+ import pyarrow as pa
+
+ return pa.Table.from_pydict(
+ {
+ "long": [1, 9],
+ "list": [[1, 1], [2, 2]],
+ "map": [{"a": 1}, {"b": 2}],
+ "double": [1.1, 9.2],
+ "uuid": [b"qZx\xefNS@\x89\x9b\xf9:\xd0\xee\x9b\xf5E",
b"\x97]\x87T^JDJ\x96\x97\xf4v\xe4\x03\x0c\xde"],
+ },
+ schema=pyarrow_schema_with_promoted_types,
+ )
diff --git a/tests/integration/test_add_files.py
b/tests/integration/test_add_files.py
index b8fd6d09..3703a9e0 100644
--- a/tests/integration/test_add_files.py
+++ b/tests/integration/test_add_files.py
@@ -30,6 +30,7 @@ from pytest_mock.plugin import MockerFixture
from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.io import FileIO
+from pyiceberg.io.pyarrow import _pyarrow_schema_ensure_large_types
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC,
PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import Table
@@ -38,6 +39,7 @@ from pyiceberg.types import (
BooleanType,
DateType,
IntegerType,
+ LongType,
NestedField,
StringType,
TimestamptzType,
@@ -505,7 +507,7 @@ def test_add_files_fails_on_schema_mismatch(spark:
SparkSession, session_catalog
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │
-| ✅ │ 2: bar: optional string │ 2: bar: optional string │
+│ ✅ │ 2: bar: optional string │ 2: bar: optional string │
│ ❌ │ 3: baz: optional int │ 3: baz: optional string │
│ ✅ │ 4: qux: optional date │ 4: qux: optional date │
└────┴──────────────────────────┴──────────────────────────┘
@@ -589,18 +591,7 @@ def
test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_v
mocker.patch.dict(os.environ,
values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"})
identifier = f"default.timestamptz_ns_added{format_version}"
-
- try:
- session_catalog.drop_table(identifier=identifier)
- except NoSuchTableError:
- pass
-
- tbl = session_catalog.create_table(
- identifier=identifier,
- schema=nanoseconds_schema_iceberg,
- properties={"format-version": str(format_version)},
- partition_spec=PartitionSpec(),
- )
+ tbl = _create_table(session_catalog, identifier, format_version,
schema=nanoseconds_schema_iceberg)
file_path =
f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test.parquet"
# write parquet files
@@ -617,3 +608,127 @@ def
test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_v
),
):
tbl.add_files(file_paths=[file_path])
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_add_file_with_valid_nullability_diff(spark: SparkSession,
session_catalog: Catalog, format_version: int) -> None:
+ identifier =
f"default.test_table_with_valid_nullability_diff{format_version}"
+ table_schema = Schema(
+ NestedField(field_id=1, name="long", field_type=LongType(),
required=False),
+ )
+ other_schema = pa.schema((
+ pa.field("long", pa.int64(), nullable=False), # can support writing
required pyarrow field to optional Iceberg field
+ ))
+ arrow_table = pa.Table.from_pydict(
+ {
+ "long": [1, 9],
+ },
+ schema=other_schema,
+ )
+ tbl = _create_table(session_catalog, identifier, format_version,
schema=table_schema)
+
+ file_path =
f"s3://warehouse/default/test_add_file_with_valid_nullability_diff/v{format_version}/test.parquet"
+ # write parquet files
+ fo = tbl.io.new_output(file_path)
+ with fo.create(overwrite=True) as fos:
+ with pq.ParquetWriter(fos, schema=other_schema) as writer:
+ writer.write_table(arrow_table)
+
+ tbl.add_files(file_paths=[file_path])
+ # table's long field should cast to be optional on read
+ written_arrow_table = tbl.scan().to_arrow()
+ assert written_arrow_table == arrow_table.cast(pa.schema((pa.field("long",
pa.int64(), nullable=True),)))
+ lhs = spark.table(f"{identifier}").toPandas()
+ rhs = written_arrow_table.to_pandas()
+
+ for column in written_arrow_table.column_names:
+ for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+ assert left == right
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_add_files_with_valid_upcast(
+ spark: SparkSession,
+ session_catalog: Catalog,
+ format_version: int,
+ table_schema_with_promoted_types: Schema,
+ pyarrow_schema_with_promoted_types: pa.Schema,
+ pyarrow_table_with_promoted_types: pa.Table,
+) -> None:
+ identifier = f"default.test_table_with_valid_upcast{format_version}"
+ tbl = _create_table(session_catalog, identifier, format_version,
schema=table_schema_with_promoted_types)
+
+ file_path =
f"s3://warehouse/default/test_add_files_with_valid_upcast/v{format_version}/test.parquet"
+ # write parquet files
+ fo = tbl.io.new_output(file_path)
+ with fo.create(overwrite=True) as fos:
+ with pq.ParquetWriter(fos, schema=pyarrow_schema_with_promoted_types)
as writer:
+ writer.write_table(pyarrow_table_with_promoted_types)
+
+ tbl.add_files(file_paths=[file_path])
+ # table's long field should cast to long on read
+ written_arrow_table = tbl.scan().to_arrow()
+ assert written_arrow_table == pyarrow_table_with_promoted_types.cast(
+ pa.schema((
+ pa.field("long", pa.int64(), nullable=True),
+ pa.field("list", pa.large_list(pa.int64()), nullable=False),
+ pa.field("map", pa.map_(pa.large_string(), pa.int64()),
nullable=False),
+ pa.field("double", pa.float64(), nullable=True),
+ pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID
is read as fixed length binary of length 16
+ ))
+ )
+ lhs = spark.table(f"{identifier}").toPandas()
+ rhs = written_arrow_table.to_pandas()
+
+ for column in written_arrow_table.column_names:
+ for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+ if column == "map":
+ # Arrow returns a list of tuples, instead of a dict
+ right = dict(right)
+ if column == "list":
+ # Arrow returns an array, convert to list for equality check
+ left, right = list(left), list(right)
+ if column == "uuid":
+ # Spark Iceberg represents UUID as hex string like
'715a78ef-4e53-4089-9bf9-3ad0ee9bf545'
+ # whereas PyIceberg represents UUID as bytes on read
+ left, right = left.replace("-", ""), right.hex()
+ assert left == right
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_add_files_subset_of_schema(spark: SparkSession, session_catalog:
Catalog, format_version: int) -> None:
+ identifier = f"default.test_table_subset_of_schema{format_version}"
+ tbl = _create_table(session_catalog, identifier, format_version)
+
+ file_path =
f"s3://warehouse/default/test_add_files_subset_of_schema/v{format_version}/test.parquet"
+ arrow_table_without_some_columns =
ARROW_TABLE.combine_chunks().drop(ARROW_TABLE.column_names[0])
+
+ # write parquet files
+ fo = tbl.io.new_output(file_path)
+ with fo.create(overwrite=True) as fos:
+ with pq.ParquetWriter(fos,
schema=arrow_table_without_some_columns.schema) as writer:
+ writer.write_table(arrow_table_without_some_columns)
+
+ tbl.add_files(file_paths=[file_path])
+ written_arrow_table = tbl.scan().to_arrow()
+ assert tbl.scan().to_arrow() == pa.Table.from_pylist(
+ [
+ {
+ "foo": None, # Missing column is read as None on read
+ "bar": "bar_string",
+ "baz": 123,
+ "qux": date(2024, 3, 7),
+ }
+ ],
+ schema=_pyarrow_schema_ensure_large_types(ARROW_SCHEMA),
+ )
+
+ lhs = spark.table(f"{identifier}").toPandas()
+ rhs = written_arrow_table.to_pandas()
+
+ for column in written_arrow_table.column_names:
+ for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+ assert left == right
diff --git a/tests/integration/test_writes/test_writes.py
b/tests/integration/test_writes/test_writes.py
index 41bc6fb5..09fe654d 100644
--- a/tests/integration/test_writes/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -43,7 +43,7 @@ from pyiceberg.partitioning import PartitionField,
PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import TableProperties
from pyiceberg.transforms import IdentityTransform
-from pyiceberg.types import IntegerType, NestedField
+from pyiceberg.types import IntegerType, LongType, NestedField
from utils import _create_table
@@ -964,9 +964,10 @@ def test_sanitize_character_partitioned(catalog: Catalog)
-> None:
assert len(tbl.scan().to_arrow()) == 22
[email protected]
@pytest.mark.parametrize("format_version", [1, 2])
-def table_write_subset_of_schema(session_catalog: Catalog,
arrow_table_with_null: pa.Table, format_version: int) -> None:
- identifier = "default.table_append_subset_of_schema"
+def test_table_write_subset_of_schema(session_catalog: Catalog,
arrow_table_with_null: pa.Table, format_version: int) -> None:
+ identifier = "default.test_table_write_subset_of_schema"
tbl = _create_table(session_catalog, identifier, {"format-version":
format_version}, [arrow_table_with_null])
arrow_table_without_some_columns =
arrow_table_with_null.combine_chunks().drop(arrow_table_with_null.column_names[0])
assert len(arrow_table_without_some_columns.columns) <
len(arrow_table_with_null.columns)
@@ -976,6 +977,101 @@ def table_write_subset_of_schema(session_catalog:
Catalog, arrow_table_with_null
assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns)
* 2
[email protected]
[email protected]("format_version", [1, 2])
+def test_table_write_out_of_order_schema(session_catalog: Catalog,
arrow_table_with_null: pa.Table, format_version: int) -> None:
+ identifier = "default.test_table_write_out_of_order_schema"
+ # rotate the schema fields by 1
+ fields = list(arrow_table_with_null.schema)
+ rotated_fields = fields[1:] + fields[:1]
+ rotated_schema = pa.schema(rotated_fields)
+ assert arrow_table_with_null.schema != rotated_schema
+ tbl = _create_table(session_catalog, identifier, {"format-version":
format_version}, schema=rotated_schema)
+
+ tbl.overwrite(arrow_table_with_null)
+ tbl.append(arrow_table_with_null)
+ # overwrite and then append should produce twice the data
+ assert len(tbl.scan().to_arrow()) == len(arrow_table_with_null) * 2
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_table_write_schema_with_valid_nullability_diff(
+ spark: SparkSession, session_catalog: Catalog, arrow_table_with_null:
pa.Table, format_version: int
+) -> None:
+ identifier = "default.test_table_write_with_valid_nullability_diff"
+ table_schema = Schema(
+ NestedField(field_id=1, name="long", field_type=LongType(),
required=False),
+ )
+ other_schema = pa.schema((
+ pa.field("long", pa.int64(), nullable=False), # can support writing
required pyarrow field to optional Iceberg field
+ ))
+ arrow_table = pa.Table.from_pydict(
+ {
+ "long": [1, 9],
+ },
+ schema=other_schema,
+ )
+ tbl = _create_table(session_catalog, identifier, {"format-version":
format_version}, [arrow_table], schema=table_schema)
+ # table's long field should cast to be optional on read
+ written_arrow_table = tbl.scan().to_arrow()
+ assert written_arrow_table == arrow_table.cast(pa.schema((pa.field("long",
pa.int64(), nullable=True),)))
+ lhs = spark.table(f"{identifier}").toPandas()
+ rhs = written_arrow_table.to_pandas()
+
+ for column in written_arrow_table.column_names:
+ for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+ assert left == right
+
+
[email protected]
[email protected]("format_version", [1, 2])
+def test_table_write_schema_with_valid_upcast(
+ spark: SparkSession,
+ session_catalog: Catalog,
+ format_version: int,
+ table_schema_with_promoted_types: Schema,
+ pyarrow_schema_with_promoted_types: pa.Schema,
+ pyarrow_table_with_promoted_types: pa.Table,
+) -> None:
+ identifier = "default.test_table_write_with_valid_upcast"
+
+ tbl = _create_table(
+ session_catalog,
+ identifier,
+ {"format-version": format_version},
+ [pyarrow_table_with_promoted_types],
+ schema=table_schema_with_promoted_types,
+ )
+ # table's long field should cast to long on read
+ written_arrow_table = tbl.scan().to_arrow()
+ assert written_arrow_table == pyarrow_table_with_promoted_types.cast(
+ pa.schema((
+ pa.field("long", pa.int64(), nullable=True),
+ pa.field("list", pa.large_list(pa.int64()), nullable=False),
+ pa.field("map", pa.map_(pa.large_string(), pa.int64()),
nullable=False),
+ pa.field("double", pa.float64(), nullable=True), # can support
upcasting float to double
+ pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID
is read as fixed length binary of length 16
+ ))
+ )
+ lhs = spark.table(f"{identifier}").toPandas()
+ rhs = written_arrow_table.to_pandas()
+
+ for column in written_arrow_table.column_names:
+ for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
+ if column == "map":
+ # Arrow returns a list of tuples, instead of a dict
+ right = dict(right)
+ if column == "list":
+ # Arrow returns an array, convert to list for equality check
+ left, right = list(left), list(right)
+ if column == "uuid":
+ # Spark Iceberg represents UUID as hex string like
'715a78ef-4e53-4089-9bf9-3ad0ee9bf545'
+ # whereas PyIceberg represents UUID as bytes on read
+ left, right = left.replace("-", ""), right.hex()
+ assert left == right
+
+
@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_write_all_timestamp_precision(
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index 37198b7e..d61a50bb 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -60,7 +60,7 @@ from pyiceberg.io.pyarrow import (
PyArrowFile,
PyArrowFileIO,
StatsAggregator,
- _check_schema_compatible,
+ _check_pyarrow_schema_compatible,
_ConvertToArrowSchema,
_determine_partitions,
_primitive_to_physical,
@@ -1742,7 +1742,7 @@ def test_schema_mismatch_type(table_schema_simple:
Schema) -> None:
"""
with pytest.raises(ValueError, match=expected):
- _check_schema_compatible(table_schema_simple, other_schema)
+ _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
@@ -1763,7 +1763,20 @@ def
test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
"""
with pytest.raises(ValueError, match=expected):
- _check_schema_compatible(table_schema_simple, other_schema)
+ _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
+
+
+def test_schema_compatible_nullability_diff(table_schema_simple: Schema) ->
None:
+ other_schema = pa.schema((
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("bar", pa.int32(), nullable=False),
+ pa.field("baz", pa.bool_(), nullable=False),
+ ))
+
+ try:
+ _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
+ except Exception:
+ pytest.fail("Unexpected Exception raised when calling
`_check_pyarrow_schema_compatible`")
def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
@@ -1783,21 +1796,114 @@ def
test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
"""
with pytest.raises(ValueError, match=expected):
- _check_schema_compatible(table_schema_simple, other_schema)
+ _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
+
+
+def test_schema_compatible_missing_nullable_field_nested(table_schema_nested:
Schema) -> None:
+ schema = table_schema_nested.as_arrow()
+ schema = schema.remove(6).insert(
+ 6,
+ pa.field(
+ "person",
+ pa.struct([
+ pa.field("age", pa.int32(), nullable=False),
+ ]),
+ nullable=True,
+ ),
+ )
+ try:
+ _check_pyarrow_schema_compatible(table_schema_nested, schema)
+ except Exception:
+ pytest.fail("Unexpected Exception raised when calling
`_check_pyarrow_schema_compatible`")
+
+
+def test_schema_mismatch_missing_required_field_nested(table_schema_nested:
Schema) -> None:
+ other_schema = table_schema_nested.as_arrow()
+ other_schema = other_schema.remove(6).insert(
+ 6,
+ pa.field(
+ "person",
+ pa.struct([
+ pa.field("name", pa.string(), nullable=True),
+ ]),
+ nullable=True,
+ ),
+ )
+ expected = """Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ ┃ Table field ┃ Dataframe field
┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
+│ ✅ │ 2: bar: required int │ 2: bar: required int │
+│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
+│ ✅ │ 4: qux: required list<string> │ 4: qux: required list<string> │
+│ ✅ │ 5: element: required string │ 5: element: required string │
+│ ✅ │ 6: quux: required map<string, │ 6: quux: required map<string, │
+│ │ map<string, int>> │ map<string, int>>
│
+│ ✅ │ 7: key: required string │ 7: key: required string │
+│ ✅ │ 8: value: required map<string, │ 8: value: required map<string, │
+│ │ int> │ int>
│
+│ ✅ │ 9: key: required string │ 9: key: required string │
+│ ✅ │ 10: value: required int │ 10: value: required int │
+│ ✅ │ 11: location: required │ 11: location: required │
+│ │ list<struct<13: latitude: optional │ list<struct<13: latitude: optional
│
+│ │ float, 14: longitude: optional │ float, 14: longitude: optional
│
+│ │ float>> │ float>>
│
+│ ✅ │ 12: element: required struct<13: │ 12: element: required struct<13: │
+│ │ latitude: optional float, 14: │ latitude: optional float, 14:
│
+│ │ longitude: optional float> │ longitude: optional float>
│
+│ ✅ │ 13: latitude: optional float │ 13: latitude: optional float │
+│ ✅ │ 14: longitude: optional float │ 14: longitude: optional float │
+│ ✅ │ 15: person: optional struct<16: │ 15: person: optional struct<16: │
+│ │ name: optional string, 17: age: │ name: optional string>
│
+│ │ required int> │
│
+│ ✅ │ 16: name: optional string │ 16: name: optional string │
+│ ❌ │ 17: age: required int │ Missing │
+└────┴────────────────────────────────────┴────────────────────────────────────┘
+"""
+
+ with pytest.raises(ValueError, match=expected):
+ _check_pyarrow_schema_compatible(table_schema_nested, other_schema)
+
+
+def test_schema_compatible_nested(table_schema_nested: Schema) -> None:
+ try:
+ _check_pyarrow_schema_compatible(table_schema_nested,
table_schema_nested.as_arrow())
+ except Exception:
+ pytest.fail("Unexpected Exception raised when calling
`_check_pyarrow_schema_compatible`")
def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
- pa.field("bar", pa.int32(), nullable=True),
+ pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
pa.field("new_field", pa.date32(), nullable=True),
))
- expected = r"PyArrow table contains more columns: new_field. Update the
schema first \(hint, use union_by_name\)."
+ with pytest.raises(
+ ValueError, match=r"PyArrow table contains more columns: new_field.
Update the schema first \(hint, use union_by_name\)."
+ ):
+ _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
- with pytest.raises(ValueError, match=expected):
- _check_schema_compatible(table_schema_simple, other_schema)
+
+def test_schema_compatible(table_schema_simple: Schema) -> None:
+ try:
+ _check_pyarrow_schema_compatible(table_schema_simple,
table_schema_simple.as_arrow())
+ except Exception:
+ pytest.fail("Unexpected Exception raised when calling
`_check_pyarrow_schema_compatible`")
+
+
+def test_schema_projection(table_schema_simple: Schema) -> None:
+ # remove optional `baz` field from `table_schema_simple`
+ other_schema = pa.schema((
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("bar", pa.int32(), nullable=False),
+ ))
+ try:
+ _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
+ except Exception:
+ pytest.fail("Unexpected Exception raised when calling
`_check_pyarrow_schema_compatible`")
def test_schema_downcast(table_schema_simple: Schema) -> None:
@@ -1809,9 +1915,9 @@ def test_schema_downcast(table_schema_simple: Schema) ->
None:
))
try:
- _check_schema_compatible(table_schema_simple, other_schema)
+ _check_pyarrow_schema_compatible(table_schema_simple, other_schema)
except Exception:
- pytest.fail("Unexpected Exception raised when calling `_check_schema`")
+ pytest.fail("Unexpected Exception raised when calling
`_check_pyarrow_schema_compatible`")
def test_partition_for_demo() -> None: