This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch pyiceberg-0.6.x
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/pyiceberg-0.6.x by this push:
     new b9362eea Improve error message in case of a mismatch (#352) (#560)
b9362eea is described below

commit b9362eea005e2a437971ca72a64ba967e4cbbbd2
Author: Honah J <[email protected]>
AuthorDate: Sat Mar 30 17:45:17 2024 -0700

    Improve error message in case of a mismatch (#352) (#560)
    
    Backport to 0.6.1
    
    Co-authored-by: Fokko Driesprong <[email protected]>
---
 pyiceberg/io/pyarrow.py         |  4 +++
 pyiceberg/schema.py             | 12 +++++++
 pyiceberg/table/__init__.py     | 39 +++++++++++++++++++++
 pyiceberg/table/name_mapping.py |  6 +++-
 tests/table/test_init.py        | 78 +++++++++++++++++++++++++++++++++++++++++
 5 files changed, 138 insertions(+), 1 deletion(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 02f72c7c..038414ee 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -655,6 +655,10 @@ def pyarrow_to_schema(schema: pa.Schema, name_mapping: 
Optional[NameMapping] = N
     return visit_pyarrow(schema, visitor)
 
 
+def _pyarrow_to_schema_without_ids(schema: pa.Schema) -> Schema:
+    return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
+
+
 @singledispatch
 def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: 
PyArrowSchemaVisitor[T]) -> T:
     """Apply a pyarrow schema visitor to any point within a schema.
diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py
index 6dd174f3..e805895a 100644
--- a/pyiceberg/schema.py
+++ b/pyiceberg/schema.py
@@ -22,6 +22,7 @@ from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from functools import cached_property, partial, singledispatch
 from typing import (
+    TYPE_CHECKING,
     Any,
     Callable,
     Dict,
@@ -62,6 +63,11 @@ from pyiceberg.types import (
     UUIDType,
 )
 
+if TYPE_CHECKING:
+    from pyiceberg.table.name_mapping import (
+        NameMapping,
+    )
+
 T = TypeVar("T")
 P = TypeVar("P")
 
@@ -221,6 +227,12 @@ class Schema(IcebergBaseModel):
     def highest_field_id(self) -> int:
         return max(self._lazy_id_to_name.keys(), default=0)
 
+    @cached_property
+    def name_mapping(self) -> NameMapping:
+        from pyiceberg.table.name_mapping import create_mapping_from_schema
+
+        return create_mapping_from_schema(self)
+
     def find_column_name(self, column_id: int) -> Optional[str]:
         """Find a column name given a column ID.
 
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index fd9192bc..1f5e3131 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -132,6 +132,41 @@ TABLE_ROOT_ID = -1
 _JAVA_LONG_MAX = 9223372036854775807
 
 
+def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
+    from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, 
pyarrow_to_schema
+
+    name_mapping = table_schema.name_mapping
+    try:
+        task_schema = pyarrow_to_schema(other_schema, 
name_mapping=name_mapping)
+    except ValueError as e:
+        other_schema = _pyarrow_to_schema_without_ids(other_schema)
+        additional_names = set(other_schema.column_names) - 
set(table_schema.column_names)
+        raise ValueError(
+            f"PyArrow table contains more columns: {', 
'.join(sorted(additional_names))}. Update the schema first (hint, use 
union_by_name)."
+        ) from e
+
+    if table_schema.as_struct() != task_schema.as_struct():
+        from rich.console import Console
+        from rich.table import Table as RichTable
+
+        console = Console(record=True)
+
+        rich_table = RichTable(show_header=True, header_style="bold")
+        rich_table.add_column("")
+        rich_table.add_column("Table field")
+        rich_table.add_column("Dataframe field")
+
+        for lhs in table_schema.fields:
+            try:
+                rhs = task_schema.find_field(lhs.field_id)
+                rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), 
str(rhs))
+            except ValueError:
+                rich_table.add_row("❌", str(lhs), "Missing")
+
+        console.print(rich_table)
+        raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
+
+
 class TableProperties:
     PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes"
     PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024  # 128 MB
@@ -1009,6 +1044,8 @@ class Table:
         if len(self.spec().fields) > 0:
             raise ValueError("Cannot write to partitioned tables")
 
+        _check_schema(self.schema(), other_schema=df.schema)
+
         merge = _MergingSnapshotProducer(operation=Operation.APPEND, 
table=self)
 
         # skip writing data files if the dataframe is empty
@@ -1042,6 +1079,8 @@ class Table:
         if len(self.spec().fields) > 0:
             raise ValueError("Cannot write to partitioned tables")
 
+        _check_schema(self.schema(), other_schema=df.schema)
+
         merge = _MergingSnapshotProducer(
             operation=Operation.OVERWRITE if self.current_snapshot() is not 
None else Operation.APPEND,
             table=self,
diff --git a/pyiceberg/table/name_mapping.py b/pyiceberg/table/name_mapping.py
index ffe96359..9990d836 100644
--- a/pyiceberg/table/name_mapping.py
+++ b/pyiceberg/table/name_mapping.py
@@ -26,7 +26,7 @@ from __future__ import annotations
 from abc import ABC, abstractmethod
 from collections import ChainMap
 from functools import cached_property, singledispatch
-from typing import Any, Dict, Generic, List, TypeVar, Union
+from typing import Any, Dict, Generic, Iterator, List, TypeVar, Union
 
 from pydantic import Field, conlist, field_validator, model_serializer
 
@@ -85,6 +85,10 @@ class NameMapping(IcebergRootModel[List[MappedField]]):
         """Return the number of mappings."""
         return len(self.root)
 
+    def __iter__(self) -> Iterator[MappedField]:
+        """Iterate over the mapped fields."""
+        return iter(self.root)
+
     def __str__(self) -> str:
         """Convert the name-mapping into a nicely formatted string."""
         if len(self.root) == 0:
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index d2c1082f..d660759a 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -19,6 +19,7 @@ import uuid
 from copy import copy
 from typing import Dict
 
+import pyarrow as pa
 import pytest
 from sortedcontainers import SortedList
 
@@ -57,6 +58,7 @@ from pyiceberg.table import (
     Table,
     UpdateSchema,
     _apply_table_update,
+    _check_schema,
     _generate_snapshot_id,
     _match_deletes_to_data_file,
     _TableMetadataUpdateContext,
@@ -982,3 +984,79 @@ def test_correct_schema() -> None:
         _ = t.scan(snapshot_id=-1).projection()
 
     assert "Snapshot not found: -1" in str(exc_info.value)
+
+
+def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
+    other_schema = pa.schema((
+        pa.field("foo", pa.string(), nullable=True),
+        pa.field("bar", pa.decimal128(18, 6), nullable=False),
+        pa.field("baz", pa.bool_(), nullable=True),
+    ))
+
+    expected = r"""Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃    ┃ Table field              ┃ Dataframe field                 ┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional string  │ 1: foo: optional string         │
+│ ❌ │ 2: bar: required int     │ 2: bar: required decimal\(18, 6\) │
+│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean        │
+└────┴──────────────────────────┴─────────────────────────────────┘
+"""
+
+    with pytest.raises(ValueError, match=expected):
+        _check_schema(table_schema_simple, other_schema)
+
+
+def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
+    other_schema = pa.schema((
+        pa.field("foo", pa.string(), nullable=True),
+        pa.field("bar", pa.int32(), nullable=True),
+        pa.field("baz", pa.bool_(), nullable=True),
+    ))
+
+    expected = """Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃    ┃ Table field              ┃ Dataframe field          ┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional string  │ 1: foo: optional string  │
+│ ❌ │ 2: bar: required int     │ 2: bar: optional int     │
+│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
+└────┴──────────────────────────┴──────────────────────────┘
+"""
+
+    with pytest.raises(ValueError, match=expected):
+        _check_schema(table_schema_simple, other_schema)
+
+
+def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
+    other_schema = pa.schema((
+        pa.field("foo", pa.string(), nullable=True),
+        pa.field("baz", pa.bool_(), nullable=True),
+    ))
+
+    expected = """Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃    ┃ Table field              ┃ Dataframe field          ┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional string  │ 1: foo: optional string  │
+│ ❌ │ 2: bar: required int     │ Missing                  │
+│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
+└────┴──────────────────────────┴──────────────────────────┘
+"""
+
+    with pytest.raises(ValueError, match=expected):
+        _check_schema(table_schema_simple, other_schema)
+
+
+def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
+    other_schema = pa.schema((
+        pa.field("foo", pa.string(), nullable=True),
+        pa.field("bar", pa.int32(), nullable=True),
+        pa.field("baz", pa.bool_(), nullable=True),
+        pa.field("new_field", pa.date32(), nullable=True),
+    ))
+
+    expected = r"PyArrow table contains more columns: new_field. Update the 
schema first \(hint, use union_by_name\)."
+
+    with pytest.raises(ValueError, match=expected):
+        _check_schema(table_schema_simple, other_schema)

Reply via email to