This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new cd7fb50 Add UnionByName functionality (#296)
cd7fb50 is described below
commit cd7fb502900a717d6b902a398b267eb10e4faa9b
Author: Fokko Driesprong <[email protected]>
AuthorDate: Fri Jan 26 14:36:20 2024 +0100
Add UnionByName functionality (#296)
* Add UnionByName functionality
* Thanks Honah!
* Add `_id`
Co-authored-by: Honah J. <[email protected]>
* Fix
---------
Co-authored-by: Honah J. <[email protected]>
---
mkdocs/docs/api.md | 41 +++
pyiceberg/table/__init__.py | 188 ++++++++++++-
pyiceberg/types.py | 12 +
tests/test_schema.py | 671 +++++++++++++++++++++++++++++++++++++++++++-
4 files changed, 905 insertions(+), 7 deletions(-)
diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md
index 9d97d4f..6f79835 100644
--- a/mkdocs/docs/api.md
+++ b/mkdocs/docs/api.md
@@ -293,6 +293,47 @@ with table.transaction() as transaction:
# ... Update properties etc
```
+### Union by Name
+
+Using `.union_by_name()` you can merge another schema into an existing schema
without having to worry about field-IDs:
+
+```python
+from pyiceberg.catalog import load_catalog
+from pyiceberg.schema import Schema
+from pyiceberg.types import NestedField, StringType, DoubleType, LongType
+
+catalog = load_catalog()
+
+schema = Schema(
+ NestedField(1, "city", StringType(), required=False),
+ NestedField(2, "lat", DoubleType(), required=False),
+ NestedField(3, "long", DoubleType(), required=False),
+)
+
+table = catalog.create_table("default.locations", schema)
+
+new_schema = Schema(
+ NestedField(1, "city", StringType(), required=False),
+ NestedField(2, "lat", DoubleType(), required=False),
+ NestedField(3, "long", DoubleType(), required=False),
+ NestedField(10, "population", LongType(), required=False),
+)
+
+with table.update_schema() as update:
+ update.union_by_name(new_schema)
+```
+
+Now the table has the union of the two schemas `print(table.schema())`:
+
+```
+table {
+ 1: city: optional string
+ 2: lat: optional double
+ 3: long: optional double
+ 4: population: optional long
+}
+```
+
### Add column
Using `add_column` you can add a column, without having to worry about the
field-id:
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 057dd84..221a609 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -69,11 +69,14 @@ from pyiceberg.manifest import (
)
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import (
+ PartnerAccessor,
Schema,
SchemaVisitor,
+ SchemaWithPartnerVisitor,
assign_fresh_schema_ids,
promote,
visit,
+ visit_with_partner,
)
from pyiceberg.table.metadata import (
INITIAL_SEQUENCE_NUMBER,
@@ -1379,7 +1382,7 @@ class Move:
class UpdateSchema:
- _table: Table
+ _table: Optional[Table]
_schema: Schema
_last_column_id: itertools.count[int]
_identifier_field_names: Set[str]
@@ -1398,14 +1401,23 @@ class UpdateSchema:
def __init__(
self,
- table: Table,
+ table: Optional[Table],
transaction: Optional[Transaction] = None,
allow_incompatible_changes: bool = False,
case_sensitive: bool = True,
+ schema: Optional[Schema] = None,
) -> None:
self._table = table
- self._schema = table.schema()
- self._last_column_id = itertools.count(table.metadata.last_column_id +
1)
+
+ if isinstance(schema, Schema):
+ self._schema = schema
+ self._last_column_id = itertools.count(1 + schema.highest_field_id)
+ elif table is not None:
+ self._schema = table.schema()
+ self._last_column_id = itertools.count(1 +
table.metadata.last_column_id)
+ else:
+ raise ValueError("Either provide a table or a schema")
+
self._identifier_field_names = self._schema.identifier_field_names()
self._adds = {}
@@ -1449,6 +1461,15 @@ class UpdateSchema:
self._case_sensitive = case_sensitive
return self
+ def union_by_name(self, new_schema: Schema) -> UpdateSchema:
+ visit_with_partner(
+ new_schema,
+ -1,
+ UnionByNameVisitor(update_schema=self,
existing_schema=self._schema, case_sensitive=self._case_sensitive), # type:
ignore
+ PartnerIdByNameAccessor(partner_schema=self._schema,
case_sensitive=self._case_sensitive),
+ )
+ return self
+
def add_column(
self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc:
Optional[str] = None, required: bool = False
) -> UpdateSchema:
@@ -1816,6 +1837,9 @@ class UpdateSchema:
def commit(self) -> None:
"""Apply the pending changes and commit."""
+ if self._table is None:
+ raise ValueError("Requires a table to commit to")
+
new_schema = self._apply()
existing_schema_id = next((schema.schema_id for schema in
self._table.metadata.schemas if schema == new_schema), None)
@@ -1862,7 +1886,8 @@ class UpdateSchema:
field_ids.add(field.field_id)
- return Schema(*struct.fields, schema_id=1 +
max(self._table.schemas().keys()), identifier_field_ids=field_ids)
+ next_schema_id = 1 + (max(self._table.schemas().keys()) if self._table
is not None else self._schema.schema_id)
+ return Schema(*struct.fields, schema_id=next_schema_id,
identifier_field_ids=field_ids)
def assign_new_column_id(self) -> int:
return next(self._last_column_id)
@@ -1995,6 +2020,159 @@ class
_ApplyChanges(SchemaVisitor[Optional[IcebergType]]):
return primitive
+class UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]):
+ update_schema: UpdateSchema
+ existing_schema: Schema
+ case_sensitive: bool
+
+ def __init__(self, update_schema: UpdateSchema, existing_schema: Schema,
case_sensitive: bool) -> None:
+ self.update_schema = update_schema
+ self.existing_schema = existing_schema
+ self.case_sensitive = case_sensitive
+
+ def schema(self, schema: Schema, partner_id: Optional[int], struct_result:
bool) -> bool:
+ return struct_result
+
+ def struct(self, struct: StructType, partner_id: Optional[int],
missing_positions: List[bool]) -> bool:
+ if partner_id is None:
+ return True
+
+ fields = struct.fields
+ partner_struct = self._find_field_type(partner_id)
+
+ if not partner_struct.is_struct:
+ raise ValueError(f"Expected a struct, got: {partner_struct}")
+
+ for pos, missing in enumerate(missing_positions):
+ if missing:
+ self._add_column(partner_id, fields[pos])
+ else:
+ field = fields[pos]
+ if nested_field := partner_struct.field_by_name(field.name,
case_sensitive=self.case_sensitive):
+ self._update_column(field, nested_field)
+
+ return False
+
+ def _add_column(self, parent_id: int, field: NestedField) -> None:
+ if parent_name := self.existing_schema.find_column_name(parent_id):
+ path: Tuple[str, ...] = (parent_name, field.name)
+ else:
+ path = (field.name,)
+
+ self.update_schema.add_column(path=path, field_type=field.field_type,
required=field.required, doc=field.doc)
+
+ def _update_column(self, field: NestedField, existing_field: NestedField)
-> None:
+ full_name =
self.existing_schema.find_column_name(existing_field.field_id)
+
+ if full_name is None:
+ raise ValueError(f"Could not find field: {existing_field}")
+
+ if field.optional and existing_field.required:
+ self.update_schema.make_column_optional(full_name)
+
+ if field.field_type.is_primitive and field.field_type !=
existing_field.field_type:
+ self.update_schema.update_column(full_name,
field_type=field.field_type)
+
+ if field.doc is not None and not field.doc != existing_field.doc:
+ self.update_schema.update_column(full_name, doc=field.doc)
+
+ def _find_field_type(self, field_id: int) -> IcebergType:
+ if field_id == -1:
+ return self.existing_schema.as_struct()
+ else:
+ return self.existing_schema.find_field(field_id).field_type
+
+ def field(self, field: NestedField, partner_id: Optional[int],
field_result: bool) -> bool:
+ return partner_id is None
+
+ def list(self, list_type: ListType, list_partner_id: Optional[int],
element_missing: bool) -> bool:
+ if list_partner_id is None:
+ return True
+
+ if element_missing:
+ raise ValueError("Error traversing schemas: element is missing,
but list is present")
+
+ partner_list_type = self._find_field_type(list_partner_id)
+ if not isinstance(partner_list_type, ListType):
+ raise ValueError(f"Expected list-type, got: {partner_list_type}")
+
+ self._update_column(list_type.element_field,
partner_list_type.element_field)
+
+ return False
+
+ def map(self, map_type: MapType, map_partner_id: Optional[int],
key_missing: bool, value_missing: bool) -> bool:
+ if map_partner_id is None:
+ return True
+
+ if key_missing:
+ raise ValueError("Error traversing schemas: key is missing, but
map is present")
+
+ if value_missing:
+ raise ValueError("Error traversing schemas: value is missing, but
map is present")
+
+ partner_map_type = self._find_field_type(map_partner_id)
+ if not isinstance(partner_map_type, MapType):
+ raise ValueError(f"Expected map-type, got: {partner_map_type}")
+
+ self._update_column(map_type.key_field, partner_map_type.key_field)
+ self._update_column(map_type.value_field, partner_map_type.value_field)
+
+ return False
+
+ def primitive(self, primitive: PrimitiveType, primitive_partner_id:
Optional[int]) -> bool:
+ return primitive_partner_id is None
+
+
+class PartnerIdByNameAccessor(PartnerAccessor[int]):
+ partner_schema: Schema
+ case_sensitive: bool
+
+ def __init__(self, partner_schema: Schema, case_sensitive: bool) -> None:
+ self.partner_schema = partner_schema
+ self.case_sensitive = case_sensitive
+
+ def schema_partner(self, partner: Optional[int]) -> Optional[int]:
+ return -1
+
+ def field_partner(self, partner_field_id: Optional[int], field_id: int,
field_name: str) -> Optional[int]:
+ if partner_field_id is not None:
+ if partner_field_id == -1:
+ struct = self.partner_schema.as_struct()
+ else:
+ struct =
self.partner_schema.find_field(partner_field_id).field_type
+ if not struct.is_struct:
+ raise ValueError(f"Expected StructType: {struct}")
+
+ if field := struct.field_by_name(name=field_name,
case_sensitive=self.case_sensitive):
+ return field.field_id
+
+ return None
+
+ def list_element_partner(self, partner_list_id: Optional[int]) ->
Optional[int]:
+ if partner_list_id is not None and (field :=
self.partner_schema.find_field(partner_list_id)):
+ if not isinstance(field.field_type, ListType):
+ raise ValueError(f"Expected ListType: {field}")
+ return field.field_type.element_field.field_id
+ else:
+ return None
+
+ def map_key_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
+ if partner_map_id is not None and (field :=
self.partner_schema.find_field(partner_map_id)):
+ if not isinstance(field.field_type, MapType):
+ raise ValueError(f"Expected MapType: {field}")
+ return field.field_type.key_field.field_id
+ else:
+ return None
+
+ def map_value_partner(self, partner_map_id: Optional[int]) ->
Optional[int]:
+ if partner_map_id is not None and (field :=
self.partner_schema.find_field(partner_map_id)):
+ if not isinstance(field.field_type, MapType):
+ raise ValueError(f"Expected MapType: {field}")
+ return field.field_type.value_field.field_id
+ else:
+ return None
+
+
def _add_fields(fields: Tuple[NestedField, ...], adds:
Optional[List[NestedField]]) -> Tuple[NestedField, ...]:
adds = adds or []
return fields + tuple(adds)
diff --git a/pyiceberg/types.py b/pyiceberg/types.py
index 5e7c193..eb21512 100644
--- a/pyiceberg/types.py
+++ b/pyiceberg/types.py
@@ -350,6 +350,18 @@ class StructType(IcebergType):
return field
return None
+ def field_by_name(self, name: str, case_sensitive: bool = True) ->
Optional[NestedField]:
+ if case_sensitive:
+ name_lower = name.lower()
+ for field in self.fields:
+ if field.name.lower() == name_lower:
+ return field
+ else:
+ for field in self.fields:
+ if field.name == name:
+ return field
+ return None
+
def __str__(self) -> str:
"""Return the string representation of the StructType class."""
return f"struct<{', '.join(map(str, self.fields))}>"
diff --git a/tests/test_schema.py b/tests/test_schema.py
index 8e34423..a5487b7 100644
--- a/tests/test_schema.py
+++ b/tests/test_schema.py
@@ -16,12 +16,12 @@
# under the License.
from textwrap import dedent
-from typing import Any, Dict
+from typing import Any, Dict, List
import pytest
from pyiceberg import schema
-from pyiceberg.exceptions import ResolveError
+from pyiceberg.exceptions import ResolveError, ValidationError
from pyiceberg.expressions import Accessor
from pyiceberg.schema import (
Schema,
@@ -30,6 +30,7 @@ from pyiceberg.schema import (
prune_columns,
sanitize_column_names,
)
+from pyiceberg.table import UpdateSchema
from pyiceberg.typedef import EMPTY_DICT, StructProtocol
from pyiceberg.types import (
BinaryType,
@@ -45,6 +46,7 @@ from pyiceberg.types import (
LongType,
MapType,
NestedField,
+ PrimitiveType,
StringType,
StructType,
TimestampType,
@@ -912,3 +914,668 @@ def test_promotion(file_type: IcebergType, read_type:
IcebergType) -> None:
else:
with pytest.raises(ResolveError):
promote(file_type, read_type)
+
+
[email protected]()
+def primitive_fields() -> List[NestedField]:
+ return [
+ NestedField(field_id=1, name=str(primitive_type),
field_type=primitive_type, required=False)
+ for primitive_type in TEST_PRIMITIVE_TYPES
+ ]
+
+
+def test_add_top_level_primitives(primitive_fields: NestedField) -> None:
+ for primitive_field in primitive_fields:
+ new_schema = Schema(primitive_field)
+ applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ assert applied == new_schema
+
+
+def test_add_top_level_list_of_primitives(primitive_fields: NestedField) ->
None:
+ for primitive_type in TEST_PRIMITIVE_TYPES:
+ new_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="aList",
+ field_type=ListType(element_id=2, element_type=primitive_type,
element_required=False),
+ required=False,
+ )
+ )
+ applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_top_level_map_of_primitives(primitive_fields: NestedField) ->
None:
+ for primitive_type in TEST_PRIMITIVE_TYPES:
+ new_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="aMap",
+ field_type=MapType(
+ key_id=2, key_type=primitive_type, value_id=3,
value_type=primitive_type, value_required=False
+ ),
+ required=False,
+ )
+ )
+ applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_top_struct_of_primitives(primitive_fields: NestedField) -> None:
+ for primitive_type in TEST_PRIMITIVE_TYPES:
+ new_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="aStruct",
+ field_type=StructType(NestedField(field_id=2,
name="primitive", field_type=primitive_type, required=False)),
+ required=False,
+ )
+ )
+ applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_nested_primitive(primitive_fields: NestedField) -> None:
+ for primitive_type in TEST_PRIMITIVE_TYPES:
+ current_schema = Schema(NestedField(field_id=1, name="aStruct",
field_type=StructType(), required=False))
+ new_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="aStruct",
+ field_type=StructType(NestedField(field_id=2,
name="primitive", field_type=primitive_type, required=False)),
+ required=False,
+ )
+ )
+ applied = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+ assert applied.as_struct() == new_schema.as_struct()
+
+
+def _primitive_fields(types: List[PrimitiveType], start_id: int = 0) ->
List[NestedField]:
+ fields = []
+ for iceberg_type in types:
+ fields.append(NestedField(field_id=start_id, name=str(iceberg_type),
field_type=iceberg_type, required=False))
+ start_id = start_id + 1
+
+ return fields
+
+
+def test_add_nested_primitives(primitive_fields: NestedField) -> None:
+ current_schema = Schema(NestedField(field_id=1, name="aStruct",
field_type=StructType(), required=False))
+ new_schema = Schema(
+ NestedField(
+ field_id=1, name="aStruct",
field_type=StructType(*_primitive_fields(TEST_PRIMITIVE_TYPES, 2)),
required=False
+ )
+ )
+ applied = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+ assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_nested_lists(primitive_fields: NestedField) -> None:
+ new_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="aList",
+ type=ListType(
+ element_id=2,
+ element_type=ListType(
+ element_id=3,
+ element_type=ListType(
+ element_id=4,
+ element_type=ListType(
+ element_id=5,
+ element_type=ListType(
+ element_id=6,
+ element_type=ListType(
+ element_id=7,
+ element_type=ListType(
+ element_id=8,
+ element_type=ListType(element_id=9,
element_type=DecimalType(precision=11, scale=20)),
+ element_required=False,
+ ),
+ element_required=False,
+ ),
+ element_required=False,
+ ),
+ element_required=False,
+ ),
+ element_required=False,
+ ),
+ element_required=False,
+ ),
+ element_required=False,
+ ),
+ required=False,
+ )
+ )
+ applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_nested_struct(primitive_fields: NestedField) -> None:
+ new_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="struct1",
+ type=StructType(
+ NestedField(
+ field_id=2,
+ name="struct2",
+ type=StructType(
+ NestedField(
+ field_id=3,
+ name="struct3",
+ type=StructType(
+ NestedField(
+ field_id=4,
+ name="struct4",
+ type=StructType(
+ NestedField(
+ field_id=5,
+ name="struct5",
+ type=StructType(
+ NestedField(
+ field_id=6,
+ name="struct6",
+ type=StructType(
+
NestedField(field_id=7, name="aString", field_type=StringType())
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ )
+ applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_add_nested_maps(primitive_fields: NestedField) -> None:
+ new_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="struct",
+ field_type=MapType(
+ key_id=2,
+ value_id=3,
+ key_type=StringType(),
+ value_type=MapType(
+ key_id=4,
+ value_id=5,
+ key_type=StringType(),
+ value_type=MapType(
+ key_id=6,
+ value_id=7,
+ key_type=StringType(),
+ value_type=MapType(
+ key_id=8,
+ value_id=9,
+ key_type=StringType(),
+ value_type=MapType(
+ key_id=10,
+ value_id=11,
+ key_type=StringType(),
+ value_type=MapType(key_id=12, value_id=13,
key_type=StringType(), value_type=StringType()),
+ value_required=False,
+ ),
+ value_required=False,
+ ),
+ value_required=False,
+ ),
+ value_required=False,
+ ),
+ value_required=False,
+ ),
+ required=False,
+ )
+ )
+ applied = UpdateSchema(None,
schema=Schema()).union_by_name(new_schema)._apply()
+ assert applied.as_struct() == new_schema.as_struct()
+
+
+def test_detect_invalid_top_level_list() -> None:
+ current_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="aList",
+ field_type=ListType(element_id=2, element_type=StringType(),
element_required=False),
+ required=False,
+ )
+ )
+ new_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="aList",
+ field_type=ListType(element_id=2, element_type=DoubleType(),
element_required=False),
+ required=False,
+ )
+ )
+
+ with pytest.raises(ValidationError, match="Cannot change column type:
aList.element: string -> double"):
+ _ = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+
+
+def test_detect_invalid_top_level_maps() -> None:
+ current_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="aMap",
+ field_type=MapType(key_id=2, value_id=3, key_type=StringType(),
value_type=StringType(), value_required=False),
+ required=False,
+ )
+ )
+ new_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="aMap",
+ field_type=MapType(key_id=2, value_id=3, key_type=UUIDType(),
value_type=StringType(), value_required=False),
+ required=False,
+ )
+ )
+
+ with pytest.raises(ValidationError, match="Cannot change column type:
aMap.key: string -> uuid"):
+ _ = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+
+
+def test_promote_float_to_double() -> None:
+ current_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=FloatType(), required=False))
+ new_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=DoubleType(), required=False))
+
+ applied = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+
+ assert applied.as_struct() == new_schema.as_struct()
+ assert len(applied.fields) == 1
+ assert isinstance(applied.fields[0].field_type, DoubleType)
+
+
+def test_detect_invalid_promotion_double_to_float() -> None:
+ current_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=DoubleType(), required=False))
+ new_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=FloatType(), required=False))
+
+ with pytest.raises(ValidationError, match="Cannot change column type:
aCol: double -> float"):
+ _ = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+
+
+# decimal(P,S) Fixed-point decimal; precision P, scale S -> Scale is fixed [1],
+# precision must be 38 or less
+def test_type_promote_decimal_to_fixed_scale_with_wider_precision() -> None:
+ current_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=DecimalType(precision=20, scale=1), required=False))
+ new_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=DecimalType(precision=22, scale=1), required=False))
+
+ applied = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+
+ assert applied.as_struct() == new_schema.as_struct()
+ assert len(applied.fields) == 1
+ field = applied.fields[0]
+ decimal_type = field.field_type
+ assert isinstance(decimal_type, DecimalType)
+ assert decimal_type.precision == 22
+ assert decimal_type.scale == 1
+
+
+def test_add_nested_structs(primitive_fields: NestedField) -> None:
+ schema = Schema(
+ NestedField(
+ field_id=1,
+ name="struct1",
+ field_type=StructType(
+ NestedField(
+ field_id=2,
+ name="struct2",
+ field_type=StructType(
+ NestedField(
+ field_id=3,
+ name="list",
+ field_type=ListType(
+ element_id=4,
+ element_type=StructType(
+ NestedField(field_id=5, name="value",
field_type=StringType(), required=False)
+ ),
+ element_required=False,
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ )
+ new_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="struct1",
+ field_type=StructType(
+ NestedField(
+ field_id=2,
+ name="struct2",
+ field_type=StructType(
+ NestedField(
+ field_id=3,
+ name="list",
+ field_type=ListType(
+ element_id=4,
+ element_type=StructType(
+ NestedField(field_id=5, name="time",
field_type=TimeType(), required=False)
+ ),
+ element_required=False,
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ )
+ applied = UpdateSchema(None,
schema=schema).union_by_name(new_schema)._apply()
+
+ expected = Schema(
+ NestedField(
+ field_id=1,
+ name="struct1",
+ field_type=StructType(
+ NestedField(
+ field_id=2,
+ name="struct2",
+ field_type=StructType(
+ NestedField(
+ field_id=3,
+ name="list",
+ field_type=ListType(
+ element_id=4,
+ element_type=StructType(
+ NestedField(field_id=5, name="value",
field_type=StringType(), required=False),
+ NestedField(field_id=6, name="time",
field_type=TimeType(), required=False),
+ ),
+ element_required=False,
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ )
+
+ assert applied.as_struct() == expected.as_struct()
+
+
+def test_replace_list_with_primitive() -> None:
+ current_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=ListType(element_id=2, element_type=StringType())))
+ new_schema = Schema(NestedField(field_id=1, name="aCol",
field_type=StringType()))
+
+ with pytest.raises(ValidationError, match="Cannot change column type:
list<string> is not a primitive"):
+ _ = UpdateSchema(None,
schema=current_schema).union_by_name(new_schema)._apply()
+
+
+def test_mirrored_schemas() -> None:
+ current_schema = Schema(
+ NestedField(9, "struct1", StructType(NestedField(8, "string1",
StringType(), required=False)), required=False),
+ NestedField(6, "list1", ListType(element_id=7,
element_type=StringType(), element_required=False), required=False),
+ NestedField(5, "string2", StringType(), required=False),
+ NestedField(4, "string3", StringType(), required=False),
+ NestedField(3, "string4", StringType(), required=False),
+ NestedField(2, "string5", StringType(), required=False),
+ NestedField(1, "string6", StringType(), required=False),
+ )
+ mirrored_schema = Schema(
+ NestedField(1, "struct1", StructType(NestedField(2, "string1",
StringType(), required=False))),
+ NestedField(3, "list1", ListType(element_id=4,
element_type=StringType(), element_required=False), required=False),
+ NestedField(5, "string2", StringType(), required=False),
+ NestedField(6, "string3", StringType(), required=False),
+ NestedField(7, "string4", StringType(), required=False),
+ NestedField(8, "string5", StringType(), required=False),
+ NestedField(9, "string6", StringType(), required=False),
+ )
+
+ applied = UpdateSchema(None,
schema=current_schema).union_by_name(mirrored_schema)._apply()
+
+ assert applied.as_struct() == current_schema.as_struct()
+
+
+def test_add_new_top_level_struct() -> None:
+ current_schema = Schema(
+ NestedField(
+ 1,
+ "map1",
+ MapType(
+ key_id=2,
+ value_id=3,
+ key_type=StringType(),
+ value_type=ListType(
+ element_id=4,
+ element_type=StructType(NestedField(field_id=5,
name="string", field_type=StringType(), required=False)),
+ ),
+ value_required=False,
+ ),
+ )
+ )
+ observed_schema = Schema(
+ NestedField(
+ 1,
+ "map1",
+ MapType(
+ key_id=2,
+ value_id=3,
+ key_type=StringType(),
+ value_type=ListType(
+ element_id=4,
+ element_type=StructType(NestedField(field_id=5,
name="string", field_type=StringType(), required=False)),
+ ),
+ value_required=False,
+ ),
+ ),
+ NestedField(
+ field_id=6,
+ name="struct1",
+ field_type=StructType(
+ NestedField(
+ field_id=7,
+ name="d1",
+ field_type=StructType(NestedField(field_id=8, name="d2",
field_type=StringType(), required=False)),
+ required=False,
+ )
+ ),
+ required=False,
+ ),
+ )
+
+ applied = UpdateSchema(None,
schema=current_schema).union_by_name(observed_schema)._apply()
+
+ assert applied.as_struct() == observed_schema.as_struct()
+
+
+def test_append_nested_struct() -> None:
+ current_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="s1",
+ field_type=StructType(
+ NestedField(
+ field_id=2,
+ name="s2",
+ field_type=StructType(
+ NestedField(
+ field_id=3,
+ name="s3",
+ field_type=StructType(NestedField(field_id=4,
name="s4", field_type=StringType(), required=False)),
+ )
+ ),
+ required=False,
+ )
+ ),
+ )
+ )
+ observed_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="s1",
+ field_type=StructType(
+ NestedField(
+ field_id=2,
+ name="s2",
+ field_type=StructType(
+ NestedField(
+ field_id=3,
+ name="s3",
+ field_type=StructType(NestedField(field_id=4,
name="s4", field_type=StringType(), required=False)),
+ required=False,
+ ),
+ NestedField(
+ field_id=5,
+ name="repeat",
+ field_type=StructType(
+ NestedField(
+ field_id=6,
+ name="s1",
+ field_type=StructType(
+ NestedField(
+ field_id=7,
+ name="s2",
+ field_type=StructType(
+ NestedField(
+ field_id=8,
+ name="s3",
+ field_type=StructType(
+
NestedField(field_id=9, name="s4", field_type=StringType())
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ ),
+ required=False,
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ )
+
+ applied = UpdateSchema(None,
schema=current_schema).union_by_name(observed_schema)._apply()
+
+ assert applied.as_struct() == observed_schema.as_struct()
+
+
+def test_append_nested_lists() -> None:
+ current_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="s1",
+ field_type=StructType(
+ NestedField(
+ field_id=2,
+ name="s2",
+ field_type=StructType(
+ NestedField(
+ field_id=3,
+ name="s3",
+ field_type=StructType(
+ NestedField(
+ field_id=4,
+ name="list1",
+ field_type=ListType(element_id=5,
element_type=StringType(), element_required=False),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ )
+
+ observed_schema = Schema(
+ NestedField(
+ field_id=1,
+ name="s1",
+ field_type=StructType(
+ NestedField(
+ field_id=2,
+ name="s2",
+ field_type=StructType(
+ NestedField(
+ field_id=3,
+ name="s3",
+ field_type=StructType(
+ NestedField(
+ field_id=4,
+ name="list2",
+ field_type=ListType(element_id=5,
element_type=StringType(), element_required=False),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ )
+ union = UpdateSchema(None,
schema=current_schema).union_by_name(observed_schema)._apply()
+
+ expected = Schema(
+ NestedField(
+ field_id=1,
+ name="s1",
+ field_type=StructType(
+ NestedField(
+ field_id=2,
+ name="s2",
+ field_type=StructType(
+ NestedField(
+ field_id=3,
+ name="s3",
+ field_type=StructType(
+ NestedField(
+ field_id=4,
+ name="list1",
+ field_type=ListType(element_id=5,
element_type=StringType(), element_required=False),
+ required=False,
+ ),
+ NestedField(
+ field_id=6,
+ name="list2",
+ field_type=ListType(element_id=7,
element_type=StringType(), element_required=False),
+ required=False,
+ ),
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ ),
+ required=False,
+ )
+ )
+
+ assert union.as_struct() == expected.as_struct()