This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 1dde51a0 Support snapshot management operations like creating tags by 
adding `ManageSnapshots` API (#728)
1dde51a0 is described below

commit 1dde51a09097984d8bf298db08171b9d299ffd59
Author: Chinmay Bhat <[email protected]>
AuthorDate: Sat Jun 15 21:42:58 2024 +0530

    Support snapshot management operations like creating tags by adding 
`ManageSnapshots` API (#728)
---
 dev/provision.py                              |  47 +++++++
 mkdocs/docs/api.md                            |  22 ++++
 pyiceberg/table/__init__.py                   | 176 ++++++++++++++++++++++++++
 tests/integration/test_snapshot_operations.py |  42 ++++++
 tests/table/test_init.py                      |  24 ++++
 5 files changed, 311 insertions(+)

diff --git a/dev/provision.py b/dev/provision.py
index 44086caf..6c8fe366 100644
--- a/dev/provision.py
+++ b/dev/provision.py
@@ -342,3 +342,50 @@ for catalog_name, catalog in catalogs.items():
            (array(), map(), array(struct(1)))
     """
     )
+
+    spark.sql(
+        f"""
+        CREATE OR REPLACE TABLE 
{catalog_name}.default.test_table_snapshot_operations (
+            number integer
+        )
+        USING iceberg
+        TBLPROPERTIES (
+            'format-version'='2'
+        );
+        """
+    )
+
+    spark.sql(
+        f"""
+        INSERT INTO {catalog_name}.default.test_table_snapshot_operations
+        VALUES (1)
+        """
+    )
+
+    spark.sql(
+        f"""
+        INSERT INTO {catalog_name}.default.test_table_snapshot_operations
+        VALUES (2)
+        """
+    )
+
+    spark.sql(
+        f"""
+        DELETE FROM {catalog_name}.default.test_table_snapshot_operations
+        WHERE number = 2
+        """
+    )
+
+    spark.sql(
+        f"""
+        INSERT INTO {catalog_name}.default.test_table_snapshot_operations
+        VALUES (3)
+        """
+    )
+
+    spark.sql(
+        f"""
+        INSERT INTO {catalog_name}.default.test_table_snapshot_operations
+        VALUES (4)
+        """
+    )
diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md
index 70b5fd62..6bbd9abe 100644
--- a/mkdocs/docs/api.md
+++ b/mkdocs/docs/api.md
@@ -913,6 +913,28 @@ tbl.overwrite(df, snapshot_properties={"abc": "def"})
 assert tbl.metadata.snapshots[-1].summary["abc"] == "def"
 ```
 
+## Snapshot Management
+
+Manage snapshots with operations through the `Table` API:
+
+```python
+# To run a specific operation
+table.manage_snapshots().create_tag(snapshot_id, "tag123").commit()
+# To run multiple operations
+table.manage_snapshots()
+    .create_tag(snapshot_id1, "tag123")
+    .create_tag(snapshot_id2, "tag456")
+    .commit()
+# Operations are applied on commit.
+```
+
+You can also use context managers to make more changes:
+
+```python
+with table.manage_snapshots() as ms:
+    ms.create_branch(snapshot_id1, "Branch_A").create_tag(snapshot_id2, 
"tag789")
+```
+
 ## Query the data
 
 To query a table, a table scan is needed. A table scan accepts a filter, 
columns, optionally a limit and a snapshot ID:
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 2d7f81a6..9a10fc6b 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -138,6 +138,7 @@ from pyiceberg.types import (
 )
 from pyiceberg.utils.concurrent import ExecutorFactory
 from pyiceberg.utils.datetime import datetime_to_millis
+from pyiceberg.utils.deprecated import deprecated
 from pyiceberg.utils.singleton import _convert_to_hashable_type
 
 if TYPE_CHECKING:
@@ -351,6 +352,88 @@ class Transaction:
         updates = properties or kwargs
         return self._apply((SetPropertiesUpdate(updates=updates),))
 
+    @deprecated(
+        deprecated_in="0.7.0",
+        removed_in="0.8.0",
+        help_message="Please use one of the functions in ManageSnapshots 
instead",
+    )
+    def add_snapshot(self, snapshot: Snapshot) -> Transaction:
+        """Add a new snapshot to the table.
+
+        Returns:
+            The transaction with the add-snapshot staged.
+        """
+        updates = (AddSnapshotUpdate(snapshot=snapshot),)
+
+        return self._apply(updates, ())
+
+    @deprecated(
+        deprecated_in="0.7.0",
+        removed_in="0.8.0",
+        help_message="Please use one of the functions in ManageSnapshots 
instead",
+    )
+    def set_ref_snapshot(
+        self,
+        snapshot_id: int,
+        parent_snapshot_id: Optional[int],
+        ref_name: str,
+        type: str,
+        max_ref_age_ms: Optional[int] = None,
+        max_snapshot_age_ms: Optional[int] = None,
+        min_snapshots_to_keep: Optional[int] = None,
+    ) -> Transaction:
+        """Update a ref to a snapshot.
+
+        Returns:
+            The transaction with the set-snapshot-ref staged
+        """
+        updates = (
+            SetSnapshotRefUpdate(
+                snapshot_id=snapshot_id,
+                ref_name=ref_name,
+                type=type,
+                max_ref_age_ms=max_ref_age_ms,
+                max_snapshot_age_ms=max_snapshot_age_ms,
+                min_snapshots_to_keep=min_snapshots_to_keep,
+            ),
+        )
+
+        requirements = (AssertRefSnapshotId(snapshot_id=parent_snapshot_id, 
ref="main"),)
+        return self._apply(updates, requirements)
+
+    def _set_ref_snapshot(
+        self,
+        snapshot_id: int,
+        ref_name: str,
+        type: str,
+        max_ref_age_ms: Optional[int] = None,
+        max_snapshot_age_ms: Optional[int] = None,
+        min_snapshots_to_keep: Optional[int] = None,
+    ) -> UpdatesAndRequirements:
+        """Update a ref to a snapshot.
+
+        Returns:
+            The updates and requirements for the set-snapshot-ref staged
+        """
+        updates = (
+            SetSnapshotRefUpdate(
+                snapshot_id=snapshot_id,
+                ref_name=ref_name,
+                type=type,
+                max_ref_age_ms=max_ref_age_ms,
+                max_snapshot_age_ms=max_snapshot_age_ms,
+                min_snapshots_to_keep=min_snapshots_to_keep,
+            ),
+        )
+        requirements = (
+            AssertRefSnapshotId(
+                snapshot_id=self.table_metadata.refs[ref_name].snapshot_id if 
ref_name in self.table_metadata.refs else None,
+                ref=ref_name,
+            ),
+        )
+
+        return updates, requirements
+
     def update_schema(self, allow_incompatible_changes: bool = False, 
case_sensitive: bool = True) -> UpdateSchema:
         """Create a new UpdateSchema to alter the columns of this table.
 
@@ -1323,6 +1406,21 @@ class Table:
         """Get the snapshot history of this table."""
         return self.metadata.snapshot_log
 
+    def manage_snapshots(self) -> ManageSnapshots:
+        """
+        Shorthand to run snapshot management operations like create branch, 
create tag, etc.
+
+        Use table.manage_snapshots().<operation>().commit() to run a specific 
operation.
+        Use 
table.manage_snapshots().<operation-one>().<operation-two>().commit() to run 
multiple operations.
+        Pending changes are applied on commit.
+
+        We can also use context managers to make more changes. For example,
+
+        with table.manage_snapshots() as ms:
+           ms.create_tag(snapshot_id1, "Tag_A").create_tag(snapshot_id2, 
"Tag_B")
+        """
+        return ManageSnapshots(transaction=Transaction(self, autocommit=True))
+
     def update_schema(self, allow_incompatible_changes: bool = False, 
case_sensitive: bool = True) -> UpdateSchema:
         """Create a new UpdateSchema to alter the columns of this table.
 
@@ -1835,6 +1933,84 @@ class UpdateTableMetadata(ABC, Generic[U]):
         return self  # type: ignore
 
 
+class ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
+    """
+    Run snapshot management operations using APIs.
+
+    APIs include create branch, create tag, etc.
+
+    Use table.manage_snapshots().<operation>().commit() to run a specific 
operation.
+    Use table.manage_snapshots().<operation-one>().<operation-two>().commit() 
to run multiple operations.
+    Pending changes are applied on commit.
+
+    We can also use context managers to make more changes. For example,
+
+    with table.manage_snapshots() as ms:
+       ms.create_tag(snapshot_id1, "Tag_A").create_tag(snapshot_id2, "Tag_B")
+    """
+
+    _updates: Tuple[TableUpdate, ...] = ()
+    _requirements: Tuple[TableRequirement, ...] = ()
+
+    def _commit(self) -> UpdatesAndRequirements:
+        """Apply the pending changes and commit."""
+        return self._updates, self._requirements
+
+    def create_tag(self, snapshot_id: int, tag_name: str, max_ref_age_ms: 
Optional[int] = None) -> ManageSnapshots:
+        """
+        Create a new tag pointing to the given snapshot id.
+
+        Args:
+            snapshot_id (int): snapshot id of the existing snapshot to tag
+            tag_name (str): name of the tag
+            max_ref_age_ms (Optional[int]): max ref age in milliseconds
+
+        Returns:
+            This for method chaining
+        """
+        update, requirement = self._transaction._set_ref_snapshot(
+            snapshot_id=snapshot_id,
+            ref_name=tag_name,
+            type="tag",
+            max_ref_age_ms=max_ref_age_ms,
+        )
+        self._updates += update
+        self._requirements += requirement
+        return self
+
+    def create_branch(
+        self,
+        snapshot_id: int,
+        branch_name: str,
+        max_ref_age_ms: Optional[int] = None,
+        max_snapshot_age_ms: Optional[int] = None,
+        min_snapshots_to_keep: Optional[int] = None,
+    ) -> ManageSnapshots:
+        """
+        Create a new branch pointing to the given snapshot id.
+
+        Args:
+            snapshot_id (int): snapshot id of existing snapshot at which the 
branch is created.
+            branch_name (str): name of the new branch
+            max_ref_age_ms (Optional[int]): max ref age in milliseconds
+            max_snapshot_age_ms (Optional[int]): max age of snapshots to keep 
in milliseconds
+            min_snapshots_to_keep (Optional[int]): min number of snapshots to 
keep in milliseconds
+        Returns:
+            This for method chaining
+        """
+        update, requirement = self._transaction._set_ref_snapshot(
+            snapshot_id=snapshot_id,
+            ref_name=branch_name,
+            type="branch",
+            max_ref_age_ms=max_ref_age_ms,
+            max_snapshot_age_ms=max_snapshot_age_ms,
+            min_snapshots_to_keep=min_snapshots_to_keep,
+        )
+        self._updates += update
+        self._requirements += requirement
+        return self
+
+
 class UpdateSchema(UpdateTableMetadata["UpdateSchema"]):
     _schema: Schema
     _last_column_id: itertools.count[int]
diff --git a/tests/integration/test_snapshot_operations.py 
b/tests/integration/test_snapshot_operations.py
new file mode 100644
index 00000000..63919338
--- /dev/null
+++ b/tests/integration/test_snapshot_operations.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from pyiceberg.catalog import Catalog
+from pyiceberg.table.refs import SnapshotRef
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_create_tag(catalog: Catalog) -> None:
+    identifier = "default.test_table_snapshot_operations"
+    tbl = catalog.load_table(identifier)
+    assert len(tbl.history()) > 3
+    tag_snapshot_id = tbl.history()[-3].snapshot_id
+    tbl.manage_snapshots().create_tag(snapshot_id=tag_snapshot_id, 
tag_name="tag123").commit()
+    assert tbl.metadata.refs["tag123"] == 
SnapshotRef(snapshot_id=tag_snapshot_id, snapshot_ref_type="tag")
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_create_branch(catalog: Catalog) -> None:
+    identifier = "default.test_table_snapshot_operations"
+    tbl = catalog.load_table(identifier)
+    assert len(tbl.history()) > 2
+    branch_snapshot_id = tbl.history()[-2].snapshot_id
+    tbl.manage_snapshots().create_branch(snapshot_id=branch_snapshot_id, 
branch_name="branch123").commit()
+    assert tbl.metadata.refs["branch123"] == 
SnapshotRef(snapshot_id=branch_snapshot_id, snapshot_ref_type="branch")
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index 20b77b6a..c97b3a4a 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -689,6 +689,30 @@ def test_update_metadata_add_snapshot(table_v2: Table) -> 
None:
     assert new_metadata.last_updated_ms == new_snapshot.timestamp_ms
 
 
+def test_update_metadata_set_ref_snapshot(table_v2: Table) -> None:
+    update, _ = table_v2.transaction()._set_ref_snapshot(
+        snapshot_id=3051729675574597004,
+        ref_name="main",
+        type="branch",
+        max_ref_age_ms=123123123,
+        max_snapshot_age_ms=12312312312,
+        min_snapshots_to_keep=1,
+    )
+
+    new_metadata = update_table_metadata(table_v2.metadata, update)
+    assert len(new_metadata.snapshot_log) == 3
+    assert new_metadata.snapshot_log[2].snapshot_id == 3051729675574597004
+    assert new_metadata.current_snapshot_id == 3051729675574597004
+    assert new_metadata.last_updated_ms > table_v2.metadata.last_updated_ms
+    assert new_metadata.refs["main"] == SnapshotRef(
+        snapshot_id=3051729675574597004,
+        snapshot_ref_type="branch",
+        min_snapshots_to_keep=1,
+        max_snapshot_age_ms=12312312312,
+        max_ref_age_ms=123123123,
+    )
+
+
 def test_update_metadata_set_snapshot_ref(table_v2: Table) -> None:
     update = SetSnapshotRefUpdate(
         ref_name="main",

Reply via email to