This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch pyiceberg-0.6.x
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/pyiceberg-0.6.x by this push:
new b9362eea Improve error message in case of a mismatch (#352) (#560)
b9362eea is described below
commit b9362eea005e2a437971ca72a64ba967e4cbbbd2
Author: Honah J <[email protected]>
AuthorDate: Sat Mar 30 17:45:17 2024 -0700
Improve error message in case of a mismatch (#352) (#560)
Backport to 0.6.1
Co-authored-by: Fokko Driesprong <[email protected]>
---
pyiceberg/io/pyarrow.py | 4 +++
pyiceberg/schema.py | 12 +++++++
pyiceberg/table/__init__.py | 39 +++++++++++++++++++++
pyiceberg/table/name_mapping.py | 6 +++-
tests/table/test_init.py | 78 +++++++++++++++++++++++++++++++++++++++++
5 files changed, 138 insertions(+), 1 deletion(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 02f72c7c..038414ee 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -655,6 +655,10 @@ def pyarrow_to_schema(schema: pa.Schema, name_mapping:
Optional[NameMapping] = N
return visit_pyarrow(schema, visitor)
+def _pyarrow_to_schema_without_ids(schema: pa.Schema) -> Schema:
+ return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
+
+
@singledispatch
def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor:
PyArrowSchemaVisitor[T]) -> T:
"""Apply a pyarrow schema visitor to any point within a schema.
diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py
index 6dd174f3..e805895a 100644
--- a/pyiceberg/schema.py
+++ b/pyiceberg/schema.py
@@ -22,6 +22,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import cached_property, partial, singledispatch
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
Dict,
@@ -62,6 +63,11 @@ from pyiceberg.types import (
UUIDType,
)
+if TYPE_CHECKING:
+ from pyiceberg.table.name_mapping import (
+ NameMapping,
+ )
+
T = TypeVar("T")
P = TypeVar("P")
@@ -221,6 +227,12 @@ class Schema(IcebergBaseModel):
def highest_field_id(self) -> int:
return max(self._lazy_id_to_name.keys(), default=0)
+ @cached_property
+ def name_mapping(self) -> NameMapping:
+ from pyiceberg.table.name_mapping import create_mapping_from_schema
+
+ return create_mapping_from_schema(self)
+
def find_column_name(self, column_id: int) -> Optional[str]:
"""Find a column name given a column ID.
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index fd9192bc..1f5e3131 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -132,6 +132,41 @@ TABLE_ROOT_ID = -1
_JAVA_LONG_MAX = 9223372036854775807
+def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
+ from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids,
pyarrow_to_schema
+
+ name_mapping = table_schema.name_mapping
+ try:
+ task_schema = pyarrow_to_schema(other_schema,
name_mapping=name_mapping)
+ except ValueError as e:
+ other_schema = _pyarrow_to_schema_without_ids(other_schema)
+ additional_names = set(other_schema.column_names) -
set(table_schema.column_names)
+ raise ValueError(
+ f"PyArrow table contains more columns: {',
'.join(sorted(additional_names))}. Update the schema first (hint, use
union_by_name)."
+ ) from e
+
+ if table_schema.as_struct() != task_schema.as_struct():
+ from rich.console import Console
+ from rich.table import Table as RichTable
+
+ console = Console(record=True)
+
+ rich_table = RichTable(show_header=True, header_style="bold")
+ rich_table.add_column("")
+ rich_table.add_column("Table field")
+ rich_table.add_column("Dataframe field")
+
+ for lhs in table_schema.fields:
+ try:
+ rhs = task_schema.find_field(lhs.field_id)
+ rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs),
str(rhs))
+ except ValueError:
+ rich_table.add_row("❌", str(lhs), "Missing")
+
+ console.print(rich_table)
+ raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
+
+
class TableProperties:
PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes"
PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB
@@ -1009,6 +1044,8 @@ class Table:
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")
+ _check_schema(self.schema(), other_schema=df.schema)
+
merge = _MergingSnapshotProducer(operation=Operation.APPEND,
table=self)
# skip writing data files if the dataframe is empty
@@ -1042,6 +1079,8 @@ class Table:
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")
+ _check_schema(self.schema(), other_schema=df.schema)
+
merge = _MergingSnapshotProducer(
operation=Operation.OVERWRITE if self.current_snapshot() is not
None else Operation.APPEND,
table=self,
diff --git a/pyiceberg/table/name_mapping.py b/pyiceberg/table/name_mapping.py
index ffe96359..9990d836 100644
--- a/pyiceberg/table/name_mapping.py
+++ b/pyiceberg/table/name_mapping.py
@@ -26,7 +26,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from collections import ChainMap
from functools import cached_property, singledispatch
-from typing import Any, Dict, Generic, List, TypeVar, Union
+from typing import Any, Dict, Generic, Iterator, List, TypeVar, Union
from pydantic import Field, conlist, field_validator, model_serializer
@@ -85,6 +85,10 @@ class NameMapping(IcebergRootModel[List[MappedField]]):
"""Return the number of mappings."""
return len(self.root)
+ def __iter__(self) -> Iterator[MappedField]:
+ """Iterate over the mapped fields."""
+ return iter(self.root)
+
def __str__(self) -> str:
"""Convert the name-mapping into a nicely formatted string."""
if len(self.root) == 0:
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index d2c1082f..d660759a 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -19,6 +19,7 @@ import uuid
from copy import copy
from typing import Dict
+import pyarrow as pa
import pytest
from sortedcontainers import SortedList
@@ -57,6 +58,7 @@ from pyiceberg.table import (
Table,
UpdateSchema,
_apply_table_update,
+ _check_schema,
_generate_snapshot_id,
_match_deletes_to_data_file,
_TableMetadataUpdateContext,
@@ -982,3 +984,79 @@ def test_correct_schema() -> None:
_ = t.scan(snapshot_id=-1).projection()
assert "Snapshot not found: -1" in str(exc_info.value)
+
+
+def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
+ other_schema = pa.schema((
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("bar", pa.decimal128(18, 6), nullable=False),
+ pa.field("baz", pa.bool_(), nullable=True),
+ ))
+
+ expected = r"""Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ ┃ Table field ┃ Dataframe field ┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
+│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │
+│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
+└────┴──────────────────────────┴─────────────────────────────────┘
+"""
+
+ with pytest.raises(ValueError, match=expected):
+ _check_schema(table_schema_simple, other_schema)
+
+
+def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
+ other_schema = pa.schema((
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("bar", pa.int32(), nullable=True),
+ pa.field("baz", pa.bool_(), nullable=True),
+ ))
+
+ expected = """Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ ┃ Table field ┃ Dataframe field ┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
+│ ❌ │ 2: bar: required int │ 2: bar: optional int │
+│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
+└────┴──────────────────────────┴──────────────────────────┘
+"""
+
+ with pytest.raises(ValueError, match=expected):
+ _check_schema(table_schema_simple, other_schema)
+
+
+def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
+ other_schema = pa.schema((
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("baz", pa.bool_(), nullable=True),
+ ))
+
+ expected = """Mismatch in fields:
+┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ ┃ Table field ┃ Dataframe field ┃
+┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
+│ ❌ │ 2: bar: required int │ Missing │
+│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
+└────┴──────────────────────────┴──────────────────────────┘
+"""
+
+ with pytest.raises(ValueError, match=expected):
+ _check_schema(table_schema_simple, other_schema)
+
+
+def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
+ other_schema = pa.schema((
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("bar", pa.int32(), nullable=True),
+ pa.field("baz", pa.bool_(), nullable=True),
+ pa.field("new_field", pa.date32(), nullable=True),
+ ))
+
+ expected = r"PyArrow table contains more columns: new_field. Update the
schema first \(hint, use union_by_name\)."
+
+ with pytest.raises(ValueError, match=expected):
+ _check_schema(table_schema_simple, other_schema)