This is an automated email from the ASF dual-hosted git repository.

kevinjqliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 8e4d4248 feat: Add support for rolling back to snapshot (#2878)
8e4d4248 is described below

commit 8e4d42483bbab92fc76538e32574c0336b3dd9d5
Author: geruh <[email protected]>
AuthorDate: Wed Jan 14 18:36:21 2026 -0800

    feat: Add support for rolling back to snapshot (#2878)
    
    # Rationale for this change
    
    This PR adds the ability to rollback a table to a ancestoral snapshot.
    Some of this work was also done in #758, and is a progress pr to be
    merged after #2871.
    
    Additionally, adding some more tests.
    
    ## Are these changes tested?
    
    Yes
    
    ## Are there any user-facing changes?
    
    New API
---
 pyiceberg/table/update/snapshot.py            |  35 +++++++++
 tests/integration/test_snapshot_operations.py | 108 ++++++++++++++++++++++++++
 2 files changed, 143 insertions(+)

diff --git a/pyiceberg/table/update/snapshot.py 
b/pyiceberg/table/update/snapshot.py
index bc05aab9..987200bf 100644
--- a/pyiceberg/table/update/snapshot.py
+++ b/pyiceberg/table/update/snapshot.py
@@ -64,6 +64,7 @@ from pyiceberg.table.snapshots import (
     Snapshot,
     SnapshotSummaryCollector,
     Summary,
+    ancestors_of,
     update_snapshot_summaries,
 )
 from pyiceberg.table.update import (
@@ -985,6 +986,40 @@ class 
ManageSnapshots(UpdateTableMetadata["ManageSnapshots"]):
         self._transaction._stage(update, requirement)
         return self
 
+    def rollback_to_snapshot(self, snapshot_id: int) -> ManageSnapshots:
+        """Rollback the table to the given snapshot id.
+
+        The snapshot needs to be an ancestor of the current table state.
+
+        Args:
+            snapshot_id (int): rollback to this snapshot_id that used to be 
current.
+
+        Returns:
+            This for method chaining
+
+        Raises:
+            ValueError: If the snapshot does not exist or is not an ancestor 
of the current table state.
+        """
+        if not self._transaction.table_metadata.snapshot_by_id(snapshot_id):
+            raise ValueError(f"Cannot roll back to unknown snapshot id: 
{snapshot_id}")
+
+        if not self._is_current_ancestor(snapshot_id):
+            raise ValueError(f"Cannot roll back to snapshot, not an ancestor 
of the current state: {snapshot_id}")
+
+        return self.set_current_snapshot(snapshot_id=snapshot_id)
+
+    def _is_current_ancestor(self, snapshot_id: int) -> bool:
+        return snapshot_id in self._current_ancestors()
+
+    def _current_ancestors(self) -> set[int]:
+        return {
+            a.snapshot_id
+            for a in ancestors_of(
+                self._transaction.table_metadata.current_snapshot(),
+                self._transaction.table_metadata,
+            )
+        }
+
 
 class ExpireSnapshots(UpdateTableMetadata["ExpireSnapshots"]):
     """Expire snapshots by ID.
diff --git a/tests/integration/test_snapshot_operations.py 
b/tests/integration/test_snapshot_operations.py
index 2f0447ec..8755e95f 100644
--- a/tests/integration/test_snapshot_operations.py
+++ b/tests/integration/test_snapshot_operations.py
@@ -14,12 +14,44 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import uuid
+from collections.abc import Generator
+
+import pyarrow as pa
 import pytest
 
 from pyiceberg.catalog import Catalog
+from pyiceberg.table import Table
 from pyiceberg.table.refs import SnapshotRef
 
 
[email protected]
+def table_with_snapshots(session_catalog: Catalog) -> Generator[Table, None, 
None]:
+    session_catalog.create_namespace_if_not_exists("default")
+    identifier = f"default.test_table_snapshot_ops_{uuid.uuid4().hex[:8]}"
+
+    arrow_schema = pa.schema(
+        [
+            pa.field("id", pa.int64(), nullable=False),
+            pa.field("data", pa.string(), nullable=True),
+        ]
+    )
+
+    tbl = session_catalog.create_table(identifier=identifier, 
schema=arrow_schema)
+
+    data1 = pa.Table.from_pylist([{"id": 1, "data": "a"}, {"id": 2, "data": 
"b"}], schema=arrow_schema)
+    tbl.append(data1)
+
+    data2 = pa.Table.from_pylist([{"id": 3, "data": "c"}, {"id": 4, "data": 
"d"}], schema=arrow_schema)
+    tbl.append(data2)
+
+    tbl = session_catalog.load_table(identifier)
+
+    yield tbl
+
+    session_catalog.drop_table(identifier)
+
+
 @pytest.mark.integration
 @pytest.mark.parametrize("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
 def test_create_tag(catalog: Catalog) -> None:
@@ -160,3 +192,79 @@ def 
test_set_current_snapshot_chained_with_create_tag(catalog: Catalog) -> None:
     tbl = catalog.load_table(identifier)
     tbl.manage_snapshots().remove_tag(tag_name=tag_name).commit()
     assert tbl.metadata.refs.get(tag_name, None) is None
+
+
[email protected]
+def test_rollback_to_snapshot(table_with_snapshots: Table) -> None:
+    history = table_with_snapshots.history()
+    assert len(history) >= 2
+
+    ancestor_snapshot_id = history[-2].snapshot_id
+
+    
table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=ancestor_snapshot_id).commit()
+
+    updated = table_with_snapshots.current_snapshot()
+    assert updated is not None
+    assert updated.snapshot_id == ancestor_snapshot_id
+
+
[email protected]
+def test_rollback_to_current_snapshot(table_with_snapshots: Table) -> None:
+    current = table_with_snapshots.current_snapshot()
+    assert current is not None
+
+    
table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=current.snapshot_id).commit()
+
+    updated = table_with_snapshots.current_snapshot()
+    assert updated is not None
+    assert updated.snapshot_id == current.snapshot_id
+
+
[email protected]
+def test_rollback_to_snapshot_chained_with_tag(table_with_snapshots: Table) -> 
None:
+    history = table_with_snapshots.history()
+    assert len(history) >= 2
+
+    ancestor_snapshot_id = history[-2].snapshot_id
+    tag_name = "my-tag"
+
+    (
+        table_with_snapshots.manage_snapshots()
+        .create_tag(snapshot_id=ancestor_snapshot_id, tag_name=tag_name)
+        .rollback_to_snapshot(snapshot_id=ancestor_snapshot_id)
+        .commit()
+    )
+
+    updated = table_with_snapshots.current_snapshot()
+    assert updated is not None
+    assert updated.snapshot_id == ancestor_snapshot_id
+    assert table_with_snapshots.metadata.refs[tag_name] == 
SnapshotRef(snapshot_id=ancestor_snapshot_id, snapshot_ref_type="tag")
+
+
[email protected]
+def test_rollback_to_snapshot_not_ancestor(table_with_snapshots: Table) -> 
None:
+    history = table_with_snapshots.history()
+    assert len(history) >= 2
+
+    snapshot_a = history[-2].snapshot_id
+
+    branch_name = "my-branch"
+    
table_with_snapshots.manage_snapshots().create_branch(snapshot_id=snapshot_a, 
branch_name=branch_name).commit()
+
+    data = pa.Table.from_pylist([{"id": 5, "data": "e"}], 
schema=table_with_snapshots.schema().as_arrow())
+    table_with_snapshots.append(data, branch=branch_name)
+
+    snapshot_c = table_with_snapshots.metadata.snapshot_by_name(branch_name)
+    assert snapshot_c is not None
+    assert snapshot_c.snapshot_id != snapshot_a
+
+    with pytest.raises(ValueError, match="not an ancestor"):
+        
table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=snapshot_c.snapshot_id).commit()
+
+
[email protected]
+def test_rollback_to_snapshot_unknown_id(table_with_snapshots: Table) -> None:
+    invalid_snapshot_id = 1234567890000
+
+    with pytest.raises(ValueError, match="Cannot roll back to unknown snapshot 
id"):
+        
table_with_snapshots.manage_snapshots().rollback_to_snapshot(snapshot_id=invalid_snapshot_id).commit()

Reply via email to