This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new a368bd9  Use Pydantic's `model_copy` for model modification (#182)
a368bd9 is described below

commit a368bd9822adb99ba3e5e9bcf49f816b7001cd77
Author: HonahX <[email protected]>
AuthorDate: Wed Dec 6 00:48:15 2023 -0800

    Use Pydantic's `model_copy` for model modification (#182)
    
    * Implement table metadata updater first draft
    
    * fix updater error and add tests
    
    * implement apply_metadata_update which is simpler
    
    * remove old implementation
    
    * re-organize method place
    
    * fix nit
    
    * fix test
    
    * add another test
    
    * clear TODO
    
    * add a combined test
    
    * Fix merge conflict
    
    * remove table requirement validation for PR simplification
    
    * make context private and solve elif issue
    
    * remove private field access
    
    * push snapshot ref validation to its builder using pydantic
    
    * fix comment
    
    * remove unnecessary code for AddSchemaUpdate update
    
    * replace if with elif
    
    * switch to model_copy()
    
    * enhance the set current schema update implementation and some other 
changes
    
    * make apply_table_update private
    
    * fix lint after merge
    
    * add validation
    
    * add test for isolation of illegal updates
    
    * fix nit
    
    * remove unnecessary flag
    
    * change to model_copy(deep=True)
---
 pyiceberg/table/__init__.py | 57 ++++++++++++++++++++++-----------------------
 tests/table/test_init.py    | 51 +++++++++++++++++++++++++++++++++++++++-
 2 files changed, 78 insertions(+), 30 deletions(-)

diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 9aa6c1c..436266f 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -417,12 +417,13 @@ def _(update: AddSchemaUpdate, base_metadata: 
TableMetadata, context: _TableMeta
     if update.last_column_id < base_metadata.last_column_id:
         raise ValueError(f"Invalid last column id {update.last_column_id}, 
must be >= {base_metadata.last_column_id}")
 
-    updated_metadata_data = copy(base_metadata.model_dump())
-    updated_metadata_data["last-column-id"] = update.last_column_id
-    updated_metadata_data["schemas"].append(update.schema_.model_dump())
-
     context.add_update(update)
-    return TableMetadataUtil.parse_obj(updated_metadata_data)
+    return base_metadata.model_copy(
+        update={
+            "last_column_id": update.last_column_id,
+            "schemas": base_metadata.schemas + [update.schema_],
+        }
+    )
 
 
 @_apply_table_update.register(SetCurrentSchemaUpdate)
@@ -441,11 +442,8 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: 
TableMetadata, context: _Ta
     if schema is None:
         raise ValueError(f"Schema with id {new_schema_id} does not exist")
 
-    updated_metadata_data = copy(base_metadata.model_dump())
-    updated_metadata_data["current-schema-id"] = new_schema_id
-
     context.add_update(update)
-    return TableMetadataUtil.parse_obj(updated_metadata_data)
+    return base_metadata.model_copy(update={"current_schema_id": 
new_schema_id})
 
 
 @_apply_table_update.register(AddSnapshotUpdate)
@@ -469,12 +467,14 @@ def _(update: AddSnapshotUpdate, base_metadata: 
TableMetadata, context: _TableMe
             f"older than last sequence number 
{base_metadata.last_sequence_number}"
         )
 
-    updated_metadata_data = copy(base_metadata.model_dump())
-    updated_metadata_data["last-updated-ms"] = update.snapshot.timestamp_ms
-    updated_metadata_data["last-sequence-number"] = 
update.snapshot.sequence_number
-    updated_metadata_data["snapshots"].append(update.snapshot.model_dump())
     context.add_update(update)
-    return TableMetadataUtil.parse_obj(updated_metadata_data)
+    return base_metadata.model_copy(
+        update={
+            "last_updated_ms": update.snapshot.timestamp_ms,
+            "last_sequence_number": update.snapshot.sequence_number,
+            "snapshots": base_metadata.snapshots + [update.snapshot],
+        }
+    )
 
 
 @_apply_table_update.register(SetSnapshotRefUpdate)
@@ -493,28 +493,27 @@ def _(update: SetSnapshotRefUpdate, base_metadata: 
TableMetadata, context: _Tabl
 
     snapshot = base_metadata.snapshot_by_id(snapshot_ref.snapshot_id)
     if snapshot is None:
-        raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown 
snapshot {snapshot_ref.snapshot_id}")
+        raise ValueError(f"Cannot set {update.ref_name} to unknown snapshot 
{snapshot_ref.snapshot_id}")
 
-    update_metadata_data = copy(base_metadata.model_dump())
-    update_last_updated_ms = True
+    metadata_updates: Dict[str, Any] = {}
     if context.is_added_snapshot(snapshot_ref.snapshot_id):
-        update_metadata_data["last-updated-ms"] = snapshot.timestamp_ms
-        update_last_updated_ms = False
+        metadata_updates["last_updated_ms"] = snapshot.timestamp_ms
 
     if update.ref_name == MAIN_BRANCH:
-        update_metadata_data["current-snapshot-id"] = snapshot_ref.snapshot_id
-        if update_last_updated_ms:
-            update_metadata_data["last-updated-ms"] = 
datetime_to_millis(datetime.datetime.now().astimezone())
-        update_metadata_data["snapshot-log"].append(
+        metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id
+        if "last_updated_ms" not in metadata_updates:
+            metadata_updates["last_updated_ms"] = 
datetime_to_millis(datetime.datetime.now().astimezone())
+
+        metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [
             SnapshotLogEntry(
                 snapshot_id=snapshot_ref.snapshot_id,
-                timestamp_ms=update_metadata_data["last-updated-ms"],
-            ).model_dump()
-        )
+                timestamp_ms=metadata_updates["last_updated_ms"],
+            )
+        ]
 
-    update_metadata_data["refs"][update.ref_name] = snapshot_ref.model_dump()
+    metadata_updates["refs"] = {**base_metadata.refs, update.ref_name: 
snapshot_ref}
     context.add_update(update)
-    return TableMetadataUtil.parse_obj(update_metadata_data)
+    return base_metadata.model_copy(update=metadata_updates)
 
 
 def update_table_metadata(base_metadata: TableMetadata, updates: 
Tuple[TableUpdate, ...]) -> TableMetadata:
@@ -533,7 +532,7 @@ def update_table_metadata(base_metadata: TableMetadata, 
updates: Tuple[TableUpda
     for update in updates:
         new_metadata = _apply_table_update(update, new_metadata, context)
 
-    return new_metadata
+    return new_metadata.model_copy(deep=True)
 
 
 class TableRequirement(IcebergBaseModel):
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index 6d188be..8d13a82 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint:disable=redefined-outer-name
+from copy import copy
 from typing import Dict
 
 import pytest
@@ -50,7 +51,7 @@ from pyiceberg.table import (
     _TableMetadataUpdateContext,
     update_table_metadata,
 )
-from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER
+from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, 
TableMetadataUtil, TableMetadataV2
 from pyiceberg.table.snapshots import (
     Operation,
     Snapshot,
@@ -640,9 +641,12 @@ def test_update_metadata_with_multiple_updates(table_v1: 
Table) -> None:
     )
 
     new_metadata = update_table_metadata(base_metadata, test_updates)
+    # rebuild the metadata to trigger validation
+    new_metadata = TableMetadataUtil.parse_obj(copy(new_metadata.model_dump()))
 
     # UpgradeFormatVersionUpdate
     assert new_metadata.format_version == 2
+    assert isinstance(new_metadata, TableMetadataV2)
 
     # UpdateSchema
     assert len(new_metadata.schemas) == 2
@@ -669,6 +673,51 @@ def test_update_metadata_with_multiple_updates(table_v1: 
Table) -> None:
     )
 
 
+def test_metadata_isolation_from_illegal_updates(table_v1: Table) -> None:
+    base_metadata = table_v1.metadata
+    base_metadata_backup = base_metadata.model_copy(deep=True)
+
+    # Apply legal updates on the table metadata
+    transaction = table_v1.transaction()
+    schema_update_1 = transaction.update_schema()
+    schema_update_1.add_column(path="b", field_type=IntegerType())
+    schema_update_1.commit()
+    test_updates = transaction._updates  # pylint: disable=W0212
+    new_snapshot = Snapshot(
+        snapshot_id=25,
+        parent_snapshot_id=19,
+        sequence_number=200,
+        timestamp_ms=1602638573590,
+        manifest_list="s3:/a/b/c.avro",
+        summary=Summary(Operation.APPEND),
+        schema_id=3,
+    )
+    test_updates += (
+        AddSnapshotUpdate(snapshot=new_snapshot),
+        SetSnapshotRefUpdate(
+            ref_name="main",
+            type="branch",
+            snapshot_id=25,
+            max_ref_age_ms=123123123,
+            max_snapshot_age_ms=12312312312,
+            min_snapshots_to_keep=1,
+        ),
+    )
+    new_metadata = update_table_metadata(base_metadata, test_updates)
+
+    # Check that the original metadata is not modified
+    assert base_metadata == base_metadata_backup
+
+    # Perform illegal update on the new metadata:
+    # TableMetadata should be immutable, but the pydantic's frozen config 
cannot prevent
+    # operations such as list append.
+    new_metadata.partition_specs.append(PartitionSpec(spec_id=0))
+    assert len(new_metadata.partition_specs) == 2
+
+    # The original metadata should not be affected by the illegal update on 
the new metadata
+    assert len(base_metadata.partition_specs) == 1
+
+
 def test_generate_snapshot_id(table_v2: Table) -> None:
     assert isinstance(_generate_snapshot_id(), int)
     assert isinstance(table_v2.new_snapshot_id(), int)

Reply via email to