This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new dceedfac Check if schema is compatible in `add_files` API (#907)
dceedfac is described below
commit dceedfac4ec072ee4da99bf02dc93c1d27be45a9
Author: Sung Yun <[email protected]>
AuthorDate: Thu Jul 11 20:32:14 2024 -0400
Check if schema is compatible in `add_files` API (#907)
Co-authored-by: Fokko Driesprong <[email protected]>
---
pyiceberg/io/pyarrow.py | 45 ++++++++++++++++++
pyiceberg/table/__init__.py | 62 ++++---------------------
tests/integration/test_add_files.py | 85 ++++++++++++++++++++++++++--------
tests/io/test_pyarrow.py | 91 ++++++++++++++++++++++++++++++++++++
tests/table/test_init.py | 92 -------------------------------------
5 files changed, 211 insertions(+), 164 deletions(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 142e9e5f..56f22425 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -2032,6 +2032,49 @@ def bin_pack_arrow_table(tbl: pa.Table,
target_file_size: int) -> Iterator[List[
return bin_packed_record_batches
+def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema,
downcast_ns_timestamp_to_us: bool = False) -> None:
+ """
+ Check if the `table_schema` is compatible with `other_schema`.
+
+ Two schemas are considered compatible when they are equal in terms of the
Iceberg Schema type.
+
+ Raises:
+ ValueError: If the schemas are not compatible.
+ """
+ name_mapping = table_schema.name_mapping
+ try:
+ task_schema = pyarrow_to_schema(
+ other_schema, name_mapping=name_mapping,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
+ )
+ except ValueError as e:
+ other_schema = _pyarrow_to_schema_without_ids(other_schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
+ additional_names = set(other_schema.column_names) -
set(table_schema.column_names)
+ raise ValueError(
+ f"PyArrow table contains more columns: {',
'.join(sorted(additional_names))}. Update the schema first (hint, use
union_by_name)."
+ ) from e
+
+ if table_schema.as_struct() != task_schema.as_struct():
+ from rich.console import Console
+ from rich.table import Table as RichTable
+
+ console = Console(record=True)
+
+ rich_table = RichTable(show_header=True, header_style="bold")
+ rich_table.add_column("")
+ rich_table.add_column("Table field")
+ rich_table.add_column("Dataframe field")
+
+ for lhs in table_schema.fields:
+ try:
+ rhs = task_schema.find_field(lhs.field_id)
+ rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs),
str(rhs))
+ except ValueError:
+ rich_table.add_row("❌", str(lhs), "Missing")
+
+ console.print(rich_table)
+ raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
+
+
def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata,
file_paths: Iterator[str]) -> Iterator[DataFile]:
for file_path in file_paths:
input_file = io.new_input(file_path)
@@ -2043,6 +2086,8 @@ def parquet_files_to_data_files(io: FileIO,
table_metadata: TableMetadata, file_
f"Cannot add file {file_path} because it has field IDs.
`add_files` only supports addition of files without field_ids"
)
schema = table_metadata.schema()
+ _check_schema_compatible(schema,
parquet_metadata.schema.to_arrow_schema())
+
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=parquet_metadata,
stats_columns=compute_statistics_plan(schema,
table_metadata.properties),
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 5342d370..62440c47 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -73,7 +73,7 @@ from pyiceberg.expressions.visitors import (
manifest_evaluator,
)
from pyiceberg.io import FileIO, OutputFile, load_file_io
-from pyiceberg.io.pyarrow import _dataframe_to_data_files,
expression_to_pyarrow, project_table
+from pyiceberg.io.pyarrow import _check_schema_compatible,
_dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.manifest import (
POSITIONAL_DELETE_SCHEMA,
DataFile,
@@ -166,54 +166,8 @@ if TYPE_CHECKING:
ALWAYS_TRUE = AlwaysTrue()
TABLE_ROOT_ID = -1
-DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"
_JAVA_LONG_MAX = 9223372036854775807
-
-
-def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema")
-> None:
- """
- Check if the `table_schema` is compatible with `other_schema`.
-
- Two schemas are considered compatible when they are equal in terms of the
Iceberg Schema type.
-
- Raises:
- ValueError: If the schemas are not compatible.
- """
- from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids,
pyarrow_to_schema
-
- downcast_ns_timestamp_to_us =
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
- name_mapping = table_schema.name_mapping
- try:
- task_schema = pyarrow_to_schema(
- other_schema, name_mapping=name_mapping,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
- )
- except ValueError as e:
- other_schema = _pyarrow_to_schema_without_ids(other_schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
- additional_names = set(other_schema.column_names) -
set(table_schema.column_names)
- raise ValueError(
- f"PyArrow table contains more columns: {',
'.join(sorted(additional_names))}. Update the schema first (hint, use
union_by_name)."
- ) from e
-
- if table_schema.as_struct() != task_schema.as_struct():
- from rich.console import Console
- from rich.table import Table as RichTable
-
- console = Console(record=True)
-
- rich_table = RichTable(show_header=True, header_style="bold")
- rich_table.add_column("")
- rich_table.add_column("Table field")
- rich_table.add_column("Dataframe field")
-
- for lhs in table_schema.fields:
- try:
- rhs = task_schema.find_field(lhs.field_id)
- rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs),
str(rhs))
- except ValueError:
- rich_table.add_row("❌", str(lhs), "Missing")
-
- console.print(rich_table)
- raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
+DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"
class TableProperties:
@@ -526,8 +480,10 @@ class Transaction:
raise ValueError(
f"Not all partition types are supported for writes. Following
partitions cannot be written using pyarrow: {unsupported_partitions}."
)
-
- _check_schema_compatible(self._table.schema(), other_schema=df.schema)
+ downcast_ns_timestamp_to_us =
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
+ _check_schema_compatible(
+ self._table.schema(), other_schema=df.schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
+ )
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
@@ -585,8 +541,10 @@ class Transaction:
raise ValueError(
f"Not all partition types are supported for writes. Following
partitions cannot be written using pyarrow: {unsupported_partitions}."
)
-
- _check_schema_compatible(self._table.schema(), other_schema=df.schema)
+ downcast_ns_timestamp_to_us =
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
+ _check_schema_compatible(
+ self._table.schema(), other_schema=df.schema,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
+ )
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
diff --git a/tests/integration/test_add_files.py
b/tests/integration/test_add_files.py
index 825d17e9..984c7d11 100644
--- a/tests/integration/test_add_files.py
+++ b/tests/integration/test_add_files.py
@@ -17,6 +17,7 @@
# pylint:disable=redefined-outer-name
import os
+import re
from datetime import date
from typing import Iterator
@@ -463,6 +464,57 @@ def test_add_files_snapshot_properties(spark:
SparkSession, session_catalog: Cat
assert summary["snapshot_prop_a"] == "test_prop_a"
[email protected]
+def test_add_files_fails_on_schema_mismatch(spark: SparkSession,
session_catalog: Catalog, format_version: int) -> None:
+ identifier = f"default.table_schema_mismatch_fails_v{format_version}"
+
+ tbl = _create_table(session_catalog, identifier, format_version)
+ WRONG_SCHEMA = pa.schema([
+ ("foo", pa.bool_()),
+ ("bar", pa.string()),
+ ("baz", pa.string()), # should be integer
+ ("qux", pa.date32()),
+ ])
+ file_path =
f"s3://warehouse/default/table_schema_mismatch_fails/v{format_version}/test.parquet"
+ # write parquet files
+ fo = tbl.io.new_output(file_path)
+ with fo.create(overwrite=True) as fos:
+ with pq.ParquetWriter(fos, schema=WRONG_SCHEMA) as writer:
+ writer.write_table(
+ pa.Table.from_pylist(
+ [
+ {
+ "foo": True,
+ "bar": "bar_string",
+ "baz": "123",
+ "qux": date(2024, 3, 7),
+ },
+ {
+ "foo": True,
+ "bar": "bar_string",
+ "baz": "124",
+ "qux": date(2024, 3, 7),
+ },
+ ],
+ schema=WRONG_SCHEMA,
+ )
+ )
+
+ expected = """Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ ┃ Table field ┃ Dataframe field ┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │
+| ✅ │ 2: bar: optional string │ 2: bar: optional string │
+│ ❌ │ 3: baz: optional int │ 3: baz: optional string │
+│ ✅ │ 4: qux: optional date │ 4: qux: optional date │
+└────┴──────────────────────────┴──────────────────────────┘
+"""
+
+ with pytest.raises(ValueError, match=expected):
+ tbl.add_files(file_paths=[file_path])
+
+
@pytest.mark.integration
def test_add_files_with_large_and_regular_schema(spark: SparkSession,
session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.unpartitioned_with_large_types{format_version}"
@@ -518,7 +570,7 @@ def test_add_files_with_large_and_regular_schema(spark:
SparkSession, session_ca
assert table_schema == arrow_schema_large
-def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog,
format_version: int, mocker: MockerFixture) -> None:
+def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog,
format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux",
TimestamptzType()))
nanoseconds_schema = pa.schema([
@@ -549,25 +601,18 @@ def
test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_versi
partition_spec=PartitionSpec(),
)
- file_paths =
[f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test-{i}.parquet"
for i in range(5)]
+ file_path =
f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test.parquet"
# write parquet files
- for file_path in file_paths:
- fo = tbl.io.new_output(file_path)
- with fo.create(overwrite=True) as fos:
- with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer:
- writer.write_table(arrow_table)
+ fo = tbl.io.new_output(file_path)
+ with fo.create(overwrite=True) as fos:
+ with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer:
+ writer.write_table(arrow_table)
# add the parquet files as data files
- tbl.add_files(file_paths=file_paths)
-
- assert tbl.scan().to_arrow() == pa.concat_tables(
- [
- arrow_table.cast(
- pa.schema([
- ("quux", pa.timestamp("us", tz="UTC")),
- ]),
- safe=False,
- )
- ]
- * 5
- )
+ with pytest.raises(
+ TypeError,
+ match=re.escape(
+ "Iceberg does not yet support 'ns' timestamp precision. Use
'downcast-ns-timestamp-to-us-on-write' configuration property to automatically
downcast 'ns' to 'us' on write."
+ ),
+ ):
+ tbl.add_files(file_paths=[file_path])
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index 1b946899..326eeff1 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -60,6 +60,7 @@ from pyiceberg.io.pyarrow import (
PyArrowFile,
PyArrowFileIO,
StatsAggregator,
+ _check_schema_compatible,
_ConvertToArrowSchema,
_determine_partitions,
_primitive_to_physical,
@@ -1722,6 +1723,96 @@ def test_bin_pack_arrow_table(arrow_table_with_null:
pa.Table) -> None:
assert len(list(bin_packed)) == 5
+def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
+ other_schema = pa.schema((
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("bar", pa.decimal128(18, 6), nullable=False),
+ pa.field("baz", pa.bool_(), nullable=True),
+ ))
+
+ expected = r"""Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ ┃ Table field ┃ Dataframe field ┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
+│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │
+│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
+└────┴──────────────────────────┴─────────────────────────────────┘
+"""
+
+ with pytest.raises(ValueError, match=expected):
+ _check_schema_compatible(table_schema_simple, other_schema)
+
+
+def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
+ other_schema = pa.schema((
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("bar", pa.int32(), nullable=True),
+ pa.field("baz", pa.bool_(), nullable=True),
+ ))
+
+ expected = """Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ ┃ Table field ┃ Dataframe field ┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
+│ ❌ │ 2: bar: required int │ 2: bar: optional int │
+│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
+└────┴──────────────────────────┴──────────────────────────┘
+"""
+
+ with pytest.raises(ValueError, match=expected):
+ _check_schema_compatible(table_schema_simple, other_schema)
+
+
+def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
+ other_schema = pa.schema((
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("baz", pa.bool_(), nullable=True),
+ ))
+
+ expected = """Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ ┃ Table field ┃ Dataframe field ┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
+│ ❌ │ 2: bar: required int │ Missing │
+│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
+└────┴──────────────────────────┴──────────────────────────┘
+"""
+
+ with pytest.raises(ValueError, match=expected):
+ _check_schema_compatible(table_schema_simple, other_schema)
+
+
+def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
+ other_schema = pa.schema((
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("bar", pa.int32(), nullable=True),
+ pa.field("baz", pa.bool_(), nullable=True),
+ pa.field("new_field", pa.date32(), nullable=True),
+ ))
+
+ expected = r"PyArrow table contains more columns: new_field. Update the
schema first \(hint, use union_by_name\)."
+
+ with pytest.raises(ValueError, match=expected):
+ _check_schema_compatible(table_schema_simple, other_schema)
+
+
+def test_schema_downcast(table_schema_simple: Schema) -> None:
+ # large_string type is compatible with string type
+ other_schema = pa.schema((
+ pa.field("foo", pa.large_string(), nullable=True),
+ pa.field("bar", pa.int32(), nullable=False),
+ pa.field("baz", pa.bool_(), nullable=True),
+ ))
+
+ try:
+ _check_schema_compatible(table_schema_simple, other_schema)
+ except Exception:
+ pytest.fail("Unexpected Exception raised when calling `_check_schema`")
+
+
def test_partition_for_demo() -> None:
test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()),
("animal", pa.string())])
test_schema = Schema(
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index 31a8bbf4..7a5ea86d 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -19,7 +19,6 @@ import uuid
from copy import copy
from typing import Any, Dict
-import pyarrow as pa
import pytest
from pydantic import ValidationError
from sortedcontainers import SortedList
@@ -63,7 +62,6 @@ from pyiceberg.table import (
TableIdentifier,
UpdateSchema,
_apply_table_update,
- _check_schema_compatible,
_match_deletes_to_data_file,
_TableMetadataUpdateContext,
update_table_metadata,
@@ -1122,96 +1120,6 @@ def test_correct_schema() -> None:
assert "Snapshot not found: -1" in str(exc_info.value)
-def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
- other_schema = pa.schema((
- pa.field("foo", pa.string(), nullable=True),
- pa.field("bar", pa.decimal128(18, 6), nullable=False),
- pa.field("baz", pa.bool_(), nullable=True),
- ))
-
- expected = r"""Mismatch in fields:
-┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
-┃ ┃ Table field ┃ Dataframe field ┃
-┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
-│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
-│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │
-│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
-└────┴──────────────────────────┴─────────────────────────────────┘
-"""
-
- with pytest.raises(ValueError, match=expected):
- _check_schema_compatible(table_schema_simple, other_schema)
-
-
-def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
- other_schema = pa.schema((
- pa.field("foo", pa.string(), nullable=True),
- pa.field("bar", pa.int32(), nullable=True),
- pa.field("baz", pa.bool_(), nullable=True),
- ))
-
- expected = """Mismatch in fields:
-┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
-┃ ┃ Table field ┃ Dataframe field ┃
-┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
-│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
-│ ❌ │ 2: bar: required int │ 2: bar: optional int │
-│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
-└────┴──────────────────────────┴──────────────────────────┘
-"""
-
- with pytest.raises(ValueError, match=expected):
- _check_schema_compatible(table_schema_simple, other_schema)
-
-
-def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
- other_schema = pa.schema((
- pa.field("foo", pa.string(), nullable=True),
- pa.field("baz", pa.bool_(), nullable=True),
- ))
-
- expected = """Mismatch in fields:
-┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
-┃ ┃ Table field ┃ Dataframe field ┃
-┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
-│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
-│ ❌ │ 2: bar: required int │ Missing │
-│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
-└────┴──────────────────────────┴──────────────────────────┘
-"""
-
- with pytest.raises(ValueError, match=expected):
- _check_schema_compatible(table_schema_simple, other_schema)
-
-
-def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
- other_schema = pa.schema((
- pa.field("foo", pa.string(), nullable=True),
- pa.field("bar", pa.int32(), nullable=True),
- pa.field("baz", pa.bool_(), nullable=True),
- pa.field("new_field", pa.date32(), nullable=True),
- ))
-
- expected = r"PyArrow table contains more columns: new_field. Update the
schema first \(hint, use union_by_name\)."
-
- with pytest.raises(ValueError, match=expected):
- _check_schema_compatible(table_schema_simple, other_schema)
-
-
-def test_schema_downcast(table_schema_simple: Schema) -> None:
- # large_string type is compatible with string type
- other_schema = pa.schema((
- pa.field("foo", pa.large_string(), nullable=True),
- pa.field("bar", pa.int32(), nullable=False),
- pa.field("baz", pa.bool_(), nullable=True),
- ))
-
- try:
- _check_schema_compatible(table_schema_simple, other_schema)
- except Exception:
- pytest.fail("Unexpected Exception raised when calling `_check_schema`")
-
-
def test_table_properties(example_table_metadata_v2: Dict[str, Any]) -> None:
# metadata properties are all strings
for k, v in example_table_metadata_v2["properties"].items():