This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new a077c73e Fix CommitTableRequest serialisation (#525)
a077c73e is described below

commit a077c73eeeb3d2dee87662806def9d65ddc061b8
Author: Kieran <[email protected]>
AuthorDate: Sun Mar 17 19:44:39 2024 +0000

    Fix CommitTableRequest serialisation (#525)
    
    * add failing test
    
    * make requirements a discriminated union
    
    * use discriminated type union
    
    * add return type
    
    * use type annotation
    
    * have requirements inherit from ValidatableTableRequirement
    
    * AddSortOrder filter by type
    
    * lint
    
    ---------
    
    Co-authored-by: Kieran Higgins <[email protected]>
---
 pyiceberg/table/__init__.py | 148 +++++++++++++++++++++++---------------------
 tests/table/test_init.py    |  12 ++++
 2 files changed, 90 insertions(+), 70 deletions(-)

diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 4fb14e7d..c75a0a59 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -43,7 +43,7 @@ from typing import (
     Union,
 )
 
-from pydantic import Field, SerializeAsAny, field_validator
+from pydantic import Field, field_validator
 from sortedcontainers import SortedList
 from typing_extensions import Annotated
 
@@ -383,77 +383,56 @@ class Transaction:
             return self._table
 
 
-class TableUpdateAction(Enum):
-    upgrade_format_version = "upgrade-format-version"
-    add_schema = "add-schema"
-    set_current_schema = "set-current-schema"
-    add_spec = "add-spec"
-    set_default_spec = "set-default-spec"
-    add_sort_order = "add-sort-order"
-    set_default_sort_order = "set-default-sort-order"
-    add_snapshot = "add-snapshot"
-    set_snapshot_ref = "set-snapshot-ref"
-    remove_snapshots = "remove-snapshots"
-    remove_snapshot_ref = "remove-snapshot-ref"
-    set_location = "set-location"
-    set_properties = "set-properties"
-    remove_properties = "remove-properties"
-
-
-class TableUpdate(IcebergBaseModel):
-    action: TableUpdateAction
-
-
-class UpgradeFormatVersionUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.upgrade_format_version
+class UpgradeFormatVersionUpdate(IcebergBaseModel):
+    action: Literal['upgrade-format-version'] = 
Field(default="upgrade-format-version")
     format_version: int = Field(alias="format-version")
 
 
-class AddSchemaUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.add_schema
+class AddSchemaUpdate(IcebergBaseModel):
+    action: Literal['add-schema'] = Field(default="add-schema")
     schema_: Schema = Field(alias="schema")
     # This field is required: https://github.com/apache/iceberg/pull/7445
     last_column_id: int = Field(alias="last-column-id")
 
 
-class SetCurrentSchemaUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.set_current_schema
+class SetCurrentSchemaUpdate(IcebergBaseModel):
+    action: Literal['set-current-schema'] = Field(default="set-current-schema")
     schema_id: int = Field(
         alias="schema-id", description="Schema ID to set as current, or -1 to 
set last added schema", default=-1
     )
 
 
-class AddPartitionSpecUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.add_spec
+class AddPartitionSpecUpdate(IcebergBaseModel):
+    action: Literal['add-spec'] = Field(default="add-spec")
     spec: PartitionSpec
 
 
-class SetDefaultSpecUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.set_default_spec
+class SetDefaultSpecUpdate(IcebergBaseModel):
+    action: Literal['set-default-spec'] = Field(default="set-default-spec")
     spec_id: int = Field(
         alias="spec-id", description="Partition spec ID to set as the default, 
or -1 to set last added spec", default=-1
     )
 
 
-class AddSortOrderUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.add_sort_order
+class AddSortOrderUpdate(IcebergBaseModel):
+    action: Literal['add-sort-order'] = Field(default="add-sort-order")
     sort_order: SortOrder = Field(alias="sort-order")
 
 
-class SetDefaultSortOrderUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.set_default_sort_order
+class SetDefaultSortOrderUpdate(IcebergBaseModel):
+    action: Literal['set-default-sort-order'] = 
Field(default="set-default-sort-order")
     sort_order_id: int = Field(
         alias="sort-order-id", description="Sort order ID to set as the 
default, or -1 to set last added sort order", default=-1
     )
 
 
-class AddSnapshotUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.add_snapshot
+class AddSnapshotUpdate(IcebergBaseModel):
+    action: Literal['add-snapshot'] = Field(default="add-snapshot")
     snapshot: Snapshot
 
 
-class SetSnapshotRefUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.set_snapshot_ref
+class SetSnapshotRefUpdate(IcebergBaseModel):
+    action: Literal['set-snapshot-ref'] = Field(default="set-snapshot-ref")
     ref_name: str = Field(alias="ref-name")
     type: Literal["tag", "branch"]
     snapshot_id: int = Field(alias="snapshot-id")
@@ -462,23 +441,23 @@ class SetSnapshotRefUpdate(TableUpdate):
     min_snapshots_to_keep: Annotated[Optional[int], 
Field(alias="min-snapshots-to-keep", default=None)]
 
 
-class RemoveSnapshotsUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.remove_snapshots
+class RemoveSnapshotsUpdate(IcebergBaseModel):
+    action: Literal['remove-snapshots'] = Field(default="remove-snapshots")
     snapshot_ids: List[int] = Field(alias="snapshot-ids")
 
 
-class RemoveSnapshotRefUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.remove_snapshot_ref
+class RemoveSnapshotRefUpdate(IcebergBaseModel):
+    action: Literal['remove-snapshot-ref'] = 
Field(default="remove-snapshot-ref")
     ref_name: str = Field(alias="ref-name")
 
 
-class SetLocationUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.set_location
+class SetLocationUpdate(IcebergBaseModel):
+    action: Literal['set-location'] = Field(default="set-location")
     location: str
 
 
-class SetPropertiesUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.set_properties
+class SetPropertiesUpdate(IcebergBaseModel):
+    action: Literal['set-properties'] = Field(default="set-properties")
     updates: Dict[str, str]
 
     @field_validator('updates', mode='before')
@@ -486,11 +465,32 @@ class SetPropertiesUpdate(TableUpdate):
         return transform_dict_value_to_str(properties)
 
 
-class RemovePropertiesUpdate(TableUpdate):
-    action: TableUpdateAction = TableUpdateAction.remove_properties
+class RemovePropertiesUpdate(IcebergBaseModel):
+    action: Literal['remove-properties'] = Field(default="remove-properties")
     removals: List[str]
 
 
+TableUpdate = Annotated[
+    Union[
+        UpgradeFormatVersionUpdate,
+        AddSchemaUpdate,
+        SetCurrentSchemaUpdate,
+        AddPartitionSpecUpdate,
+        SetDefaultSpecUpdate,
+        AddSortOrderUpdate,
+        SetDefaultSortOrderUpdate,
+        AddSnapshotUpdate,
+        SetSnapshotRefUpdate,
+        RemoveSnapshotsUpdate,
+        RemoveSnapshotRefUpdate,
+        SetLocationUpdate,
+        SetPropertiesUpdate,
+        RemovePropertiesUpdate,
+    ],
+    Field(discriminator='action'),
+]
+
+
 class _TableMetadataUpdateContext:
     _updates: List[TableUpdate]
 
@@ -502,21 +502,15 @@ class _TableMetadataUpdateContext:
 
     def is_added_snapshot(self, snapshot_id: int) -> bool:
         return any(
-            update.snapshot.snapshot_id == snapshot_id
-            for update in self._updates
-            if update.action == TableUpdateAction.add_snapshot
+            update.snapshot.snapshot_id == snapshot_id for update in 
self._updates if isinstance(update, AddSnapshotUpdate)
         )
 
     def is_added_schema(self, schema_id: int) -> bool:
-        return any(
-            update.schema_.schema_id == schema_id for update in self._updates 
if update.action == TableUpdateAction.add_schema
-        )
+        return any(update.schema_.schema_id == schema_id for update in 
self._updates if isinstance(update, AddSchemaUpdate))
 
     def is_added_sort_order(self, sort_order_id: int) -> bool:
         return any(
-            update.sort_order.order_id == sort_order_id
-            for update in self._updates
-            if update.action == TableUpdateAction.add_sort_order
+            update.sort_order.order_id == sort_order_id for update in 
self._updates if isinstance(update, AddSortOrderUpdate)
         )
 
 
@@ -767,7 +761,7 @@ def update_table_metadata(base_metadata: TableMetadata, 
updates: Tuple[TableUpda
     return new_metadata.model_copy(deep=True)
 
 
-class TableRequirement(IcebergBaseModel):
+class ValidatableTableRequirement(IcebergBaseModel):
     type: str
 
     @abstractmethod
@@ -783,7 +777,7 @@ class TableRequirement(IcebergBaseModel):
         ...
 
 
-class AssertCreate(TableRequirement):
+class AssertCreate(ValidatableTableRequirement):
     """The table must not already exist; used for create transactions."""
 
     type: Literal["assert-create"] = Field(default="assert-create")
@@ -793,7 +787,7 @@ class AssertCreate(TableRequirement):
             raise CommitFailedException("Table already exists")
 
 
-class AssertTableUUID(TableRequirement):
+class AssertTableUUID(ValidatableTableRequirement):
     """The table UUID must match the requirement's `uuid`."""
 
     type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid")
@@ -806,7 +800,7 @@ class AssertTableUUID(TableRequirement):
             raise CommitFailedException(f"Table UUID does not match: 
{self.uuid} != {base_metadata.table_uuid}")
 
 
-class AssertRefSnapshotId(TableRequirement):
+class AssertRefSnapshotId(ValidatableTableRequirement):
     """The table branch or tag identified by the requirement's `ref` must 
reference the requirement's `snapshot-id`.
 
     if `snapshot-id` is `null` or missing, the ref must not already exist.
@@ -831,7 +825,7 @@ class AssertRefSnapshotId(TableRequirement):
             raise CommitFailedException(f"Requirement failed: branch or tag 
{self.ref} is missing, expected {self.snapshot_id}")
 
 
-class AssertLastAssignedFieldId(TableRequirement):
+class AssertLastAssignedFieldId(ValidatableTableRequirement):
     """The table's last assigned column id must match the requirement's 
`last-assigned-field-id`."""
 
     type: Literal["assert-last-assigned-field-id"] = 
Field(default="assert-last-assigned-field-id")
@@ -846,7 +840,7 @@ class AssertLastAssignedFieldId(TableRequirement):
             )
 
 
-class AssertCurrentSchemaId(TableRequirement):
+class AssertCurrentSchemaId(ValidatableTableRequirement):
     """The table's current schema id must match the requirement's 
`current-schema-id`."""
 
     type: Literal["assert-current-schema-id"] = 
Field(default="assert-current-schema-id")
@@ -861,7 +855,7 @@ class AssertCurrentSchemaId(TableRequirement):
             )
 
 
-class AssertLastAssignedPartitionId(TableRequirement):
+class AssertLastAssignedPartitionId(ValidatableTableRequirement):
     """The table's last assigned partition id must match the requirement's 
`last-assigned-partition-id`."""
 
     type: Literal["assert-last-assigned-partition-id"] = 
Field(default="assert-last-assigned-partition-id")
@@ -876,7 +870,7 @@ class AssertLastAssignedPartitionId(TableRequirement):
             )
 
 
-class AssertDefaultSpecId(TableRequirement):
+class AssertDefaultSpecId(ValidatableTableRequirement):
     """The table's default spec id must match the requirement's 
`default-spec-id`."""
 
     type: Literal["assert-default-spec-id"] = 
Field(default="assert-default-spec-id")
@@ -891,7 +885,7 @@ class AssertDefaultSpecId(TableRequirement):
             )
 
 
-class AssertDefaultSortOrderId(TableRequirement):
+class AssertDefaultSortOrderId(ValidatableTableRequirement):
     """The table's default sort order id must match the requirement's 
`default-sort-order-id`."""
 
     type: Literal["assert-default-sort-order-id"] = 
Field(default="assert-default-sort-order-id")
@@ -906,6 +900,20 @@ class AssertDefaultSortOrderId(TableRequirement):
             )
 
 
+TableRequirement = Annotated[
+    Union[
+        AssertCreate,
+        AssertTableUUID,
+        AssertRefSnapshotId,
+        AssertLastAssignedFieldId,
+        AssertCurrentSchemaId,
+        AssertLastAssignedPartitionId,
+        AssertDefaultSpecId,
+        AssertDefaultSortOrderId,
+    ],
+    Field(discriminator='type'),
+]
+
 UpdatesAndRequirements = Tuple[Tuple[TableUpdate, ...], 
Tuple[TableRequirement, ...]]
 
 
@@ -927,8 +935,8 @@ class TableIdentifier(IcebergBaseModel):
 
 class CommitTableRequest(IcebergBaseModel):
     identifier: TableIdentifier = Field()
-    requirements: Tuple[SerializeAsAny[TableRequirement], ...] = 
Field(default_factory=tuple)
-    updates: Tuple[SerializeAsAny[TableUpdate], ...] = 
Field(default_factory=tuple)
+    requirements: Tuple[TableRequirement, ...] = Field(default_factory=tuple)
+    updates: Tuple[TableUpdate, ...] = Field(default_factory=tuple)
 
 
 class CommitTableResponse(IcebergBaseModel):
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index f7342115..bb212d69 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -53,12 +53,14 @@ from pyiceberg.table import (
     AssertLastAssignedPartitionId,
     AssertRefSnapshotId,
     AssertTableUUID,
+    CommitTableRequest,
     RemovePropertiesUpdate,
     SetDefaultSortOrderUpdate,
     SetPropertiesUpdate,
     SetSnapshotRefUpdate,
     StaticTable,
     Table,
+    TableIdentifier,
     UpdateSchema,
     _apply_table_update,
     _check_schema,
@@ -1113,3 +1115,13 @@ def 
test_table_properties_raise_for_none_value(example_table_metadata_v2: Dict[s
     with pytest.raises(ValidationError) as exc_info:
         TableMetadataV2(**example_table_metadata_v2)
     assert "None type is not a supported value in properties: property_name" 
in str(exc_info.value)
+
+
+def test_serialize_commit_table_request() -> None:
+    request = CommitTableRequest(
+        
requirements=(AssertTableUUID(uuid='4bfd18a3-74c6-478e-98b1-71c4c32f4163'),),
+        identifier=TableIdentifier(namespace=['a'], name='b'),
+    )
+
+    deserialized_request = 
CommitTableRequest.model_validate_json(request.model_dump_json())
+    assert request == deserialized_request

Reply via email to