This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch pyiceberg-0.6.x
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/pyiceberg-0.6.x by this push:
new 83c8c3ee [0.6.x] Backport PR #523 to cast data to iceberg table's
pyarrow schema (#559)
83c8c3ee is described below
commit 83c8c3ee09bd9506b7aeea3475d52c21eafbf3c2
Author: Honah J <[email protected]>
AuthorDate: Sat Mar 30 18:26:14 2024 -0700
[0.6.x] Backport PR #523 to cast data to iceberg table's pyarrow schema
(#559)
* Cast data to Iceberg Table's pyarrow schema (#523)
Backport to 0.6.1
* use schema_to_pyarrow directly for backporting
* remove print in test
---------
Co-authored-by: Kevin Liu <[email protected]>
---
pyiceberg/table/__init__.py | 24 +++++++++++++++++++++---
tests/catalog/test_sql.py | 30 ++++++++++++++++++++++++++++++
tests/table/test_init.py | 24 +++++++++++++++++++-----
3 files changed, 70 insertions(+), 8 deletions(-)
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 1f5e3131..55795db3 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -132,7 +132,15 @@ TABLE_ROOT_ID = -1
_JAVA_LONG_MAX = 9223372036854775807
-def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
+def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema")
-> None:
+ """
+ Check if the `table_schema` is compatible with `other_schema`.
+
+ Two schemas are considered compatible when they are equal in terms of the
Iceberg Schema type.
+
+ Raises:
+ ValueError: If the schemas are not compatible.
+ """
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids,
pyarrow_to_schema
name_mapping = table_schema.name_mapping
@@ -1044,7 +1052,12 @@ class Table:
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")
- _check_schema(self.schema(), other_schema=df.schema)
+ from pyiceberg.io.pyarrow import schema_to_pyarrow
+
+ _check_schema_compatible(self.schema(), other_schema=df.schema)
+ # cast if the two schemas are compatible but not equal
+ if schema_to_pyarrow(self.schema()) != df.schema:
+ df = df.cast(schema_to_pyarrow(self.schema()))
merge = _MergingSnapshotProducer(operation=Operation.APPEND,
table=self)
@@ -1079,7 +1092,12 @@ class Table:
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")
- _check_schema(self.schema(), other_schema=df.schema)
+ from pyiceberg.io.pyarrow import schema_to_pyarrow
+
+ _check_schema_compatible(self.schema(), other_schema=df.schema)
+ # cast if the two schemas are compatible but not equal
+ if schema_to_pyarrow(self.schema()) != df.schema:
+ df = df.cast(schema_to_pyarrow(self.schema()))
merge = _MergingSnapshotProducer(
operation=Operation.OVERWRITE if self.current_snapshot() is not
None else Operation.APPEND,
diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py
index 421b148b..44755ffb 100644
--- a/tests/catalog/test_sql.py
+++ b/tests/catalog/test_sql.py
@@ -191,6 +191,36 @@ def test_create_table_with_pyarrow_schema(
catalog.drop_table(random_identifier)
[email protected](
+ 'catalog',
+ [
+ lazy_fixture('catalog_memory'),
+ lazy_fixture('catalog_sqlite'),
+ ],
+)
+def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier:
Identifier) -> None:
+ import pyarrow as pa
+
+ pyarrow_table = pa.Table.from_arrays(
+ [
+ pa.array([None, "A", "B", "C"]), # 'foo' column
+ pa.array([1, 2, 3, 4]), # 'bar' column
+ pa.array([True, None, False, True]), # 'baz' column
+ pa.array([None, "A", "B", "C"]), # 'large' column
+ ],
+ schema=pa.schema([
+ pa.field('foo', pa.string(), nullable=True),
+ pa.field('bar', pa.int32(), nullable=False),
+ pa.field('baz', pa.bool_(), nullable=True),
+ pa.field('large', pa.large_string(), nullable=True),
+ ]),
+ )
+ database_name, _table_name = random_identifier
+ catalog.create_namespace(database_name)
+ table = catalog.create_table(random_identifier, pyarrow_table.schema)
+ table.overwrite(pyarrow_table)
+
+
@pytest.mark.parametrize(
'catalog',
[
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index d660759a..b99802c4 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -58,7 +58,7 @@ from pyiceberg.table import (
Table,
UpdateSchema,
_apply_table_update,
- _check_schema,
+ _check_schema_compatible,
_generate_snapshot_id,
_match_deletes_to_data_file,
_TableMetadataUpdateContext,
@@ -1004,7 +1004,7 @@ def test_schema_mismatch_type(table_schema_simple:
Schema) -> None:
"""
with pytest.raises(ValueError, match=expected):
- _check_schema(table_schema_simple, other_schema)
+ _check_schema_compatible(table_schema_simple, other_schema)
def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
@@ -1025,7 +1025,7 @@ def test_schema_mismatch_nullability(table_schema_simple:
Schema) -> None:
"""
with pytest.raises(ValueError, match=expected):
- _check_schema(table_schema_simple, other_schema)
+ _check_schema_compatible(table_schema_simple, other_schema)
def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
@@ -1045,7 +1045,7 @@ def
test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
"""
with pytest.raises(ValueError, match=expected):
- _check_schema(table_schema_simple, other_schema)
+ _check_schema_compatible(table_schema_simple, other_schema)
def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
@@ -1059,4 +1059,18 @@ def
test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
expected = r"PyArrow table contains more columns: new_field. Update the
schema first \(hint, use union_by_name\)."
with pytest.raises(ValueError, match=expected):
- _check_schema(table_schema_simple, other_schema)
+ _check_schema_compatible(table_schema_simple, other_schema)
+
+
+def test_schema_downcast(table_schema_simple: Schema) -> None:
+ # large_string type is compatible with string type
+ other_schema = pa.schema((
+ pa.field("foo", pa.large_string(), nullable=True),
+ pa.field("bar", pa.int32(), nullable=False),
+ pa.field("baz", pa.bool_(), nullable=True),
+ ))
+
+ try:
+ _check_schema_compatible(table_schema_simple, other_schema)
+ except Exception:
+ pytest.fail("Unexpected Exception raised when calling `_check_schema`")