This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch pyiceberg-0.6.x
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/pyiceberg-0.6.x by this push:
     new 83c8c3ee [0.6.x] Backport PR #523 to cast data to iceberg table's 
pyarrow schema (#559)
83c8c3ee is described below

commit 83c8c3ee09bd9506b7aeea3475d52c21eafbf3c2
Author: Honah J <[email protected]>
AuthorDate: Sat Mar 30 18:26:14 2024 -0700

    [0.6.x] Backport PR #523 to cast data to iceberg table's pyarrow schema 
(#559)
    
    * Cast data to Iceberg Table's pyarrow schema (#523)
    
    Backport to 0.6.1
    
    * use schema_to_pyarrow directly for backporting
    
    * remove print in test
    
    ---------
    
    Co-authored-by: Kevin Liu <[email protected]>
---
 pyiceberg/table/__init__.py | 24 +++++++++++++++++++++---
 tests/catalog/test_sql.py   | 30 ++++++++++++++++++++++++++++++
 tests/table/test_init.py    | 24 +++++++++++++++++++-----
 3 files changed, 70 insertions(+), 8 deletions(-)

diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 1f5e3131..55795db3 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -132,7 +132,15 @@ TABLE_ROOT_ID = -1
 _JAVA_LONG_MAX = 9223372036854775807
 
 
-def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
+def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") 
-> None:
+    """
+    Check if the `table_schema` is compatible with `other_schema`.
+
+    Two schemas are considered compatible when they are equal in terms of the 
Iceberg Schema type.
+
+    Raises:
+        ValueError: If the schemas are not compatible.
+    """
     from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, 
pyarrow_to_schema
 
     name_mapping = table_schema.name_mapping
@@ -1044,7 +1052,12 @@ class Table:
         if len(self.spec().fields) > 0:
             raise ValueError("Cannot write to partitioned tables")
 
-        _check_schema(self.schema(), other_schema=df.schema)
+        from pyiceberg.io.pyarrow import schema_to_pyarrow
+
+        _check_schema_compatible(self.schema(), other_schema=df.schema)
+        # cast if the two schemas are compatible but not equal
+        if schema_to_pyarrow(self.schema()) != df.schema:
+            df = df.cast(schema_to_pyarrow(self.schema()))
 
         merge = _MergingSnapshotProducer(operation=Operation.APPEND, 
table=self)
 
@@ -1079,7 +1092,12 @@ class Table:
         if len(self.spec().fields) > 0:
             raise ValueError("Cannot write to partitioned tables")
 
-        _check_schema(self.schema(), other_schema=df.schema)
+        from pyiceberg.io.pyarrow import schema_to_pyarrow
+
+        _check_schema_compatible(self.schema(), other_schema=df.schema)
+        # cast if the two schemas are compatible but not equal
+        if schema_to_pyarrow(self.schema()) != df.schema:
+            df = df.cast(schema_to_pyarrow(self.schema()))
 
         merge = _MergingSnapshotProducer(
             operation=Operation.OVERWRITE if self.current_snapshot() is not 
None else Operation.APPEND,
diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py
index 421b148b..44755ffb 100644
--- a/tests/catalog/test_sql.py
+++ b/tests/catalog/test_sql.py
@@ -191,6 +191,36 @@ def test_create_table_with_pyarrow_schema(
     catalog.drop_table(random_identifier)
 
 
[email protected](
+    'catalog',
+    [
+        lazy_fixture('catalog_memory'),
+        lazy_fixture('catalog_sqlite'),
+    ],
+)
+def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: 
Identifier) -> None:
+    import pyarrow as pa
+
+    pyarrow_table = pa.Table.from_arrays(
+        [
+            pa.array([None, "A", "B", "C"]),  # 'foo' column
+            pa.array([1, 2, 3, 4]),  # 'bar' column
+            pa.array([True, None, False, True]),  # 'baz' column
+            pa.array([None, "A", "B", "C"]),  # 'large' column
+        ],
+        schema=pa.schema([
+            pa.field('foo', pa.string(), nullable=True),
+            pa.field('bar', pa.int32(), nullable=False),
+            pa.field('baz', pa.bool_(), nullable=True),
+            pa.field('large', pa.large_string(), nullable=True),
+        ]),
+    )
+    database_name, _table_name = random_identifier
+    catalog.create_namespace(database_name)
+    table = catalog.create_table(random_identifier, pyarrow_table.schema)
+    table.overwrite(pyarrow_table)
+
+
 @pytest.mark.parametrize(
     'catalog',
     [
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index d660759a..b99802c4 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -58,7 +58,7 @@ from pyiceberg.table import (
     Table,
     UpdateSchema,
     _apply_table_update,
-    _check_schema,
+    _check_schema_compatible,
     _generate_snapshot_id,
     _match_deletes_to_data_file,
     _TableMetadataUpdateContext,
@@ -1004,7 +1004,7 @@ def test_schema_mismatch_type(table_schema_simple: 
Schema) -> None:
 """
 
     with pytest.raises(ValueError, match=expected):
-        _check_schema(table_schema_simple, other_schema)
+        _check_schema_compatible(table_schema_simple, other_schema)
 
 
 def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
@@ -1025,7 +1025,7 @@ def test_schema_mismatch_nullability(table_schema_simple: 
Schema) -> None:
 """
 
     with pytest.raises(ValueError, match=expected):
-        _check_schema(table_schema_simple, other_schema)
+        _check_schema_compatible(table_schema_simple, other_schema)
 
 
 def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
@@ -1045,7 +1045,7 @@ def 
test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
 """
 
     with pytest.raises(ValueError, match=expected):
-        _check_schema(table_schema_simple, other_schema)
+        _check_schema_compatible(table_schema_simple, other_schema)
 
 
 def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
@@ -1059,4 +1059,18 @@ def 
test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
     expected = r"PyArrow table contains more columns: new_field. Update the 
schema first \(hint, use union_by_name\)."
 
     with pytest.raises(ValueError, match=expected):
-        _check_schema(table_schema_simple, other_schema)
+        _check_schema_compatible(table_schema_simple, other_schema)
+
+
+def test_schema_downcast(table_schema_simple: Schema) -> None:
+    # large_string type is compatible with string type
+    other_schema = pa.schema((
+        pa.field("foo", pa.large_string(), nullable=True),
+        pa.field("bar", pa.int32(), nullable=False),
+        pa.field("baz", pa.bool_(), nullable=True),
+    ))
+
+    try:
+        _check_schema_compatible(table_schema_simple, other_schema)
+    except Exception:
+        pytest.fail("Unexpected Exception raised when calling `_check_schema`")

Reply via email to