This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new b7fb00735d Python: Support for adding columns (#8174)
b7fb00735d is described below
commit b7fb00735d25e53ce36f07e5af116c8ccd368011
Author: Liwei Li <[email protected]>
AuthorDate: Tue Aug 22 20:05:44 2023 +0800
Python: Support for adding columns (#8174)
* Python: Support add column
* Add integration tests (#264)
* Python: Support add column
* Add the requirement (#265)
* Python: Support add column
---------
Co-authored-by: Fokko Driesprong <[email protected]>
---
python/mkdocs/docs/api.md | 32 ++++
python/pyiceberg/schema.py | 64 ++++----
python/pyiceberg/table/__init__.py | 318 +++++++++++++++++++++++++++++++++++--
python/pyiceberg/table/metadata.py | 11 +-
python/tests/catalog/test_base.py | 173 +++++++++++++++++---
python/tests/cli/test_console.py | 11 +-
python/tests/conftest.py | 73 ++++++++-
python/tests/table/test_init.py | 213 +++++++++++++++++++++++--
python/tests/test_integration.py | 92 ++++++++++-
python/tests/test_schema.py | 72 ---------
10 files changed, 894 insertions(+), 165 deletions(-)
diff --git a/python/mkdocs/docs/api.md b/python/mkdocs/docs/api.md
index d3b8fceee5..f0b2873c03 100644
--- a/python/mkdocs/docs/api.md
+++ b/python/mkdocs/docs/api.md
@@ -146,6 +146,38 @@ catalog.create_table(
)
```
+### Update table schema
+
+Add new columns through the `Transaction` or `UpdateSchema` API:
+
+Use the Transaction API:
+
+```python
+with table.transaction() as transaction:
+ transaction.update_schema().add_column("x", IntegerType(), "doc").commit()
+```
+
+Or, without a context manager:
+
+```python
+transaction = table.transaction()
+transaction.update_schema().add_column("x", IntegerType(), "doc").commit()
+transaction.commit_transaction()
+```
+
+Or, use the UpdateSchema API directly:
+
+```python
+with table.update_schema() as update:
+ update.add_column("x", IntegerType(), "doc")
+```
+
+Or, without a context manager:
+
+```python
+table.update_schema().add_column("x", IntegerType(), "doc").commit()
+```
+
### Update table properties
Set and remove properties through the `Transaction` API:
diff --git a/python/pyiceberg/schema.py b/python/pyiceberg/schema.py
index 74232d0b7b..5064d07174 100644
--- a/python/pyiceberg/schema.py
+++ b/python/pyiceberg/schema.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=W0511
+from __future__ import annotations
+
import itertools
from abc import ABC, abstractmethod
from dataclasses import dataclass
@@ -145,7 +147,7 @@ class Schema(IcebergBaseModel):
return index_name_by_id(self)
@cached_property
- def _lazy_id_to_accessor(self) -> Dict[int, "Accessor"]:
+ def _lazy_id_to_accessor(self) -> Dict[int, Accessor]:
"""Returns an index of field ID to accessor.
This is calculated once when called for the first time. Subsequent
calls to this method will use a cached index.
@@ -201,7 +203,7 @@ class Schema(IcebergBaseModel):
@property
def highest_field_id(self) -> int:
- return visit(self.as_struct(), _FindLastFieldId())
+ return max(self._lazy_id_to_name.keys(), default=0)
def find_column_name(self, column_id: int) -> Optional[str]:
"""Find a column name given a column ID.
@@ -226,7 +228,7 @@ class Schema(IcebergBaseModel):
"""
return list(self._lazy_id_to_name.values())
- def accessor_for_field(self, field_id: int) -> "Accessor":
+ def accessor_for_field(self, field_id: int) -> Accessor:
"""Find a schema position accessor given a field ID.
Args:
@@ -243,7 +245,7 @@ class Schema(IcebergBaseModel):
return self._lazy_id_to_accessor[field_id]
- def select(self, *names: str, case_sensitive: bool = True) -> "Schema":
+ def select(self, *names: str, case_sensitive: bool = True) -> Schema:
"""Return a new schema instance pruned to a subset of columns.
Args:
@@ -682,7 +684,7 @@ class Accessor:
"""An accessor for a specific position in a container that implements the
StructProtocol."""
position: int
- inner: Optional["Accessor"] = None
+ inner: Optional[Accessor] = None
def __str__(self) -> str:
"""Returns the string representation of the Accessor class."""
@@ -766,7 +768,7 @@ def _(obj: MapType, visitor: SchemaVisitor[T]) -> T:
visitor.before_map_value(obj.value_field)
value_result = visit(obj.value_type, visitor)
- visitor.after_list_element(obj.value_field)
+ visitor.after_map_value(obj.value_field)
return visitor.map(obj, key_result, value_result)
@@ -890,6 +892,22 @@ class _IndexByName(SchemaVisitor[Dict[str, int]]):
self._field_names: List[str] = []
self._short_field_names: List[str] = []
+ def before_map_key(self, key: NestedField) -> None:
+ self.before_field(key)
+
+ def after_map_key(self, key: NestedField) -> None:
+ self.after_field(key)
+
+ def before_map_value(self, value: NestedField) -> None:
+ if not isinstance(value.field_type, StructType):
+ self._short_field_names.append(value.name)
+ self._field_names.append(value.name)
+
+ def after_map_value(self, value: NestedField) -> None:
+ if not isinstance(value.field_type, StructType):
+ self._short_field_names.pop()
+ self._field_names.pop()
+
def before_list_element(self, element: NestedField) -> None:
"""Short field names omit element when the element is a StructType."""
if not isinstance(element.field_type, StructType):
@@ -1082,45 +1100,23 @@ def build_position_accessors(schema_or_type:
Union[Schema, IcebergType]) -> Dict
return visit(schema_or_type, _BuildPositionAccessors())
-class _FindLastFieldId(SchemaVisitor[int]):
- """Traverses the schema to get the highest field-id."""
-
- def schema(self, schema: Schema, struct_result: int) -> int:
- return struct_result
-
- def struct(self, struct: StructType, field_results: List[int]) -> int:
- return max(field_results)
-
- def field(self, field: NestedField, field_result: int) -> int:
- return max(field.field_id, field_result)
-
- def list(self, list_type: ListType, element_result: int) -> int:
- return element_result
-
- def map(self, map_type: MapType, key_result: int, value_result: int) ->
int:
- return max(key_result, value_result)
-
- def primitive(self, primitive: PrimitiveType) -> int:
- return 0
-
-
-def assign_fresh_schema_ids(schema: Schema) -> Schema:
+def assign_fresh_schema_ids(schema_or_type: Union[Schema, IcebergType],
next_id: Optional[Callable[[], int]] = None) -> Schema:
"""Traverses the schema, and sets new IDs."""
- return pre_order_visit(schema, _SetFreshIDs())
+ return pre_order_visit(schema_or_type, _SetFreshIDs(next_id_func=next_id))
class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]):
"""Traverses the schema and assigns monotonically increasing ids."""
- counter: itertools.count # type: ignore
reserved_ids: Dict[int, int]
- def __init__(self, start: int = 1) -> None:
- self.counter = itertools.count(start)
+ def __init__(self, next_id_func: Optional[Callable[[], int]] = None) ->
None:
self.reserved_ids = {}
+ counter = itertools.count(1)
+ self.next_id_func = next_id_func if next_id_func is not None else
lambda: next(counter)
def _get_and_increment(self) -> int:
- return next(self.counter)
+ return self.next_id_func()
def schema(self, schema: Schema, struct_result: Callable[[], StructType])
-> Schema:
# First we keep the original identifier_field_ids here, we remap
afterwards
diff --git a/python/pyiceberg/table/__init__.py
b/python/pyiceberg/table/__init__.py
index 52479c29ca..3d4e5f7d28 100644
--- a/python/pyiceberg/table/__init__.py
+++ b/python/pyiceberg/table/__init__.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import itertools
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
@@ -58,7 +59,13 @@ from pyiceberg.manifest import (
ManifestFile,
)
from pyiceberg.partitioning import PartitionSpec
-from pyiceberg.schema import Schema
+from pyiceberg.schema import (
+ Schema,
+ SchemaVisitor,
+ assign_fresh_schema_ids,
+ index_by_name,
+ visit,
+)
from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadata
from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry
from pyiceberg.table.sorting import SortOrder
@@ -69,6 +76,14 @@ from pyiceberg.typedef import (
KeyDefaultDict,
Properties,
)
+from pyiceberg.types import (
+ IcebergType,
+ ListType,
+ MapType,
+ NestedField,
+ PrimitiveType,
+ StructType,
+)
from pyiceberg.utils.concurrent import ExecutorFactory
if TYPE_CHECKING:
@@ -81,6 +96,7 @@ if TYPE_CHECKING:
ALWAYS_TRUE = AlwaysTrue()
+TABLE_ROOT_ID = -1
class Transaction:
@@ -119,7 +135,7 @@ class Transaction:
ValueError: When the type of update is not unique.
Returns:
- A new AlterTable object with the new updates appended.
+ Transaction object with the new updates appended.
"""
for new_update in new_updates:
type_new_update = type(new_update)
@@ -128,6 +144,25 @@ class Transaction:
self._updates = self._updates + new_updates
return self
+ def _append_requirements(self, *new_requirements: TableRequirement) ->
Transaction:
+ """Appends requirements to the set of staged requirements.
+
+ Args:
+ *new_requirements: Any new requirements.
+
+ Raises:
+ ValueError: When the type of requirement is not unique.
+
+ Returns:
+ Transaction object with the new requirements appended.
+ """
+ for requirement in new_requirements:
+ type_new_requirement = type(requirement)
+ if any(type(update) == type_new_requirement for update in
self._updates):
+ raise ValueError(f"Requirements in a single commit need to be
unique, duplicate: {type_new_requirement}")
+ self._requirements = self._requirements + new_requirements
+ return self
+
def set_table_version(self, format_version: Literal[1, 2]) -> Transaction:
"""Sets the table to a certain version.
@@ -152,6 +187,14 @@ class Transaction:
"""
return self._append_updates(SetPropertiesUpdate(updates=updates))
+ def update_schema(self) -> UpdateSchema:
+ """Create a new UpdateSchema to alter the columns of this table.
+
+ Returns:
+ A new UpdateSchema.
+ """
+ return UpdateSchema(self._table.schema(), self._table, self)
+
def remove_properties(self, *removals: str) -> Transaction:
"""Removes properties.
@@ -227,6 +270,8 @@ class UpgradeFormatVersionUpdate(TableUpdate):
class AddSchemaUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.add_schema
schema_: Schema = Field(alias="schema")
+ # This field is required: https://github.com/apache/iceberg/pull/7445
+ last_column_id: int = Field(alias="last-column-id")
class SetCurrentSchemaUpdate(TableUpdate):
@@ -307,13 +352,13 @@ class TableRequirement(IcebergBaseModel):
class AssertCreate(TableRequirement):
"""The table must not already exist; used for create transactions."""
- type: Literal["assert-create"]
+ type: Literal["assert-create"] = Field(default="assert-create")
class AssertTableUUID(TableRequirement):
"""The table UUID must match the requirement's `uuid`."""
- type: Literal["assert-table-uuid"]
+ type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid")
uuid: str
@@ -323,7 +368,7 @@ class AssertRefSnapshotId(TableRequirement):
if `snapshot-id` is `null` or missing, the ref must not already exist.
"""
- type: Literal["assert-ref-snapshot-id"]
+ type: Literal["assert-ref-snapshot-id"] =
Field(default="assert-ref-snapshot-id")
ref: str
snapshot_id: int = Field(..., alias="snapshot-id")
@@ -331,35 +376,35 @@ class AssertRefSnapshotId(TableRequirement):
class AssertLastAssignedFieldId(TableRequirement):
"""The table's last assigned column id must match the requirement's
`last-assigned-field-id`."""
- type: Literal["assert-last-assigned-field-id"]
+ type: Literal["assert-last-assigned-field-id"] =
Field(default="assert-last-assigned-field-id")
last_assigned_field_id: int = Field(..., alias="last-assigned-field-id")
class AssertCurrentSchemaId(TableRequirement):
"""The table's current schema id must match the requirement's
`current-schema-id`."""
- type: Literal["assert-current-schema-id"]
+ type: Literal["assert-current-schema-id"] =
Field(default="assert-current-schema-id")
current_schema_id: int = Field(..., alias="current-schema-id")
class AssertLastAssignedPartitionId(TableRequirement):
"""The table's last assigned partition id must match the requirement's
`last-assigned-partition-id`."""
- type: Literal["assert-last-assigned-partition-id"]
+ type: Literal["assert-last-assigned-partition-id"] =
Field(default="assert-last-assigned-partition-id")
last_assigned_partition_id: int = Field(...,
alias="last-assigned-partition-id")
class AssertDefaultSpecId(TableRequirement):
"""The table's default spec id must match the requirement's
`default-spec-id`."""
- type: Literal["assert-default-spec-id"]
+ type: Literal["assert-default-spec-id"] =
Field(default="assert-default-spec-id")
default_spec_id: int = Field(..., alias="default-spec-id")
class AssertDefaultSortOrderId(TableRequirement):
"""The table's default sort order id must match the requirement's
`default-sort-order-id`."""
- type: Literal["assert-default-sort-order-id"]
+ type: Literal["assert-default-sort-order-id"] =
Field(default="assert-default-sort-order-id")
default_sort_order_id: int = Field(..., alias="default-sort-order-id")
@@ -482,6 +527,9 @@ class Table:
"""Get the snapshot history of this table."""
return self.metadata.snapshot_log
+ def update_schema(self) -> UpdateSchema:
+ return UpdateSchema(self.schema(), self)
+
def __eq__(self, other: Any) -> bool:
"""Returns the equality of two instances of the Table class."""
return (
@@ -839,3 +887,253 @@ class DataScan(TableScan):
import ray
return ray.data.from_arrow(self.to_arrow())
+
+
+class UpdateSchema:
+ _table: Table
+ _schema: Schema
+ _last_column_id: itertools.count[int]
+ _identifier_field_names: List[str]
+ _adds: Dict[int, List[NestedField]]
+ _added_name_to_id: Dict[str, int]
+ _id_to_parent: Dict[int, str]
+ _allow_incompatible_changes: bool
+ _case_sensitive: bool
+ _transaction: Optional[Transaction]
+
+ def __init__(self, schema: Schema, table: Table, transaction:
Optional[Transaction] = None):
+ self._table = table
+ self._schema = schema
+ self._last_column_id = itertools.count(schema.highest_field_id + 1)
+ self._identifier_field_names = schema.column_names
+ self._adds = {}
+ self._added_name_to_id = {}
+ self._id_to_parent = {}
+ self._allow_incompatible_changes = False
+ self._case_sensitive = True
+ self._transaction = transaction
+
+ def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
+ """Closes and commits the change."""
+ return self.commit()
+
+ def __enter__(self) -> UpdateSchema:
+ """Update the table."""
+ return self
+
+ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema:
+ """Determines if the case of schema needs to be considered when
comparing column names.
+
+ Args:
+ case_sensitive: When false case is not considered in column name
comparisons.
+
+ Returns:
+ This for method chaining
+ """
+ self._case_sensitive = case_sensitive
+ return self
+
+ def add_column(
+ self, name: str, type_var: IcebergType, doc: Optional[str] = None,
parent: Optional[str] = None, required: bool = False
+ ) -> UpdateSchema:
+ """Add a new column to a nested struct or Add a new top-level column.
+
+ Args:
+ name: Name for the new column.
+ type_var: Type for the new column.
+ doc: Documentation string for the new column.
+ parent: Name of the parent struct to the column will be added to.
+ required: Whether the new column is required.
+
+ Returns:
+ This for method chaining
+ """
+ if "." in name:
+ raise ValueError(f"Cannot add column with ambiguous name: {name}")
+
+ if required and not self._allow_incompatible_changes:
+ # Table format version 1 and 2 cannot add required column because
there is no initial value
+ raise ValueError(f"Incompatible change: cannot add required
column: {name}")
+
+ self._internal_add_column(parent, name, not required, type_var, doc)
+ return self
+
+ def allow_incompatible_changes(self) -> UpdateSchema:
+ """Allow incompatible changes to the schema.
+
+ Returns:
+ This for method chaining
+ """
+ self._allow_incompatible_changes = True
+ return self
+
+ def commit(self) -> None:
+ """Apply the pending changes and commit."""
+ new_schema = self._apply()
+ updates = [
+ AddSchemaUpdate(schema=new_schema,
last_column_id=new_schema.highest_field_id),
+ SetCurrentSchemaUpdate(schema_id=-1),
+ ]
+ requirements =
[AssertCurrentSchemaId(current_schema_id=self._schema.schema_id)]
+
+ if self._transaction is not None:
+ self._transaction._append_updates(*updates) # pylint:
disable=W0212
+ self._transaction._append_requirements(*requirements) # pylint:
disable=W0212
+ else:
+ table_update_response = self._table.catalog._commit_table( #
pylint: disable=W0212
+ CommitTableRequest(identifier=self._table.identifier[1:],
updates=updates, requirements=requirements)
+ )
+ self._table.metadata = table_update_response.metadata
+ self._table.metadata_location =
table_update_response.metadata_location
+
+ def _apply(self) -> Schema:
+ """Apply the pending changes to the original schema and returns the
result.
+
+ Returns:
+ the result Schema when all pending updates are applied
+ """
+ return _apply_changes(self._schema, self._adds,
self._identifier_field_names)
+
+ def _internal_add_column(
+ self, parent: Optional[str], name: str, is_optional: bool, type_var:
IcebergType, doc: Optional[str]
+ ) -> None:
+ full_name: str = name
+ parent_id: int = TABLE_ROOT_ID
+
+ exist_field: Optional[NestedField] = None
+ if parent:
+ parent_field = self._schema.find_field(parent,
self._case_sensitive)
+ parent_type = parent_field.field_type
+ if isinstance(parent_type, MapType):
+ parent_field = parent_type.value_field
+ elif isinstance(parent_type, ListType):
+ parent_field = parent_type.element_field
+
+ if not parent_field.field_type.is_struct:
+ raise ValueError(f"Cannot add column to non-struct type:
{parent}")
+
+ parent_id = parent_field.field_id
+
+ try:
+ exist_field = self._schema.find_field(parent + "." + name,
self._case_sensitive)
+ except ValueError:
+ pass
+
+ if exist_field:
+ raise ValueError(f"Cannot add column, name already exists:
{parent}.{name}")
+
+ full_name = parent_field.name + "." + name
+
+ else:
+ try:
+ exist_field = self._schema.find_field(name,
self._case_sensitive)
+ except ValueError:
+ pass
+
+ if exist_field:
+ raise ValueError(f"Cannot add column, name already exists:
{name}")
+
+ # assign new IDs in order
+ new_id = self.assign_new_column_id()
+
+ # update tracking for moves
+ self._added_name_to_id[full_name] = new_id
+
+ new_type = assign_fresh_schema_ids(type_var, self.assign_new_column_id)
+ field = NestedField(new_id, name, new_type, not is_optional, doc)
+
+ self._adds.setdefault(parent_id, []).append(field)
+
+ def assign_new_column_id(self) -> int:
+ return next(self._last_column_id)
+
+
+def _apply_changes(schema_: Schema, adds: Dict[int, List[NestedField]],
identifier_field_names: List[str]) -> Schema:
+ struct = visit(schema_, _ApplyChanges(adds))
+ name_to_id: Dict[str, int] = index_by_name(struct)
+ for name in identifier_field_names:
+ if name not in name_to_id:
+ raise ValueError(f"Cannot add field {name} as an identifier field:
not found in current schema or added columns")
+
+ return Schema(*struct.fields)
+
+
+class _ApplyChanges(SchemaVisitor[IcebergType]):
+ def __init__(self, adds: Dict[int, List[NestedField]]):
+ self.adds = adds
+
+ def schema(self, schema: Schema, struct_result: IcebergType) ->
IcebergType:
+ fields = _ApplyChanges.add_fields(schema.as_struct().fields,
self.adds.get(TABLE_ROOT_ID))
+ if len(fields) > 0:
+ return StructType(*fields)
+
+ return struct_result
+
+ def struct(self, struct: StructType, field_results: List[IcebergType]) ->
IcebergType:
+ has_change = False
+ new_fields: List[NestedField] = []
+ for i in range(len(field_results)):
+ type_: Optional[IcebergType] = field_results[i]
+ if type_ is None:
+ has_change = True
+ continue
+
+ field: NestedField = struct.fields[i]
+ new_fields.append(field)
+
+ if has_change:
+ return StructType(*new_fields)
+
+ return struct
+
+ def field(self, field: NestedField, field_result: IcebergType) ->
IcebergType:
+ field_id: int = field.field_id
+ if field_id in self.adds:
+ new_fields = self.adds[field_id]
+ if len(new_fields) > 0:
+ fields = _ApplyChanges.add_fields(field_result.fields,
new_fields)
+ if len(fields) > 0:
+ return StructType(*fields)
+
+ return field_result
+
+ def list(self, list_type: ListType, element_result: IcebergType) ->
IcebergType:
+ element_field: NestedField = list_type.element_field
+ element_type = self.field(element_field, element_result)
+ if element_type is None:
+ raise ValueError(f"Cannot delete element type from list:
{element_field}")
+
+ is_element_optional = not list_type.element_required
+
+ if is_element_optional == element_field.required and
list_type.element_type == element_type:
+ return list_type
+
+ return ListType(list_type.element_id, element_type,
is_element_optional)
+
+ def map(self, map_type: MapType, key_result: IcebergType, value_result:
IcebergType) -> IcebergType:
+ key_id: int = map_type.key_field.field_id
+ if key_id in self.adds:
+ raise ValueError(f"Cannot add fields to map keys: {map_type}")
+
+ value_field: NestedField = map_type.value_field
+ value_type = self.field(value_field, value_result)
+ if value_type is None:
+ raise ValueError(f"Cannot delete value type from map:
{value_field}")
+
+ is_value_optional = not map_type.value_required
+
+ if is_value_optional != value_field.required and map_type.value_type
== value_type:
+ return map_type
+
+ return MapType(map_type.key_id, map_type.key_field, map_type.value_id,
value_type, not is_value_optional)
+
+ def primitive(self, primitive: PrimitiveType) -> IcebergType:
+ return primitive
+
+ @staticmethod
+ def add_fields(fields: Tuple[NestedField, ...], adds:
Optional[List[NestedField]]) -> List[NestedField]:
+ new_fields: List[NestedField] = []
+ new_fields.extend(fields)
+ if adds:
+ new_fields.extend(adds)
+ return new_fields
diff --git a/python/pyiceberg/table/metadata.py
b/python/pyiceberg/table/metadata.py
index e6a3e6f16e..690f5d4d59 100644
--- a/python/pyiceberg/table/metadata.py
+++ b/python/pyiceberg/table/metadata.py
@@ -388,12 +388,20 @@ TableMetadata = Annotated[Union[TableMetadataV1,
TableMetadataV2], Field(discrim
def new_table_metadata(
- schema: Schema, partition_spec: PartitionSpec, sort_order: SortOrder,
location: str, properties: Properties = EMPTY_DICT
+ schema: Schema,
+ partition_spec: PartitionSpec,
+ sort_order: SortOrder,
+ location: str,
+ properties: Properties = EMPTY_DICT,
+ table_uuid: Optional[uuid.UUID] = None,
) -> TableMetadata:
fresh_schema = assign_fresh_schema_ids(schema)
fresh_partition_spec = assign_fresh_partition_spec_ids(partition_spec,
schema, fresh_schema)
fresh_sort_order = assign_fresh_sort_order_ids(sort_order, schema,
fresh_schema)
+ if table_uuid is None:
+ table_uuid = uuid.uuid4()
+
return TableMetadataV2(
location=location,
schemas=[fresh_schema],
@@ -405,6 +413,7 @@ def new_table_metadata(
default_sort_order_id=fresh_sort_order.order_id,
properties=properties,
last_partition_id=fresh_partition_spec.last_assigned_field_id,
+ table_uuid=table_uuid,
)
diff --git a/python/tests/catalog/test_base.py
b/python/tests/catalog/test_base.py
index b47aa5f5f7..29e93d0c9d 100644
--- a/python/tests/catalog/test_base.py
+++ b/python/tests/catalog/test_base.py
@@ -42,11 +42,18 @@ from pyiceberg.exceptions import (
from pyiceberg.io import load_file_io
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC,
PartitionField, PartitionSpec
from pyiceberg.schema import Schema
-from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table
-from pyiceberg.table.metadata import TableMetadataV1
+from pyiceberg.table import (
+ AddSchemaUpdate,
+ CommitTableRequest,
+ CommitTableResponse,
+ SetCurrentSchemaUpdate,
+ Table,
+)
+from pyiceberg.table.metadata import TableMetadata, TableMetadataV1,
new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import EMPTY_DICT
+from pyiceberg.types import IntegerType, LongType, NestedField
class InMemoryCatalog(Catalog):
@@ -78,29 +85,24 @@ class InMemoryCatalog(Catalog):
if namespace not in self.__namespaces:
self.__namespaces[namespace] = {}
+ new_location = location or
f's3://warehouse/{"/".join(identifier)}/data'
+ metadata = TableMetadataV1(
+ **{
+ "format-version": 1,
+ "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c",
+ "location": new_location,
+ "last-updated-ms": 1602638573874,
+ "last-column-id": schema.highest_field_id,
+ "schema": schema.model_dump(),
+ "partition-spec": partition_spec.model_dump()["fields"],
+ "properties": properties,
+ "current-snapshot-id": -1,
+ "snapshots": [{"snapshot-id": 1925, "timestamp-ms":
1602638573822}],
+ }
+ )
table = Table(
identifier=identifier,
- metadata=TableMetadataV1(
- **{
- "format-version": 1,
- "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c",
- "location": "s3://bucket/test/location",
- "last-updated-ms": 1602638573874,
- "last-column-id": 3,
- "schema": {
- "type": "struct",
- "fields": [
- {"id": 1, "name": "x", "required": True,
"type": "long"},
- {"id": 2, "name": "y", "required": True,
"type": "long", "doc": "comment"},
- {"id": 3, "name": "z", "required": True,
"type": "long"},
- ],
- },
- "partition-spec": [{"name": "x", "transform":
"identity", "source-id": 1, "field-id": 1000}],
- "properties": properties,
- "current-snapshot-id": -1,
- "snapshots": [{"snapshot-id": 1925, "timestamp-ms":
1602638573822}],
- }
- ),
+ metadata=metadata,
metadata_location=f's3://warehouse/{"/".join(identifier)}/metadata/metadata.json',
io=load_file_io(),
catalog=self,
@@ -109,7 +111,37 @@ class InMemoryCatalog(Catalog):
return table
def _commit_table(self, table_request: CommitTableRequest) ->
CommitTableResponse:
- raise NotImplementedError
+ new_metadata: Optional[TableMetadata] = None
+ metadata_location = ""
+ for update in table_request.updates:
+ if isinstance(update, AddSchemaUpdate):
+ add_schema_update: AddSchemaUpdate = update
+ identifier =
Catalog.identifier_to_tuple(table_request.identifier)
+ table = self.__tables[("com", *identifier)]
+ new_metadata = new_table_metadata(
+ add_schema_update.schema_,
+ table.metadata.partition_specs[0],
+ table.sort_order(),
+ table.location(),
+ table.properties,
+ table.metadata.table_uuid,
+ )
+
+ table = Table(
+ identifier=identifier,
+ metadata=new_metadata,
+
metadata_location=f's3://warehouse/{"/".join(identifier)}/metadata/metadata.json',
+ io=load_file_io(),
+ catalog=self,
+ )
+
+ self.__tables[identifier] = table
+ metadata_location =
f's3://warehouse/{"/".join(identifier)}/metadata/metadata.json'
+
+ return CommitTableResponse(
+ metadata=new_metadata.model_dump() if new_metadata else {},
+ metadata_location=metadata_location if metadata_location else "",
+ )
def load_table(self, identifier: Union[str, Identifier]) -> Table:
identifier = Catalog.identifier_to_tuple(identifier)
@@ -223,7 +255,11 @@ def catalog() -> InMemoryCatalog:
TEST_TABLE_IDENTIFIER = ("com", "organization", "department", "my_table")
TEST_TABLE_NAMESPACE = ("com", "organization", "department")
TEST_TABLE_NAME = "my_table"
-TEST_TABLE_SCHEMA = Schema(schema_id=1)
+TEST_TABLE_SCHEMA = Schema(
+ NestedField(1, "x", LongType()),
+ NestedField(2, "y", LongType(), doc="comment"),
+ NestedField(3, "z", LongType()),
+)
TEST_TABLE_LOCATION = "protocol://some/location"
TEST_TABLE_PARTITION_SPEC = PartitionSpec(PartitionField(name="x",
transform=IdentityTransform(), source_id=1, field_id=1000))
TEST_TABLE_PROPERTIES = {"key1": "value1", "key2": "value2"}
@@ -239,7 +275,7 @@ def given_catalog_has_a_table(catalog: InMemoryCatalog) ->
Table:
identifier=TEST_TABLE_IDENTIFIER,
schema=TEST_TABLE_SCHEMA,
location=TEST_TABLE_LOCATION,
- partition_spec=UNPARTITIONED_PARTITION_SPEC,
+ partition_spec=TEST_TABLE_PARTITION_SPEC,
properties=TEST_TABLE_PROPERTIES,
)
@@ -474,3 +510,88 @@ def test_update_namespace_metadata_removals(catalog:
InMemoryCatalog) -> None:
def
test_update_namespace_metadata_raises_error_when_namespace_does_not_exist(catalog:
InMemoryCatalog) -> None:
with pytest.raises(NoSuchNamespaceError, match=NO_SUCH_NAMESPACE_ERROR):
catalog.update_namespace_properties(TEST_TABLE_NAMESPACE,
updates=TEST_TABLE_PROPERTIES)
+
+
+def test_commit_table(catalog: InMemoryCatalog) -> None:
+ # Given
+ given_table = given_catalog_has_a_table(catalog)
+ new_schema = Schema(
+ NestedField(1, "x", LongType()),
+ NestedField(2, "y", LongType(), doc="comment"),
+ NestedField(3, "z", LongType()),
+ NestedField(4, "add", LongType()),
+ )
+
+ # When
+ response = given_table.catalog._commit_table( # pylint: disable=W0212
+ CommitTableRequest(
+ identifier=given_table.identifier[1:],
+ updates=[
+ AddSchemaUpdate(schema=new_schema,
last_column_id=new_schema.highest_field_id),
+ SetCurrentSchemaUpdate(schema_id=-1),
+ ],
+ )
+ )
+
+ # Then
+ assert response.metadata.table_uuid == given_table.metadata.table_uuid
+ assert len(response.metadata.schemas) == 1
+ assert response.metadata.schemas[0] == new_schema
+
+
+def test_add_column(catalog: InMemoryCatalog) -> None:
+ given_table = given_catalog_has_a_table(catalog)
+
+ given_table.update_schema().add_column(name="new_column1",
type_var=IntegerType()).commit()
+
+ assert given_table.schema() == Schema(
+ NestedField(field_id=1, name="x", field_type=LongType(),
required=True),
+ NestedField(field_id=2, name="y", field_type=LongType(),
required=True, doc="comment"),
+ NestedField(field_id=3, name="z", field_type=LongType(),
required=True),
+ NestedField(field_id=4, name="new_column1", field_type=IntegerType(),
required=False),
+ schema_id=0,
+ identifier_field_ids=[],
+ )
+
+ transaction = given_table.transaction()
+ transaction.update_schema().add_column(name="new_column2",
type_var=IntegerType(), doc="doc").commit()
+ transaction.commit_transaction()
+
+ assert given_table.schema() == Schema(
+ NestedField(field_id=1, name="x", field_type=LongType(),
required=True),
+ NestedField(field_id=2, name="y", field_type=LongType(),
required=True, doc="comment"),
+ NestedField(field_id=3, name="z", field_type=LongType(),
required=True),
+ NestedField(field_id=4, name="new_column1", field_type=IntegerType(),
required=False),
+ NestedField(field_id=5, name="new_column2", field_type=IntegerType(),
required=False, doc="doc"),
+ schema_id=0,
+ identifier_field_ids=[],
+ )
+
+
+def test_add_column_with_statement(catalog: InMemoryCatalog) -> None:
+ given_table = given_catalog_has_a_table(catalog)
+
+ with given_table.update_schema() as tx:
+ tx.add_column(name="new_column1", type_var=IntegerType())
+
+ assert given_table.schema() == Schema(
+ NestedField(field_id=1, name="x", field_type=LongType(),
required=True),
+ NestedField(field_id=2, name="y", field_type=LongType(),
required=True, doc="comment"),
+ NestedField(field_id=3, name="z", field_type=LongType(),
required=True),
+ NestedField(field_id=4, name="new_column1", field_type=IntegerType(),
required=False),
+ schema_id=0,
+ identifier_field_ids=[],
+ )
+
+ with given_table.transaction() as tx:
+ tx.update_schema().add_column(name="new_column2",
type_var=IntegerType(), doc="doc").commit()
+
+ assert given_table.schema() == Schema(
+ NestedField(field_id=1, name="x", field_type=LongType(),
required=True),
+ NestedField(field_id=2, name="y", field_type=LongType(),
required=True, doc="comment"),
+ NestedField(field_id=3, name="z", field_type=LongType(),
required=True),
+ NestedField(field_id=4, name="new_column1", field_type=IntegerType(),
required=False),
+ NestedField(field_id=5, name="new_column2", field_type=IntegerType(),
required=False, doc="doc"),
+ schema_id=0,
+ identifier_field_ids=[],
+ )
diff --git a/python/tests/cli/test_console.py b/python/tests/cli/test_console.py
index 12c82c2cde..45eb4dd1be 100644
--- a/python/tests/cli/test_console.py
+++ b/python/tests/cli/test_console.py
@@ -25,6 +25,7 @@ from pyiceberg.partitioning import PartitionField,
PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import Properties
+from pyiceberg.types import LongType, NestedField
from pyiceberg.utils.config import Config
from tests.catalog.test_base import InMemoryCatalog
@@ -62,8 +63,12 @@ TEST_TABLE_IDENTIFIER = ("default", "my_table")
TEST_TABLE_NAMESPACE = "default"
TEST_NAMESPACE_PROPERTIES = {"location": "s3://warehouse/database/location"}
TEST_TABLE_NAME = "my_table"
-TEST_TABLE_SCHEMA = Schema(schema_id=0)
-TEST_TABLE_LOCATION = "protocol://some/location"
+TEST_TABLE_SCHEMA = Schema(
+ NestedField(1, "x", LongType()),
+ NestedField(2, "y", LongType(), doc="comment"),
+ NestedField(3, "z", LongType()),
+)
+TEST_TABLE_LOCATION = "s3://bucket/test/location"
TEST_TABLE_PARTITION_SPEC = PartitionSpec(PartitionField(name="x",
transform=IdentityTransform(), source_id=1, field_id=1000))
TEST_TABLE_PROPERTIES = {"read.split.target.size": "134217728"}
MOCK_ENVIRONMENT = {"PYICEBERG_CATALOG__PRODUCTION__URI":
"test://doesnotexist"}
@@ -558,7 +563,7 @@ def test_json_describe_table(catalog: InMemoryCatalog) ->
None:
assert result.exit_code == 0
assert (
result.output
- ==
"""{"identifier":["default","my_table"],"metadata_location":"s3://warehouse/default/my_table/metadata/metadata.json","metadata":{"location":"s3://bucket/test/location","table-uuid":"d20125c8-7284-442c-9aea-15fee620737c","last-updated-ms":1602638573874,"last-column-id":3,"schemas":[{"type":"struct","fields":[{"id":1,"name":"x","type":"long","required":true},{"id":2,"name":"y","type":"long","required":true,"doc":"comment"},{"id":3,"name":"z","type":"long","required":true}],"sche
[...]
+ ==
"""{"identifier":["default","my_table"],"metadata_location":"s3://warehouse/default/my_table/metadata/metadata.json","metadata":{"location":"s3://bucket/test/location","table-uuid":"d20125c8-7284-442c-9aea-15fee620737c","last-updated-ms":1602638573874,"last-column-id":3,"schemas":[{"type":"struct","fields":[{"id":1,"name":"x","type":"long","required":true},{"id":2,"name":"y","type":"long","required":true,"doc":"comment"},{"id":3,"name":"z","type":"long","required":true}],"sche
[...]
)
diff --git a/python/tests/conftest.py b/python/tests/conftest.py
index 9a560284ea..67fc992780 100644
--- a/python/tests/conftest.py
+++ b/python/tests/conftest.py
@@ -57,6 +57,7 @@ from pyarrow import parquet as pq
from pyiceberg import schema
from pyiceberg.catalog import Catalog
+from pyiceberg.catalog.noop import NoopCatalog
from pyiceberg.io import (
GCS_ENDPOINT,
GCS_PROJECT_ID,
@@ -65,13 +66,14 @@ from pyiceberg.io import (
OutputFile,
OutputStream,
fsspec,
+ load_file_io,
)
from pyiceberg.io.fsspec import FsspecFileIO
from pyiceberg.io.pyarrow import PyArrowFile, PyArrowFileIO
from pyiceberg.manifest import DataFile, FileFormat
from pyiceberg.schema import Schema
from pyiceberg.serializers import ToOutputFile
-from pyiceberg.table import FileScanTask
+from pyiceberg.table import FileScanTask, Table
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import (
BinaryType,
@@ -194,6 +196,63 @@ def table_schema_nested() -> Schema:
)
[email protected](scope="session")
+def table_schema_nested_with_struct_key_map() -> Schema:
+ return schema.Schema(
+ NestedField(field_id=1, name="foo", field_type=StringType(),
required=False),
+ NestedField(field_id=2, name="bar", field_type=IntegerType(),
required=True),
+ NestedField(field_id=3, name="baz", field_type=BooleanType(),
required=False),
+ NestedField(
+ field_id=4,
+ name="qux",
+ field_type=ListType(element_id=5, element_type=StringType(),
element_required=True),
+ required=True,
+ ),
+ NestedField(
+ field_id=6,
+ name="quux",
+ field_type=MapType(
+ key_id=7,
+ key_type=StringType(),
+ value_id=8,
+ value_type=MapType(key_id=9, key_type=StringType(),
value_id=10, value_type=IntegerType(), value_required=True),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ NestedField(
+ field_id=11,
+ name="location",
+ field_type=MapType(
+ key_id=18,
+ value_id=19,
+ key_type=StructType(
+ NestedField(field_id=21, name="address",
field_type=StringType(), required=False),
+ NestedField(field_id=22, name="city",
field_type=StringType(), required=False),
+ NestedField(field_id=23, name="zip",
field_type=IntegerType(), required=False),
+ ),
+ value_type=StructType(
+ NestedField(field_id=13, name="latitude",
field_type=FloatType(), required=False),
+ NestedField(field_id=14, name="longitude",
field_type=FloatType(), required=False),
+ ),
+ value_required=True,
+ ),
+ required=True,
+ ),
+ NestedField(
+ field_id=15,
+ name="person",
+ field_type=StructType(
+ NestedField(field_id=16, name="name", field_type=StringType(),
required=False),
+ NestedField(field_id=17, name="age", field_type=IntegerType(),
required=True),
+ ),
+ required=False,
+ ),
+ schema_id=1,
+ identifier_field_ids=[1],
+ )
+
+
@pytest.fixture(scope="session")
def all_avro_types() -> Dict[str, Any]:
return {
@@ -1561,3 +1620,15 @@ def example_task(data_file: str) -> FileScanTask:
return FileScanTask(
data_file=DataFile(file_path=data_file,
file_format=FileFormat.PARQUET, file_size_in_bytes=1925),
)
+
+
[email protected]
+def table(example_table_metadata_v2: Dict[str, Any]) -> Table:
+ table_metadata = TableMetadataV2(**example_table_metadata_v2)
+ return Table(
+ identifier=("database", "table"),
+ metadata=table_metadata,
+ metadata_location=f"{table_metadata.location}/uuid.metadata.json",
+ io=load_file_io(),
+ catalog=NoopCatalog("NoopCatalog"),
+ )
diff --git a/python/tests/table/test_init.py b/python/tests/table/test_init.py
index 2587fb76d9..b25e445032 100644
--- a/python/tests/table/test_init.py
+++ b/python/tests/table/test_init.py
@@ -15,19 +15,18 @@
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name
-from typing import Any, Dict
+from typing import Dict
import pytest
from sortedcontainers import SortedList
-from pyiceberg.catalog.noop import NoopCatalog
from pyiceberg.expressions import (
AlwaysTrue,
And,
EqualTo,
In,
)
-from pyiceberg.io import PY_IO_IMPL, load_file_io
+from pyiceberg.io import PY_IO_IMPL
from pyiceberg.manifest import (
DataFile,
DataFileContent,
@@ -41,9 +40,10 @@ from pyiceberg.table import (
SetPropertiesUpdate,
StaticTable,
Table,
+ UpdateSchema,
_match_deletes_to_datafile,
)
-from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataV2
+from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER
from pyiceberg.table.snapshots import (
Operation,
Snapshot,
@@ -57,19 +57,25 @@ from pyiceberg.table.sorting import (
SortOrder,
)
from pyiceberg.transforms import BucketTransform, IdentityTransform
-from pyiceberg.types import LongType, NestedField
-
-
[email protected]
-def table(example_table_metadata_v2: Dict[str, Any]) -> Table:
- table_metadata = TableMetadataV2(**example_table_metadata_v2)
- return Table(
- identifier=("database", "table"),
- metadata=table_metadata,
- metadata_location=f"{table_metadata.location}/uuid.metadata.json",
- io=load_file_io(),
- catalog=NoopCatalog("NoopCatalog"),
- )
+from pyiceberg.types import (
+ BinaryType,
+ BooleanType,
+ DateType,
+ DoubleType,
+ FloatType,
+ IntegerType,
+ ListType,
+ LongType,
+ MapType,
+ NestedField,
+ PrimitiveType,
+ StringType,
+ StructType,
+ TimestampType,
+ TimestamptzType,
+ TimeType,
+ UUIDType,
+)
def test_schema(table: Table) -> None:
@@ -388,3 +394,176 @@ def test_match_deletes_to_datafile_duplicate_number() ->
None:
def test_serialize_set_properties_updates() -> None:
assert SetPropertiesUpdate(updates={"abc": "🤪"}).model_dump_json() ==
"""{"action":"set-properties","updates":{"abc":"🤪"}}"""
+
+
+def test_add_column(table_schema_simple: Schema, table: Table) -> None:
+ update = UpdateSchema(table_schema_simple, table)
+ update.add_column(name="b", type_var=IntegerType())
+ apply_schema: Schema = update._apply() # pylint: disable=W0212
+ assert len(apply_schema.fields) == 4
+
+ assert apply_schema == Schema(
+ NestedField(field_id=1, name="foo", field_type=StringType(),
required=False),
+ NestedField(field_id=2, name="bar", field_type=IntegerType(),
required=True),
+ NestedField(field_id=3, name="baz", field_type=BooleanType(),
required=False),
+ NestedField(field_id=4, name="b", field_type=IntegerType(),
required=False),
+ )
+ assert apply_schema.schema_id == 0
+ assert apply_schema.highest_field_id == 4
+
+
+def test_add_primitive_type_column(table_schema_simple: Schema, table: Table)
-> None:
+ primitive_type: Dict[str, PrimitiveType] = {
+ "boolean": BooleanType(),
+ "int": IntegerType(),
+ "long": LongType(),
+ "float": FloatType(),
+ "double": DoubleType(),
+ "date": DateType(),
+ "time": TimeType(),
+ "timestamp": TimestampType(),
+ "timestamptz": TimestamptzType(),
+ "string": StringType(),
+ "uuid": UUIDType(),
+ "binary": BinaryType(),
+ }
+
+ for name, type_ in primitive_type.items():
+ field_name = f"new_column_{name}"
+ update = UpdateSchema(table_schema_simple, table)
+ update.add_column(parent=None, name=field_name, type_var=type_,
doc=f"new_column_{name}")
+ new_schema = update._apply() # pylint: disable=W0212
+
+ field: NestedField = new_schema.find_field(field_name)
+ assert field.field_type == type_
+ assert field.doc == f"new_column_{name}"
+
+
+def test_add_nested_type_column(table_schema_simple: Schema, table: Table) ->
None:
+ # add struct type column
+ field_name = "new_column_struct"
+ update = UpdateSchema(table_schema_simple, table)
+ struct_ = StructType(
+ NestedField(1, "lat", DoubleType()),
+ NestedField(2, "long", DoubleType()),
+ )
+ update.add_column(parent=None, name=field_name, type_var=struct_)
+ schema_ = update._apply() # pylint: disable=W0212
+ field: NestedField = schema_.find_field(field_name)
+ assert field.field_type == StructType(
+ NestedField(5, "lat", DoubleType()),
+ NestedField(6, "long", DoubleType()),
+ )
+ assert schema_.highest_field_id == 6
+
+
+def test_add_nested_map_type_column(table_schema_simple: Schema, table: Table)
-> None:
+ # add map type column
+ field_name = "new_column_map"
+ update = UpdateSchema(table_schema_simple, table)
+ map_ = MapType(1, StringType(), 2, IntegerType(), False)
+ update.add_column(parent=None, name=field_name, type_var=map_)
+ new_schema = update._apply() # pylint: disable=W0212
+ field: NestedField = new_schema.find_field(field_name)
+ assert field.field_type == MapType(5, StringType(), 6, IntegerType(),
False)
+ assert new_schema.highest_field_id == 6
+
+
+def test_add_nested_list_type_column(table_schema_simple: Schema, table:
Table) -> None:
+ # add list type column
+ field_name = "new_column_list"
+ update = UpdateSchema(table_schema_simple, table)
+ list_ = ListType(
+ element_id=101,
+ element_type=StructType(
+ NestedField(102, "lat", DoubleType()),
+ NestedField(103, "long", DoubleType()),
+ ),
+ element_required=False,
+ )
+ update.add_column(parent=None, name=field_name, type_var=list_)
+ new_schema = update._apply() # pylint: disable=W0212
+ field: NestedField = new_schema.find_field(field_name)
+ assert field.field_type == ListType(
+ element_id=5,
+ element_type=StructType(
+ NestedField(6, "lat", DoubleType()),
+ NestedField(7, "long", DoubleType()),
+ ),
+ element_required=False,
+ )
+ assert new_schema.highest_field_id == 7
+
+
+def test_add_field_to_map_key(table_schema_nested_with_struct_key_map: Schema,
table: Table) -> None:
+ with pytest.raises(ValueError) as exc_info:
+ update = UpdateSchema(table_schema_nested_with_struct_key_map, table)
+ update.add_column(name="b", type_var=IntegerType(),
parent="location.key")._apply() # pylint: disable=W0212
+ assert "Cannot add fields to map keys" in str(exc_info.value)
+
+
+def test_add_already_exists(table_schema_nested: Schema, table: Table) -> None:
+ with pytest.raises(ValueError) as exc_info:
+ update = UpdateSchema(table_schema_nested, table)
+ update.add_column("foo", IntegerType())
+ assert "already exists: foo" in str(exc_info.value)
+
+ with pytest.raises(ValueError) as exc_info:
+ update = UpdateSchema(table_schema_nested, table)
+ update.add_column(name="latitude", type_var=IntegerType(),
parent="location")
+ assert "already exists: location.lat" in str(exc_info.value)
+
+
+def test_add_to_non_struct_type(table_schema_simple: Schema, table: Table) ->
None:
+ with pytest.raises(ValueError) as exc_info:
+ update = UpdateSchema(table_schema_simple, table)
+ update.add_column(name="lat", type_var=IntegerType(), parent="foo")
+ assert "Cannot add column to non-struct type" in str(exc_info.value)
+
+
+def test_add_required_column(table: Table) -> None:
+ schema_ = Schema(
+ NestedField(field_id=1, name="a", field_type=BooleanType(),
required=False), schema_id=1, identifier_field_ids=[]
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ update = UpdateSchema(schema_, table)
+ update.add_column(name="data", type_var=IntegerType(), required=True)
+ assert "Incompatible change: cannot add required column: data" in
str(exc_info.value)
+
+ new_schema = (
+ UpdateSchema(schema_, table) # pylint: disable=W0212
+ .allow_incompatible_changes()
+ .add_column(name="data", type_var=IntegerType(), required=True)
+ ._apply()
+ )
+ assert new_schema == Schema(
+ NestedField(field_id=1, name="a", field_type=BooleanType(),
required=False),
+ NestedField(field_id=2, name="data", field_type=IntegerType(),
required=True),
+ schema_id=0,
+ identifier_field_ids=[],
+ )
+
+
+def test_add_required_column_case_insensitive(table: Table) -> None:
+ schema_ = Schema(
+ NestedField(field_id=1, name="id", field_type=BooleanType(),
required=False), schema_id=1, identifier_field_ids=[]
+ )
+
+ with pytest.raises(ValueError) as exc_info:
+ update = UpdateSchema(schema_, table)
+
update.allow_incompatible_changes().case_sensitive(False).add_column(name="ID",
type_var=IntegerType(), required=True)
+ assert "already exists: ID" in str(exc_info.value)
+
+ new_schema = (
+ UpdateSchema(schema_, table) # pylint: disable=W0212
+ .allow_incompatible_changes()
+ .add_column(name="ID", type_var=IntegerType(), required=True)
+ ._apply()
+ )
+ assert new_schema == Schema(
+ NestedField(field_id=1, name="id", field_type=BooleanType(),
required=False),
+ NestedField(field_id=2, name="ID", field_type=IntegerType(),
required=True),
+ schema_id=0,
+ identifier_field_ids=[],
+ )
diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py
index a63436bdae..acd6946774 100644
--- a/python/tests/test_integration.py
+++ b/python/tests/test_integration.py
@@ -25,7 +25,7 @@ import pytest
from pyarrow.fs import S3FileSystem
from pyiceberg.catalog import Catalog, load_catalog
-from pyiceberg.exceptions import NoSuchTableError
+from pyiceberg.exceptions import CommitFailedException, NoSuchTableError
from pyiceberg.expressions import (
And,
EqualTo,
@@ -40,10 +40,14 @@ from pyiceberg.schema import Schema
from pyiceberg.table import Table
from pyiceberg.types import (
BooleanType,
+ DoubleType,
+ FixedType,
IntegerType,
+ LongType,
NestedField,
StringType,
TimestampType,
+ UUIDType,
)
@@ -352,3 +356,89 @@ def test_unpartitioned_fixed_table(catalog: Catalog) ->
None:
b"12345678901234567ass12345",
b"qweeqwwqq1231231231231111",
]
+
+
[email protected]
+def test_schema_evolution(catalog: Catalog) -> None:
+ try:
+ catalog.drop_table("default.test_schema_evolution")
+ except NoSuchTableError:
+ pass
+
+ schema = Schema(
+ NestedField(field_id=1, name="col_uuid", field_type=UUIDType(),
required=False),
+ NestedField(field_id=2, name="col_fixed", field_type=FixedType(25),
required=False),
+ )
+
+ t = catalog.create_table(identifier="default.test_schema_evolution",
schema=schema)
+
+ assert t.schema() == schema
+
+ with t.update_schema() as tx:
+ tx.add_column("col_string", StringType())
+
+ assert t.schema() == Schema(
+ NestedField(field_id=1, name="col_uuid", field_type=UUIDType(),
required=False),
+ NestedField(field_id=2, name="col_fixed", field_type=FixedType(25),
required=False),
+ NestedField(field_id=3, name="col_string", field_type=StringType(),
required=False),
+ schema_id=1,
+ )
+
+
[email protected]
+def test_schema_evolution_via_transaction(catalog: Catalog) -> None:
+ try:
+ catalog.drop_table("default.test_schema_evolution")
+ except NoSuchTableError:
+ pass
+
+ schema = Schema(
+ NestedField(field_id=1, name="col_uuid", field_type=UUIDType(),
required=False),
+ NestedField(field_id=2, name="col_fixed", field_type=FixedType(25),
required=False),
+ )
+
+ tbl = catalog.create_table(identifier="default.test_schema_evolution",
schema=schema)
+
+ assert tbl.schema() == schema
+
+ with tbl.transaction() as tx:
+ tx.update_schema().add_column("col_string", StringType()).commit()
+
+ assert tbl.schema() == Schema(
+ NestedField(field_id=1, name="col_uuid", field_type=UUIDType(),
required=False),
+ NestedField(field_id=2, name="col_fixed", field_type=FixedType(25),
required=False),
+ NestedField(field_id=3, name="col_string", field_type=StringType(),
required=False),
+ schema_id=1,
+ )
+
+ tbl.update_schema().add_column("col_integer", IntegerType()).commit()
+
+ assert tbl.schema() == Schema(
+ NestedField(field_id=1, name="col_uuid", field_type=UUIDType(),
required=False),
+ NestedField(field_id=2, name="col_fixed", field_type=FixedType(25),
required=False),
+ NestedField(field_id=3, name="col_string", field_type=StringType(),
required=False),
+ NestedField(field_id=4, name="col_integer", field_type=IntegerType(),
required=False),
+ schema_id=1,
+ )
+
+ with pytest.raises(CommitFailedException) as exc_info:
+ with tbl.transaction() as tx:
+ # Start a new update
+ schema_update = tx.update_schema()
+
+ # Do a concurrent update
+ tbl.update_schema().add_column("col_long", LongType()).commit()
+
+ # stage another update in the transaction
+ schema_update.add_column("col_double", DoubleType()).commit()
+
+ assert "Requirement failed: current schema changed: expected id 2 != 3" in
str(exc_info.value)
+
+ assert tbl.schema() == Schema(
+ NestedField(field_id=1, name="col_uuid", field_type=UUIDType(),
required=False),
+ NestedField(field_id=2, name="col_fixed", field_type=FixedType(25),
required=False),
+ NestedField(field_id=3, name="col_string", field_type=StringType(),
required=False),
+ NestedField(field_id=4, name="col_integer", field_type=IntegerType(),
required=False),
+ NestedField(field_id=5, name="col_long", field_type=LongType(),
required=False),
+ schema_id=1,
+ )
diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py
index 57f1947346..50d788b953 100644
--- a/python/tests/test_schema.py
+++ b/python/tests/test_schema.py
@@ -334,78 +334,6 @@ def test_schema_find_field_type_by_id(table_schema_simple:
Schema) -> None:
assert index[3] == NestedField(field_id=3, name="baz",
field_type=BooleanType(), required=False)
-def test_index_by_id_schema_visitor(table_schema_nested: Schema) -> None:
- """Test the index_by_id function that uses the IndexById schema visitor"""
- assert schema.index_by_id(table_schema_nested) == {
- 1: NestedField(field_id=1, name="foo", field_type=StringType(),
required=False),
- 2: NestedField(field_id=2, name="bar", field_type=IntegerType(),
required=True),
- 3: NestedField(field_id=3, name="baz", field_type=BooleanType(),
required=False),
- 4: NestedField(
- field_id=4,
- name="qux",
- field_type=ListType(element_id=5, element_type=StringType(),
element_required=True),
- required=True,
- ),
- 5: NestedField(field_id=5, name="element", field_type=StringType(),
required=True),
- 6: NestedField(
- field_id=6,
- name="quux",
- field_type=MapType(
- key_id=7,
- key_type=StringType(),
- value_id=8,
- value_type=MapType(key_id=9, key_type=StringType(),
value_id=10, value_type=IntegerType(), value_required=True),
- value_required=True,
- ),
- required=True,
- ),
- 7: NestedField(field_id=7, name="key", field_type=StringType(),
required=True),
- 8: NestedField(
- field_id=8,
- name="value",
- field_type=MapType(key_id=9, key_type=StringType(), value_id=10,
value_type=IntegerType(), value_required=True),
- required=True,
- ),
- 9: NestedField(field_id=9, name="key", field_type=StringType(),
required=True),
- 10: NestedField(field_id=10, name="value", field_type=IntegerType(),
required=True),
- 11: NestedField(
- field_id=11,
- name="location",
- field_type=ListType(
- element_id=12,
- element_type=StructType(
- NestedField(field_id=13, name="latitude",
field_type=FloatType(), required=False),
- NestedField(field_id=14, name="longitude",
field_type=FloatType(), required=False),
- ),
- element_required=True,
- ),
- required=True,
- ),
- 12: NestedField(
- field_id=12,
- name="element",
- field_type=StructType(
- NestedField(field_id=13, name="latitude",
field_type=FloatType(), required=False),
- NestedField(field_id=14, name="longitude",
field_type=FloatType(), required=False),
- ),
- required=True,
- ),
- 13: NestedField(field_id=13, name="latitude", field_type=FloatType(),
required=False),
- 14: NestedField(field_id=14, name="longitude", field_type=FloatType(),
required=False),
- 15: NestedField(
- field_id=15,
- name="person",
- field_type=StructType(
- NestedField(field_id=16, name="name", field_type=StringType(),
required=False),
- NestedField(field_id=17, name="age", field_type=IntegerType(),
required=True),
- ),
- required=False,
- ),
- 16: NestedField(field_id=16, name="name", field_type=StringType(),
required=False),
- 17: NestedField(field_id=17, name="age", field_type=IntegerType(),
required=True),
- }
-
-
def test_index_by_id_schema_visitor_raise_on_unregistered_type() -> None:
"""Test raising a NotImplementedError when an invalid type is provided to
the index_by_id function"""
with pytest.raises(NotImplementedError) as exc_info: