This is an automated email from the ASF dual-hosted git repository.

kevinjqliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new b0880c85 feat: Add Set Current Snapshot to ManageSnapshots API (#2871)
b0880c85 is described below

commit b0880c855b8dfd5c03019afcc5dd67a26432ce23
Author: geruh <[email protected]>
AuthorDate: Mon Jan 12 09:01:32 2026 -0800

    feat: Add Set Current Snapshot to ManageSnapshots API (#2871)
    
    # Rationale for this change
    
    This PR adds the ability to change the set the current snapshot of a
    table. A bulk of this work was done in #758 but instead we have broken
    it out to focus on the set snapshot logic first. Additionally I added a
    few more tests, following the existing expire snapshots behavior.
    
    
    ## Are these changes tested?
    
    Yes, added tests
    
    ## Are there any user-facing changes?
    
    New API :)
    
    ```
    
table.manage_snapshots().set_current_snapshot(snapshot_id=123456789).commit()
    
    
    table.manage_snapshots().set_current_snapshot(ref_name="my-tag").commit()
    
    # chaining
    table.manage_snapshots() \
          .create_tag(snapshot_id=older_id, tag_name="my-tag") \
          .set_current_snapshot(ref_name="my-tag") \
          .commit()
    
    ```
    
    ---------
    
    Co-authored-by: Chinmay Bhat 
<[email protected]>
---
 pyiceberg/table/__init__.py                   |  26 +++-
 pyiceberg/table/update/snapshot.py            |  44 +++++++
 tests/integration/test_snapshot_operations.py |  88 +++++++++++++
 tests/table/test_manage_snapshots.py          | 179 ++++++++++++++++++++++++++
 4 files changed, 335 insertions(+), 2 deletions(-)

diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 88a7bd00..ae5eb400 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -275,8 +275,20 @@ class Transaction:
         if exctype is None and excinst is None and exctb is None:
             self.commit_transaction()
 
-    def _apply(self, updates: tuple[TableUpdate, ...], requirements: 
tuple[TableRequirement, ...] = ()) -> Transaction:
-        """Check if the requirements are met, and applies the updates to the 
metadata."""
+    def _stage(
+        self,
+        updates: tuple[TableUpdate, ...],
+        requirements: tuple[TableRequirement, ...] = (),
+    ) -> Transaction:
+        """Stage updates to the transaction state without committing to the 
catalog.
+
+        Args:
+            updates: The updates to stage.
+            requirements: The requirements that must be met.
+
+        Returns:
+            This transaction for method chaining.
+        """
         for requirement in requirements:
             requirement.validate(self.table_metadata)
 
@@ -289,6 +301,16 @@ class Transaction:
             if type(new_requirement) not in existing_requirements:
                 self._requirements = self._requirements + (new_requirement,)
 
+        return self
+
+    def _apply(
+        self,
+        updates: tuple[TableUpdate, ...],
+        requirements: tuple[TableRequirement, ...] = (),
+    ) -> Transaction:
+        """Check if the requirements are met, and applies the updates to the 
metadata."""
+        self._stage(updates, requirements)
+
         if self._autocommit:
             self.commit_transaction()
 
diff --git a/pyiceberg/table/update/snapshot.py 
b/pyiceberg/table/update/snapshot.py
index 84298e08..bc05aab9 100644
--- a/pyiceberg/table/update/snapshot.py
+++ b/pyiceberg/table/update/snapshot.py
@@ -843,6 +843,13 @@ class 
ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
         """Apply the pending changes and commit."""
         return self._updates, self._requirements
 
+    def _commit_if_ref_updates_exist(self) -> None:
+        """Stage any pending ref updates to the transaction state."""
+        if self._updates:
+            self._transaction._stage(*self._commit())
+            self._updates = ()
+            self._requirements = ()
+
     def _remove_ref_snapshot(self, ref_name: str) -> ManageSnapshots:
         """Remove a snapshot ref.
 
@@ -941,6 +948,43 @@ class 
ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
         """
         return self._remove_ref_snapshot(ref_name=branch_name)
 
+    def set_current_snapshot(self, snapshot_id: int | None = None, ref_name: 
str | None = None) -> ManageSnapshots:
+        """Set the current snapshot to a specific snapshot ID or ref.
+
+        Args:
+            snapshot_id: The ID of the snapshot to set as current.
+            ref_name: The snapshot reference (branch or tag) to set as current.
+
+        Returns:
+            This for method chaining.
+
+        Raises:
+            ValueError: If neither or both arguments are provided, or if the 
snapshot/ref does not exist.
+        """
+        self._commit_if_ref_updates_exist()
+
+        if (snapshot_id is None) == (ref_name is None):
+            raise ValueError("Either snapshot_id or ref_name must be provided, 
not both")
+
+        target_snapshot_id: int
+        if snapshot_id is not None:
+            target_snapshot_id = snapshot_id
+        else:
+            if ref_name not in self._transaction.table_metadata.refs:
+                raise ValueError(f"Cannot find matching snapshot ID for ref: 
{ref_name}")
+            target_snapshot_id = 
self._transaction.table_metadata.refs[ref_name].snapshot_id
+
+        if self._transaction.table_metadata.snapshot_by_id(target_snapshot_id) 
is None:
+            raise ValueError(f"Cannot set current snapshot to unknown snapshot 
id: {target_snapshot_id}")
+
+        update, requirement = self._transaction._set_ref_snapshot(
+            snapshot_id=target_snapshot_id,
+            ref_name=MAIN_BRANCH,
+            type=SnapshotRefType.BRANCH,
+        )
+        self._transaction._stage(update, requirement)
+        return self
+
 
 class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
     """Expire snapshots by ID.
diff --git a/tests/integration/test_snapshot_operations.py 
b/tests/integration/test_snapshot_operations.py
index 1b7f2d3a..2f0447ec 100644
--- a/tests/integration/test_snapshot_operations.py
+++ b/tests/integration/test_snapshot_operations.py
@@ -72,3 +72,91 @@ def test_remove_branch(catalog: Catalog) -> None:
     # now, remove the branch
     tbl.manage_snapshots().remove_branch(branch_name=branch_name).commit()
     assert tbl.metadata.refs.get(branch_name, None) is None
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_set_current_snapshot(catalog: Catalog) -> None:
+    identifier = "default.test_table_snapshot_operations"
+    tbl = catalog.load_table(identifier)
+    assert len(tbl.history()) > 2
+
+    # first get the current snapshot and an older one
+    current_snapshot_id = tbl.history()[-1].snapshot_id
+    older_snapshot_id = tbl.history()[-2].snapshot_id
+
+    # set the current snapshot to the older one
+    
tbl.manage_snapshots().set_current_snapshot(snapshot_id=older_snapshot_id).commit()
+
+    tbl = catalog.load_table(identifier)
+    updated_snapshot = tbl.current_snapshot()
+    assert updated_snapshot and updated_snapshot.snapshot_id == 
older_snapshot_id
+
+    # restore table
+    
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
+    tbl = catalog.load_table(identifier)
+    restored_snapshot = tbl.current_snapshot()
+    assert restored_snapshot and restored_snapshot.snapshot_id == 
current_snapshot_id
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_set_current_snapshot_by_ref(catalog: Catalog) -> None:
+    identifier = "default.test_table_snapshot_operations"
+    tbl = catalog.load_table(identifier)
+    assert len(tbl.history()) > 2
+
+    # first get the current snapshot and an older one
+    current_snapshot_id = tbl.history()[-1].snapshot_id
+    older_snapshot_id = tbl.history()[-2].snapshot_id
+    assert older_snapshot_id != current_snapshot_id
+
+    # create a tag pointing to the older snapshot
+    tag_name = "my-tag"
+    tbl.manage_snapshots().create_tag(snapshot_id=older_snapshot_id, 
tag_name=tag_name).commit()
+
+    # set current snapshot using the tag name
+    tbl = catalog.load_table(identifier)
+    tbl.manage_snapshots().set_current_snapshot(ref_name=tag_name).commit()
+
+    tbl = catalog.load_table(identifier)
+    updated_snapshot = tbl.current_snapshot()
+    assert updated_snapshot and updated_snapshot.snapshot_id == 
older_snapshot_id
+
+    # restore table
+    
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
+    tbl = catalog.load_table(identifier)
+    tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
+    assert tbl.metadata.refs.get(tag_name, None) is None
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> 
None:
+    identifier = "default.test_table_snapshot_operations"
+    tbl = catalog.load_table(identifier)
+    assert len(tbl.history()) > 2
+
+    current_snapshot_id = tbl.history()[-1].snapshot_id
+    older_snapshot_id = tbl.history()[-2].snapshot_id
+    assert older_snapshot_id != current_snapshot_id
+
+    # create a tag and use it to set current snapshot
+    tag_name = "my-tag"
+    (
+        tbl.manage_snapshots()
+        .create_tag(snapshot_id=older_snapshot_id, tag_name=tag_name)
+        .set_current_snapshot(ref_name=tag_name)
+        .commit()
+    )
+
+    tbl = catalog.load_table(identifier)
+    updated_snapshot = tbl.current_snapshot()
+    assert updated_snapshot
+    assert updated_snapshot.snapshot_id == older_snapshot_id
+
+    # restore table
+    
tbl.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot_id).commit()
+    tbl = catalog.load_table(identifier)
+    tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
+    assert tbl.metadata.refs.get(tag_name, None) is None
diff --git a/tests/table/test_manage_snapshots.py 
b/tests/table/test_manage_snapshots.py
new file mode 100644
index 00000000..93301a01
--- /dev/null
+++ b/tests/table/test_manage_snapshots.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from unittest.mock import MagicMock
+from uuid import uuid4
+
+import pytest
+
+from pyiceberg.table import CommitTableResponse, Table
+from pyiceberg.table.update import SetSnapshotRefUpdate, TableUpdate
+
+
+def _mock_commit_response(table: Table) -> CommitTableResponse:
+    return CommitTableResponse(
+        metadata=table.metadata,
+        metadata_location="s3://bucket/tbl",
+        uuid=uuid4(),
+    )
+
+
+def _get_updates(mock_catalog: MagicMock) -> tuple[TableUpdate, ...]:
+    args, _ = mock_catalog.commit_table.call_args
+    return args[2]
+
+
+def test_set_current_snapshot_basic(table_v2: Table) -> None:
+    snapshot_one = 3051729675574597004
+
+    table_v2.catalog = MagicMock()
+    table_v2.catalog.commit_table.return_value = 
_mock_commit_response(table_v2)
+
+    
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).commit()
+
+    table_v2.catalog.commit_table.assert_called_once()
+
+    updates = _get_updates(table_v2.catalog)
+    set_ref_updates = [u for u in updates if isinstance(u, 
SetSnapshotRefUpdate)]
+
+    assert len(set_ref_updates) == 1
+    update = set_ref_updates[0]
+    assert update.snapshot_id == snapshot_one
+    assert update.ref_name == "main"
+    assert update.type == "branch"
+
+
+def test_set_current_snapshot_unknown_id(table_v2: Table) -> None:
+    invalid_snapshot_id = 1234567890000
+    table_v2.catalog = MagicMock()
+
+    with pytest.raises(ValueError, match="Cannot set current snapshot to 
unknown snapshot id"):
+        
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=invalid_snapshot_id).commit()
+
+    table_v2.catalog.commit_table.assert_not_called()
+
+
+def test_set_current_snapshot_to_current(table_v2: Table) -> None:
+    current_snapshot = table_v2.current_snapshot()
+    assert current_snapshot is not None
+
+    table_v2.catalog = MagicMock()
+    table_v2.catalog.commit_table.return_value = 
_mock_commit_response(table_v2)
+
+    
table_v2.manage_snapshots().set_current_snapshot(snapshot_id=current_snapshot.snapshot_id).commit()
+
+    table_v2.catalog.commit_table.assert_called_once()
+
+
+def test_set_current_snapshot_chained_with_tag(table_v2: Table) -> None:
+    snapshot_one = 3051729675574597004
+    table_v2.catalog = MagicMock()
+    table_v2.catalog.commit_table.return_value = 
_mock_commit_response(table_v2)
+
+    
(table_v2.manage_snapshots().set_current_snapshot(snapshot_id=snapshot_one).create_tag(snapshot_one,
 "my-tag").commit())
+
+    table_v2.catalog.commit_table.assert_called_once()
+
+    updates = _get_updates(table_v2.catalog)
+    set_ref_updates = [u for u in updates if isinstance(u, 
SetSnapshotRefUpdate)]
+
+    assert len(set_ref_updates) == 2
+    assert {u.ref_name for u in set_ref_updates} == {"main", "my-tag"}
+
+
+def 
test_set_current_snapshot_with_extensive_snapshots(table_v2_with_extensive_snapshots:
 Table) -> None:
+    snapshots = table_v2_with_extensive_snapshots.metadata.snapshots
+    assert len(snapshots) > 100
+
+    target_snapshot = snapshots[50].snapshot_id
+
+    table_v2_with_extensive_snapshots.catalog = MagicMock()
+    table_v2_with_extensive_snapshots.catalog.commit_table.return_value = 
_mock_commit_response(table_v2_with_extensive_snapshots)
+
+    
table_v2_with_extensive_snapshots.manage_snapshots().set_current_snapshot(snapshot_id=target_snapshot).commit()
+
+    table_v2_with_extensive_snapshots.catalog.commit_table.assert_called_once()
+
+    updates = _get_updates(table_v2_with_extensive_snapshots.catalog)
+    set_ref_updates = [u for u in updates if isinstance(u, 
SetSnapshotRefUpdate)]
+
+    assert len(set_ref_updates) == 1
+    assert set_ref_updates[0].snapshot_id == target_snapshot
+
+
+def test_set_current_snapshot_by_ref_name(table_v2: Table) -> None:
+    current_snapshot = table_v2.current_snapshot()
+    assert current_snapshot is not None
+
+    table_v2.catalog = MagicMock()
+    table_v2.catalog.commit_table.return_value = 
_mock_commit_response(table_v2)
+
+    table_v2.manage_snapshots().set_current_snapshot(ref_name="main").commit()
+
+    updates = _get_updates(table_v2.catalog)
+    set_ref_updates = [u for u in updates if isinstance(u, 
SetSnapshotRefUpdate)]
+
+    assert len(set_ref_updates) == 1
+    assert set_ref_updates[0].snapshot_id == current_snapshot.snapshot_id
+    assert set_ref_updates[0].ref_name == "main"
+
+
+def test_set_current_snapshot_unknown_ref(table_v2: Table) -> None:
+    table_v2.catalog = MagicMock()
+
+    with pytest.raises(ValueError, match="Cannot find matching snapshot ID for 
ref: nonexistent"):
+        
table_v2.manage_snapshots().set_current_snapshot(ref_name="nonexistent").commit()
+
+    table_v2.catalog.commit_table.assert_not_called()
+
+
+def test_set_current_snapshot_requires_one_argument(table_v2: Table) -> None:
+    table_v2.catalog = MagicMock()
+
+    with pytest.raises(ValueError, match="Either snapshot_id or ref_name must 
be provided, not both"):
+        table_v2.manage_snapshots().set_current_snapshot().commit()
+
+    with pytest.raises(ValueError, match="Either snapshot_id or ref_name must 
be provided, not both"):
+        table_v2.manage_snapshots().set_current_snapshot(snapshot_id=123, 
ref_name="main").commit()
+
+    table_v2.catalog.commit_table.assert_not_called()
+
+
+def test_set_current_snapshot_chained_with_create_tag(table_v2: Table) -> None:
+    snapshot_one = 3051729675574597004
+    table_v2.catalog = MagicMock()
+    table_v2.catalog.commit_table.return_value = 
_mock_commit_response(table_v2)
+
+    # create a tag and immediately use it to set current snapshot
+    (
+        table_v2.manage_snapshots()
+        .create_tag(snapshot_id=snapshot_one, tag_name="new-tag")
+        .set_current_snapshot(ref_name="new-tag")
+        .commit()
+    )
+
+    table_v2.catalog.commit_table.assert_called_once()
+
+    updates = _get_updates(table_v2.catalog)
+    set_ref_updates = [u for u in updates if isinstance(u, 
SetSnapshotRefUpdate)]
+
+    # should have the tag and the main branch update
+    assert len(set_ref_updates) == 2
+    assert {u.ref_name for u in set_ref_updates} == {"new-tag", "main"}
+
+    # The main branch should point to the same snapshot as the tag
+    main_update = next(u for u in set_ref_updates if u.ref_name == "main")
+    assert main_update.snapshot_id == snapshot_one

Reply via email to