This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new cd7fb50  Add UnionByName functionality (#296)
cd7fb50 is described below

commit cd7fb502900a717d6b902a398b267eb10e4faa9b
Author: Fokko Driesprong <[email protected]>
AuthorDate: Fri Jan 26 14:36:20 2024 +0100

    Add UnionByName functionality (#296)
    
    * Add UnionByName functionality
    
    * Thanks Honah!
    
    * Add `_id`
    
    Co-authored-by: Honah J. <[email protected]>
    
    * Fix
    
    ---------
    
    Co-authored-by: Honah J. <[email protected]>
---
 mkdocs/docs/api.md          |  41 +++
 pyiceberg/table/__init__.py | 188 ++++++++++++-
 pyiceberg/types.py          |  12 +
 tests/test_schema.py        | 671 +++++++++++++++++++++++++++++++++++++++++++-
 4 files changed, 905 insertions(+), 7 deletions(-)

diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md
index 9d97d4f..6f79835 100644
--- a/mkdocs/docs/api.md
+++ b/mkdocs/docs/api.md
@@ -293,6 +293,47 @@ with table.transaction() as transaction:
     # ... Update properties etc
 ```
 
+### Union by Name
+
+Using `.union_by_name()` you can merge another schema into an existing schema 
without having to worry about field-IDs:
+
+```python
+from pyiceberg.catalog import load_catalog
+from pyiceberg.schema import Schema
+from pyiceberg.types import NestedField, StringType, DoubleType, LongType
+
+catalog = load_catalog()
+
+schema = Schema(
+    NestedField(1, "city", StringType(), required=False),
+    NestedField(2, "lat", DoubleType(), required=False),
+    NestedField(3, "long", DoubleType(), required=False),
+)
+
+table = catalog.create_table("default.locations", schema)
+
+new_schema = Schema(
+    NestedField(1, "city", StringType(), required=False),
+    NestedField(2, "lat", DoubleType(), required=False),
+    NestedField(3, "long", DoubleType(), required=False),
+    NestedField(10, "population", LongType(), required=False),
+)
+
+with table.update_schema() as update:
+    update.union_by_name(new_schema)
+```
+
+Now the table has the union of the two schemas `print(table.schema())`:
+
+```
+table {
+  1: city: optional string
+  2: lat: optional double
+  3: long: optional double
+  4: population: optional long
+}
+```
+
 ### Add column
 
 Using `add_column` you can add a column, without having to worry about the 
field-id:
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 057dd84..221a609 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -69,11 +69,14 @@ from pyiceberg.manifest import (
 )
 from pyiceberg.partitioning import PartitionSpec
 from pyiceberg.schema import (
+    PartnerAccessor,
     Schema,
     SchemaVisitor,
+    SchemaWithPartnerVisitor,
     assign_fresh_schema_ids,
     promote,
     visit,
+    visit_with_partner,
 )
 from pyiceberg.table.metadata import (
     INITIAL_SEQUENCE_NUMBER,
@@ -1379,7 +1382,7 @@ class Move:
 
 
 class UpdateSchema:
-    _table: Table
+    _table: Optional[Table]
     _schema: Schema
     _last_column_id: itertools.count[int]
     _identifier_field_names: Set[str]
@@ -1398,14 +1401,23 @@ class UpdateSchema:
 
     def __init__(
         self,
-        table: Table,
+        table: Optional[Table],
         transaction: Optional[Transaction] = None,
         allow_incompatible_changes: bool = False,
         case_sensitive: bool = True,
+        schema: Optional[Schema] = None,
     ) -> None:
         self._table = table
-        self._schema = table.schema()
-        self._last_column_id = itertools.count(table.metadata.last_column_id + 
1)
+
+        if isinstance(schema, Schema):
+            self._schema = schema
+            self._last_column_id = itertools.count(1 + schema.highest_field_id)
+        elif table is not None:
+            self._schema = table.schema()
+            self._last_column_id = itertools.count(1 + 
table.metadata.last_column_id)
+        else:
+            raise ValueError("Either provide a table or a schema")
+
         self._identifier_field_names = self._schema.identifier_field_names()
 
         self._adds = {}
@@ -1449,6 +1461,15 @@ class UpdateSchema:
         self._case_sensitive = case_sensitive
         return self
 
+    def union_by_name(self, new_schema: Schema) -> UpdateSchema:
+        visit_with_partner(
+            new_schema,
+            -1,
+            UnionByNameVisitor(update_schema=self, 
existing_schema=self._schema, case_sensitive=self._case_sensitive),  # type: 
ignore
+            PartnerIdByNameAccessor(partner_schema=self._schema, 
case_sensitive=self._case_sensitive),
+        )
+        return self
+
     def add_column(
         self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc: 
Optional[str] = None, required: bool = False
     ) -> UpdateSchema:
@@ -1816,6 +1837,9 @@ class UpdateSchema:
 
     def commit(self) -> None:
         """Apply the pending changes and commit."""
+        if self._table is None:
+            raise ValueError("Requires a table to commit to")
+
         new_schema = self._apply()
 
         existing_schema_id = next((schema.schema_id for schema in 
self._table.metadata.schemas if schema == new_schema), None)
@@ -1862,7 +1886,8 @@ class UpdateSchema:
 
             field_ids.add(field.field_id)
 
-        return Schema(*struct.fields, schema_id=1 + 
max(self._table.schemas().keys()), identifier_field_ids=field_ids)
+        next_schema_id = 1 + (max(self._table.schemas().keys()) if self._table 
is not None else self._schema.schema_id)
+        return Schema(*struct.fields, schema_id=next_schema_id, 
identifier_field_ids=field_ids)
 
     def assign_new_column_id(self) -> int:
         return next(self._last_column_id)
@@ -1995,6 +2020,159 @@ class 
_ApplyChanges(SchemaVisitor[Optional[IcebergType]]):
         return primitive
 
 
+class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]):
+    update_schema: UpdateSchema
+    existing_schema: Schema
+    case_sensitive: bool
+
+    def __init__(self, update_schema: UpdateSchema, existing_schema: Schema, 
case_sensitive: bool) -> None:
+        self.update_schema = update_schema
+        self.existing_schema = existing_schema
+        self.case_sensitive = case_sensitive
+
+    def schema(self, schema: Schema, partner_id: Optional[int], struct_result: 
bool) -> bool:
+        return struct_result
+
+    def struct(self, struct: StructType, partner_id: Optional[int], 
missing_positions: List[bool]) -> bool:
+        if partner_id is None:
+            return True
+
+        fields = struct.fields
+        partner_struct = self._find_field_type(partner_id)
+
+        if not partner_struct.is_struct:
+            raise ValueError(f"Expected a struct, got: {partner_struct}")
+
+        for pos, missing in enumerate(missing_positions):
+            if missing:
+                self._add_column(partner_id, fields[pos])
+            else:
+                field = fields[pos]
+                if nested_field := partner_struct.field_by_name(field.name, 
case_sensitive=self.case_sensitive):
+                    self._update_column(field, nested_field)
+
+        return False
+
+    def _add_column(self, parent_id: int, field: NestedField) -> None:
+        if parent_name := self.existing_schema.find_column_name(parent_id):
+            path: Tuple[str, ...] = (parent_name, field.name)
+        else:
+            path = (field.name,)
+
+        self.update_schema.add_column(path=path, field_type=field.field_type, 
required=field.required, doc=field.doc)
+
+    def _update_column(self, field: NestedField, existing_field: NestedField) 
-> None:
+        full_name = 
self.existing_schema.find_column_name(existing_field.field_id)
+
+        if full_name is None:
+            raise ValueError(f"Could not find field: {existing_field}")
+
+        if field.optional and existing_field.required:
+            self.update_schema.make_column_optional(full_name)
+
+        if field.field_type.is_primitive and field.field_type != 
existing_field.field_type:
+            self.update_schema.update_column(full_name, 
field_type=field.field_type)
+
+        if field.doc is not None and not field.doc != existing_field.doc:
+            self.update_schema.update_column(full_name, doc=field.doc)
+
+    def _find_field_type(self, field_id: int) -> IcebergType:
+        if field_id == -1:
+            return self.existing_schema.as_struct()
+        else:
+            return self.existing_schema.find_field(field_id).field_type
+
+    def field(self, field: NestedField, partner_id: Optional[int], 
field_result: bool) -> bool:
+        return partner_id is None
+
+    def list(self, list_type: ListType, list_partner_id: Optional[int], 
element_missing: bool) -> bool:
+        if list_partner_id is None:
+            return True
+
+        if element_missing:
+            raise ValueError("Error traversing schemas: element is missing, 
but list is present")
+
+        partner_list_type = self._find_field_type(list_partner_id)
+        if not isinstance(partner_list_type, ListType):
+            raise ValueError(f"Expected list-type, got: {partner_list_type}")
+
+        self._update_column(list_type.element_field, 
partner_list_type.element_field)
+
+        return False
+
+    def map(self, map_type: MapType, map_partner_id: Optional[int], 
key_missing: bool, value_missing: bool) -> bool:
+        if map_partner_id is None:
+            return True
+
+        if key_missing:
+            raise ValueError("Error traversing schemas: key is missing, but 
map is present")
+
+        if value_missing:
+            raise ValueError("Error traversing schemas: value is missing, but 
map is present")
+
+        partner_map_type = self._find_field_type(map_partner_id)
+        if not isinstance(partner_map_type, MapType):
+            raise ValueError(f"Expected map-type, got: {partner_map_type}")
+
+        self._update_column(map_type.key_field, partner_map_type.key_field)
+        self._update_column(map_type.value_field, partner_map_type.value_field)
+
+        return False
+
+    def primitive(self, primitive: PrimitiveType, primitive_partner_id: 
Optional[int]) -> bool:
+        return primitive_partner_id is None
+
+
+class PartnerIdByNameAccessor(PartnerAccessor[int]):
+    partner_schema: Schema
+    case_sensitive: bool
+
+    def __init__(self, partner_schema: Schema, case_sensitive: bool) -> None:
+        self.partner_schema = partner_schema
+        self.case_sensitive = case_sensitive
+
+    def schema_partner(self, partner: Optional[int]) -> Optional[int]:
+        return -1
+
+    def field_partner(self, partner_field_id: Optional[int], field_id: int, 
field_name: str) -> Optional[int]:
+        if partner_field_id is not None:
+            if partner_field_id == -1:
+                struct = self.partner_schema.as_struct()
+            else:
+                struct = 
self.partner_schema.find_field(partner_field_id).field_type
+                if not struct.is_struct:
+                    raise ValueError(f"Expected StructType: {struct}")
+
+            if field := struct.field_by_name(name=field_name, 
case_sensitive=self.case_sensitive):
+                return field.field_id
+
+        return None
+
+    def list_element_partner(self, partner_list_id: Optional[int]) -> 
Optional[int]:
+        if partner_list_id is not None and (field := 
self.partner_schema.find_field(partner_list_id)):
+            if not isinstance(field.field_type, ListType):
+                raise ValueError(f"Expected ListType: {field}")
+            return field.field_type.element_field.field_id
+        else:
+            return None
+
+    def map_key_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
+        if partner_map_id is not None and (field := 
self.partner_schema.find_field(partner_map_id)):
+            if not isinstance(field.field_type, MapType):
+                raise ValueError(f"Expected MapType: {field}")
+            return field.field_type.key_field.field_id
+        else:
+            return None
+
+    def map_value_partner(self, partner_map_id: Optional[int]) -> 
Optional[int]:
+        if partner_map_id is not None and (field := 
self.partner_schema.find_field(partner_map_id)):
+            if not isinstance(field.field_type, MapType):
+                raise ValueError(f"Expected MapType: {field}")
+            return field.field_type.value_field.field_id
+        else:
+            return None
+
+
 def _add_fields(fields: Tuple[NestedField, ...], adds: 
Optional[List[NestedField]]) -> Tuple[NestedField, ...]:
     adds = adds or []
     return fields + tuple(adds)
diff --git a/pyiceberg/types.py b/pyiceberg/types.py
index 5e7c193..eb21512 100644
--- a/pyiceberg/types.py
+++ b/pyiceberg/types.py
@@ -350,6 +350,18 @@ class StructType(IcebergType):
                 return field
         return None
 
+    def field_by_name(self, name: str, case_sensitive: bool = True) -> 
Optional[NestedField]:
+        if case_sensitive:
+            name_lower = name.lower()
+            for field in self.fields:
+                if field.name.lower() == name_lower:
+                    return field
+        else:
+            for field in self.fields:
+                if field.name == name:
+                    return field
+        return None
+
     def __str__(self) -> str:
         """Return the string representation of the StructType class."""
         return f"struct<{', '.join(map(str, self.fields))}>"
diff --git a/tests/test_schema.py b/tests/test_schema.py
index 8e34423..a5487b7 100644
--- a/tests/test_schema.py
+++ b/tests/test_schema.py
@@ -16,12 +16,12 @@
 # under the License.
 
 from textwrap import dedent
-from typing import Any, Dict
+from typing import Any, Dict, List
 
 import pytest
 
 from pyiceberg import schema
-from pyiceberg.exceptions import ResolveError
+from pyiceberg.exceptions import ResolveError, ValidationError
 from pyiceberg.expressions import Accessor
 from pyiceberg.schema import (
     Schema,
@@ -30,6 +30,7 @@ from pyiceberg.schema import (
     prune_columns,
     sanitize_column_names,
 )
+from pyiceberg.table import UpdateSchema
 from pyiceberg.typedef import EMPTY_DICT, StructProtocol
 from pyiceberg.types import (
     BinaryType,
@@ -45,6 +46,7 @@ from pyiceberg.types import (
     LongType,
     MapType,
     NestedField,
+    PrimitiveType,
     StringType,
     StructType,
     TimestampType,
@@ -912,3 +914,668 @@ def test_promotion(file_type: IcebergType, read_type: 
IcebergType) -> None:
     else:
         with pytest.raises(ResolveError):
             promote(file_type, read_type)
+
+
[email protected]()
+def primitive_fields() -> List[NestedField]:
+    return [
+        NestedField(field_id=1, name=str(primitive_type), 
field_type=primitive_type, required=False)
+        for primitive_type in TEST_PRIMITIVE_TYPES
+    ]
+
+
+def test_add_top_level_primitives(primitive_fields: NestedField) -> None:
+    for primitive_field in primitive_fields:
+        new_schema = Schema(primitive_field)
+        applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+        assert applied == new_schema
+
+
+def test_add_top_level_list_of_primitives(primitive_fields: NestedField) -> 
None:
+    for primitive_type in TEST_PRIMITIVE_TYPES:
+        new_schema = Schema(
+            NestedField(
+                field_id=1,
+                name="aList",
+                field_type=ListType(element_id=2, element_type=primitive_type, 
element_required=False),
+                required=False,
+            )
+        )
+        applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+        assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_top_level_map_of_primitives(primitive_fields: NestedField) -> 
None:
+    for primitive_type in TEST_PRIMITIVE_TYPES:
+        new_schema = Schema(
+            NestedField(
+                field_id=1,
+                name="aMap",
+                field_type=MapType(
+                    key_id=2, key_type=primitive_type, value_id=3, 
value_type=primitive_type, value_required=False
+                ),
+                required=False,
+            )
+        )
+        applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+        assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_top_struct_of_primitives(primitive_fields: NestedField) -> None:
+    for primitive_type in TEST_PRIMITIVE_TYPES:
+        new_schema = Schema(
+            NestedField(
+                field_id=1,
+                name="aStruct",
+                field_type=StructType(NestedField(field_id=2, 
name="primitive", field_type=primitive_type, required=False)),
+                required=False,
+            )
+        )
+        applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+        assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_nested_primitive(primitive_fields: NestedField) -> None:
+    for primitive_type in TEST_PRIMITIVE_TYPES:
+        current_schema = Schema(NestedField(field_id=1, name="aStruct", 
field_type=StructType(), required=False))
+        new_schema = Schema(
+            NestedField(
+                field_id=1,
+                name="aStruct",
+                field_type=StructType(NestedField(field_id=2, 
name="primitive", field_type=primitive_type, required=False)),
+                required=False,
+            )
+        )
+        applied = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+        assert applied.as_struct() == new_schema.as_struct()
+
+
+def _primitive_fields(types: List[PrimitiveType], start_id: int = 0) -> 
List[NestedField]:
+    fields = []
+    for iceberg_type in types:
+        fields.append(NestedField(field_id=start_id, name=str(iceberg_type), 
field_type=iceberg_type, required=False))
+        start_id = start_id + 1
+
+    return fields
+
+
+def test_add_nested_primitives(primitive_fields: NestedField) -> None:
+    current_schema = Schema(NestedField(field_id=1, name="aStruct", 
field_type=StructType(), required=False))
+    new_schema = Schema(
+        NestedField(
+            field_id=1, name="aStruct", 
field_type=StructType(*_primitive_fields(TEST_PRIMITIVE_TYPES, 2)), 
required=False
+        )
+    )
+    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+    assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_nested_lists(primitive_fields: NestedField) -> None:
+    new_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="aList",
+            type=ListType(
+                element_id=2,
+                element_type=ListType(
+                    element_id=3,
+                    element_type=ListType(
+                        element_id=4,
+                        element_type=ListType(
+                            element_id=5,
+                            element_type=ListType(
+                                element_id=6,
+                                element_type=ListType(
+                                    element_id=7,
+                                    element_type=ListType(
+                                        element_id=8,
+                                        element_type=ListType(element_id=9, 
element_type=DecimalType(precision=11, scale=20)),
+                                        element_required=False,
+                                    ),
+                                    element_required=False,
+                                ),
+                                element_required=False,
+                            ),
+                            element_required=False,
+                        ),
+                        element_required=False,
+                    ),
+                    element_required=False,
+                ),
+                element_required=False,
+            ),
+            required=False,
+        )
+    )
+    applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+    assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_nested_struct(primitive_fields: NestedField) -> None:
+    new_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="struct1",
+            type=StructType(
+                NestedField(
+                    field_id=2,
+                    name="struct2",
+                    type=StructType(
+                        NestedField(
+                            field_id=3,
+                            name="struct3",
+                            type=StructType(
+                                NestedField(
+                                    field_id=4,
+                                    name="struct4",
+                                    type=StructType(
+                                        NestedField(
+                                            field_id=5,
+                                            name="struct5",
+                                            type=StructType(
+                                                NestedField(
+                                                    field_id=6,
+                                                    name="struct6",
+                                                    type=StructType(
+                                                        
NestedField(field_id=7, name="aString", field_type=StringType())
+                                                    ),
+                                                    required=False,
+                                                )
+                                            ),
+                                            required=False,
+                                        )
+                                    ),
+                                    required=False,
+                                )
+                            ),
+                            required=False,
+                        )
+                    ),
+                    required=False,
+                )
+            ),
+            required=False,
+        )
+    )
+    applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+    assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_nested_maps(primitive_fields: NestedField) -> None:
+    new_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="struct",
+            field_type=MapType(
+                key_id=2,
+                value_id=3,
+                key_type=StringType(),
+                value_type=MapType(
+                    key_id=4,
+                    value_id=5,
+                    key_type=StringType(),
+                    value_type=MapType(
+                        key_id=6,
+                        value_id=7,
+                        key_type=StringType(),
+                        value_type=MapType(
+                            key_id=8,
+                            value_id=9,
+                            key_type=StringType(),
+                            value_type=MapType(
+                                key_id=10,
+                                value_id=11,
+                                key_type=StringType(),
+                                value_type=MapType(key_id=12, value_id=13, 
key_type=StringType(), value_type=StringType()),
+                                value_required=False,
+                            ),
+                            value_required=False,
+                        ),
+                        value_required=False,
+                    ),
+                    value_required=False,
+                ),
+                value_required=False,
+            ),
+            required=False,
+        )
+    )
+    applied = UpdateSchema(None, 
schema=Schema()).union_by_name(new_schema)._apply()
+    assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_detect_invalid_top_level_list() -> None:
+    current_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="aList",
+            field_type=ListType(element_id=2, element_type=StringType(), 
element_required=False),
+            required=False,
+        )
+    )
+    new_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="aList",
+            field_type=ListType(element_id=2, element_type=DoubleType(), 
element_required=False),
+            required=False,
+        )
+    )
+
+    with pytest.raises(ValidationError, match="Cannot change column type: 
aList.element: string -> double"):
+        _ = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+
+
+def test_detect_invalid_top_level_maps() -> None:
+    current_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="aMap",
+            field_type=MapType(key_id=2, value_id=3, key_type=StringType(), 
value_type=StringType(), value_required=False),
+            required=False,
+        )
+    )
+    new_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="aMap",
+            field_type=MapType(key_id=2, value_id=3, key_type=UUIDType(), 
value_type=StringType(), value_required=False),
+            required=False,
+        )
+    )
+
+    with pytest.raises(ValidationError, match="Cannot change column type: 
aMap.key: string -> uuid"):
+        _ = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+
+
+def test_promote_float_to_double() -> None:
+    current_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=FloatType(), required=False))
+    new_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=DoubleType(), required=False))
+
+    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+
+    assert applied.as_struct() == new_schema.as_struct()
+    assert len(applied.fields) == 1
+    assert isinstance(applied.fields[0].field_type, DoubleType)
+
+
+def test_detect_invalid_promotion_double_to_float() -> None:
+    current_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=DoubleType(), required=False))
+    new_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=FloatType(), required=False))
+
+    with pytest.raises(ValidationError, match="Cannot change column type: 
aCol: double -> float"):
+        _ = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+
+
+# decimal(P,S) Fixed-point decimal; precision P, scale S -> Scale is fixed [1],
+# precision must be 38 or less
+def test_type_promote_decimal_to_fixed_scale_with_wider_precision() -> None:
+    current_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=DecimalType(precision=20, scale=1), required=False))
+    new_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=DecimalType(precision=22, scale=1), required=False))
+
+    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+
+    assert applied.as_struct() == new_schema.as_struct()
+    assert len(applied.fields) == 1
+    field = applied.fields[0]
+    decimal_type = field.field_type
+    assert isinstance(decimal_type, DecimalType)
+    assert decimal_type.precision == 22
+    assert decimal_type.scale == 1
+
+
+def test_add_nested_structs(primitive_fields: NestedField) -> None:
+    schema = Schema(
+        NestedField(
+            field_id=1,
+            name="struct1",
+            field_type=StructType(
+                NestedField(
+                    field_id=2,
+                    name="struct2",
+                    field_type=StructType(
+                        NestedField(
+                            field_id=3,
+                            name="list",
+                            field_type=ListType(
+                                element_id=4,
+                                element_type=StructType(
+                                    NestedField(field_id=5, name="value", 
field_type=StringType(), required=False)
+                                ),
+                                element_required=False,
+                            ),
+                            required=False,
+                        )
+                    ),
+                    required=False,
+                )
+            ),
+            required=False,
+        )
+    )
+    new_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="struct1",
+            field_type=StructType(
+                NestedField(
+                    field_id=2,
+                    name="struct2",
+                    field_type=StructType(
+                        NestedField(
+                            field_id=3,
+                            name="list",
+                            field_type=ListType(
+                                element_id=4,
+                                element_type=StructType(
+                                    NestedField(field_id=5, name="time", 
field_type=TimeType(), required=False)
+                                ),
+                                element_required=False,
+                            ),
+                            required=False,
+                        )
+                    ),
+                    required=False,
+                )
+            ),
+            required=False,
+        )
+    )
+    applied = UpdateSchema(None, 
schema=schema).union_by_name(new_schema)._apply()
+
+    expected = Schema(
+        NestedField(
+            field_id=1,
+            name="struct1",
+            field_type=StructType(
+                NestedField(
+                    field_id=2,
+                    name="struct2",
+                    field_type=StructType(
+                        NestedField(
+                            field_id=3,
+                            name="list",
+                            field_type=ListType(
+                                element_id=4,
+                                element_type=StructType(
+                                    NestedField(field_id=5, name="value", 
field_type=StringType(), required=False),
+                                    NestedField(field_id=6, name="time", 
field_type=TimeType(), required=False),
+                                ),
+                                element_required=False,
+                            ),
+                            required=False,
+                        )
+                    ),
+                    required=False,
+                )
+            ),
+            required=False,
+        )
+    )
+
+    assert applied.as_struct() == expected.as_struct()
+
+
+def test_replace_list_with_primitive() -> None:
+    current_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=ListType(element_id=2, element_type=StringType())))
+    new_schema = Schema(NestedField(field_id=1, name="aCol", 
field_type=StringType()))
+
+    with pytest.raises(ValidationError, match="Cannot change column type: 
list<string> is not a primitive"):
+        _ = UpdateSchema(None, 
schema=current_schema).union_by_name(new_schema)._apply()
+
+
+def test_mirrored_schemas() -> None:
+    current_schema = Schema(
+        NestedField(9, "struct1", StructType(NestedField(8, "string1", 
StringType(), required=False)), required=False),
+        NestedField(6, "list1", ListType(element_id=7, 
element_type=StringType(), element_required=False), required=False),
+        NestedField(5, "string2", StringType(), required=False),
+        NestedField(4, "string3", StringType(), required=False),
+        NestedField(3, "string4", StringType(), required=False),
+        NestedField(2, "string5", StringType(), required=False),
+        NestedField(1, "string6", StringType(), required=False),
+    )
+    mirrored_schema = Schema(
+        NestedField(1, "struct1", StructType(NestedField(2, "string1", 
StringType(), required=False))),
+        NestedField(3, "list1", ListType(element_id=4, 
element_type=StringType(), element_required=False), required=False),
+        NestedField(5, "string2", StringType(), required=False),
+        NestedField(6, "string3", StringType(), required=False),
+        NestedField(7, "string4", StringType(), required=False),
+        NestedField(8, "string5", StringType(), required=False),
+        NestedField(9, "string6", StringType(), required=False),
+    )
+
+    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(mirrored_schema)._apply()
+
+    assert applied.as_struct() == current_schema.as_struct()
+
+
+def test_add_new_top_level_struct() -> None:
+    current_schema = Schema(
+        NestedField(
+            1,
+            "map1",
+            MapType(
+                key_id=2,
+                value_id=3,
+                key_type=StringType(),
+                value_type=ListType(
+                    element_id=4,
+                    element_type=StructType(NestedField(field_id=5, 
name="string", field_type=StringType(), required=False)),
+                ),
+                value_required=False,
+            ),
+        )
+    )
+    observed_schema = Schema(
+        NestedField(
+            1,
+            "map1",
+            MapType(
+                key_id=2,
+                value_id=3,
+                key_type=StringType(),
+                value_type=ListType(
+                    element_id=4,
+                    element_type=StructType(NestedField(field_id=5, 
name="string", field_type=StringType(), required=False)),
+                ),
+                value_required=False,
+            ),
+        ),
+        NestedField(
+            field_id=6,
+            name="struct1",
+            field_type=StructType(
+                NestedField(
+                    field_id=7,
+                    name="d1",
+                    field_type=StructType(NestedField(field_id=8, name="d2", 
field_type=StringType(), required=False)),
+                    required=False,
+                )
+            ),
+            required=False,
+        ),
+    )
+
+    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(observed_schema)._apply()
+
+    assert applied.as_struct() == observed_schema.as_struct()
+
+
+def test_append_nested_struct() -> None:
+    current_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="s1",
+            field_type=StructType(
+                NestedField(
+                    field_id=2,
+                    name="s2",
+                    field_type=StructType(
+                        NestedField(
+                            field_id=3,
+                            name="s3",
+                            field_type=StructType(NestedField(field_id=4, 
name="s4", field_type=StringType(), required=False)),
+                        )
+                    ),
+                    required=False,
+                )
+            ),
+        )
+    )
+    observed_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="s1",
+            field_type=StructType(
+                NestedField(
+                    field_id=2,
+                    name="s2",
+                    field_type=StructType(
+                        NestedField(
+                            field_id=3,
+                            name="s3",
+                            field_type=StructType(NestedField(field_id=4, 
name="s4", field_type=StringType(), required=False)),
+                            required=False,
+                        ),
+                        NestedField(
+                            field_id=5,
+                            name="repeat",
+                            field_type=StructType(
+                                NestedField(
+                                    field_id=6,
+                                    name="s1",
+                                    field_type=StructType(
+                                        NestedField(
+                                            field_id=7,
+                                            name="s2",
+                                            field_type=StructType(
+                                                NestedField(
+                                                    field_id=8,
+                                                    name="s3",
+                                                    field_type=StructType(
+                                                        
NestedField(field_id=9, name="s4", field_type=StringType())
+                                                    ),
+                                                    required=False,
+                                                )
+                                            ),
+                                            required=False,
+                                        )
+                                    ),
+                                    required=False,
+                                )
+                            ),
+                            required=False,
+                        ),
+                        required=False,
+                    ),
+                    required=False,
+                )
+            ),
+            required=False,
+        )
+    )
+
+    applied = UpdateSchema(None, 
schema=current_schema).union_by_name(observed_schema)._apply()
+
+    assert applied.as_struct() == observed_schema.as_struct()
+
+
+def test_append_nested_lists() -> None:
+    current_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="s1",
+            field_type=StructType(
+                NestedField(
+                    field_id=2,
+                    name="s2",
+                    field_type=StructType(
+                        NestedField(
+                            field_id=3,
+                            name="s3",
+                            field_type=StructType(
+                                NestedField(
+                                    field_id=4,
+                                    name="list1",
+                                    field_type=ListType(element_id=5, 
element_type=StringType(), element_required=False),
+                                    required=False,
+                                )
+                            ),
+                            required=False,
+                        )
+                    ),
+                    required=False,
+                )
+            ),
+            required=False,
+        )
+    )
+
+    observed_schema = Schema(
+        NestedField(
+            field_id=1,
+            name="s1",
+            field_type=StructType(
+                NestedField(
+                    field_id=2,
+                    name="s2",
+                    field_type=StructType(
+                        NestedField(
+                            field_id=3,
+                            name="s3",
+                            field_type=StructType(
+                                NestedField(
+                                    field_id=4,
+                                    name="list2",
+                                    field_type=ListType(element_id=5, 
element_type=StringType(), element_required=False),
+                                    required=False,
+                                )
+                            ),
+                            required=False,
+                        )
+                    ),
+                    required=False,
+                )
+            ),
+            required=False,
+        )
+    )
+    union = UpdateSchema(None, 
schema=current_schema).union_by_name(observed_schema)._apply()
+
+    expected = Schema(
+        NestedField(
+            field_id=1,
+            name="s1",
+            field_type=StructType(
+                NestedField(
+                    field_id=2,
+                    name="s2",
+                    field_type=StructType(
+                        NestedField(
+                            field_id=3,
+                            name="s3",
+                            field_type=StructType(
+                                NestedField(
+                                    field_id=4,
+                                    name="list1",
+                                    field_type=ListType(element_id=5, 
element_type=StringType(), element_required=False),
+                                    required=False,
+                                ),
+                                NestedField(
+                                    field_id=6,
+                                    name="list2",
+                                    field_type=ListType(element_id=7, 
element_type=StringType(), element_required=False),
+                                    required=False,
+                                ),
+                            ),
+                            required=False,
+                        )
+                    ),
+                    required=False,
+                )
+            ),
+            required=False,
+        )
+    )
+
+    assert union.as_struct() == expected.as_struct()


Reply via email to