This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b911057 Support Time Travel in InspectTable.entries (#599)
4b911057 is described below

commit 4b911057f13491f30f89f133544c063133133fa5
Author: Sung Yun <[email protected]>
AuthorDate: Sat Apr 13 15:07:51 2024 -0400

    Support Time Travel in InspectTable.entries (#599)
    
    * time travel in entries table
    
    * undo
    
    * Update pyiceberg/table/__init__.py
    
    Co-authored-by: Fokko Driesprong <[email protected]>
    
    * adopt review feedback
    
    * docs
    
    ---------
    
    Co-authored-by: Fokko Driesprong <[email protected]>
---
 mkdocs/docs/api.md                      |  12 +++
 pyiceberg/table/__init__.py             | 128 ++++++++++++++------------
 tests/integration/test_inspect_table.py | 156 ++++++++++++++++----------------
 3 files changed, 162 insertions(+), 134 deletions(-)

diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md
index 15931d02..9bdb6dcd 100644
--- a/mkdocs/docs/api.md
+++ b/mkdocs/docs/api.md
@@ -342,6 +342,18 @@ table.append(df)
 
 To explore the table metadata, tables can be inspected.
 
+<!-- prettier-ignore-start -->
+
+!!! tip "Time Travel"
+    To inspect a tables's metadata with the time travel feature, call the 
inspect table method with the `snapshot_id` argument.
+    Time travel is supported on all metadata tables except `snapshots` and 
`refs`.
+
+    ```python
+    table.inspect.entries(snapshot_id=805611270568163028)
+    ```
+
+<!-- prettier-ignore-end -->
+
 ### Snapshots
 
 Inspect the snapshots of the table:
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index ea813176..da4b1465 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -3253,6 +3253,18 @@ class InspectTable:
         except ModuleNotFoundError as e:
             raise ModuleNotFoundError("For metadata operations PyArrow needs 
to be installed") from e
 
+    def _get_snapshot(self, snapshot_id: Optional[int] = None) -> Snapshot:
+        if snapshot_id is not None:
+            if snapshot := self.tbl.metadata.snapshot_by_id(snapshot_id):
+                return snapshot
+            else:
+                raise ValueError(f"Cannot find snapshot with ID {snapshot_id}")
+
+        if snapshot := self.tbl.metadata.current_snapshot():
+            return snapshot
+        else:
+            raise ValueError("Cannot get a snapshot as the table does not have 
any.")
+
     def snapshots(self) -> "pa.Table":
         import pyarrow as pa
 
@@ -3287,7 +3299,7 @@ class InspectTable:
             schema=snapshots_schema,
         )
 
-    def entries(self) -> "pa.Table":
+    def entries(self, snapshot_id: Optional[int] = None) -> "pa.Table":
         import pyarrow as pa
 
         from pyiceberg.io.pyarrow import schema_to_pyarrow
@@ -3346,64 +3358,64 @@ class InspectTable:
         ])
 
         entries = []
-        if snapshot := self.tbl.metadata.current_snapshot():
-            for manifest in snapshot.manifests(self.tbl.io):
-                for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
-                    column_sizes = entry.data_file.column_sizes or {}
-                    value_counts = entry.data_file.value_counts or {}
-                    null_value_counts = entry.data_file.null_value_counts or {}
-                    nan_value_counts = entry.data_file.nan_value_counts or {}
-                    lower_bounds = entry.data_file.lower_bounds or {}
-                    upper_bounds = entry.data_file.upper_bounds or {}
-                    readable_metrics = {
-                        schema.find_column_name(field.field_id): {
-                            "column_size": column_sizes.get(field.field_id),
-                            "value_count": value_counts.get(field.field_id),
-                            "null_value_count": 
null_value_counts.get(field.field_id),
-                            "nan_value_count": 
nan_value_counts.get(field.field_id),
-                            # Makes them readable
-                            "lower_bound": from_bytes(field.field_type, 
lower_bound)
-                            if (lower_bound := 
lower_bounds.get(field.field_id))
-                            else None,
-                            "upper_bound": from_bytes(field.field_type, 
upper_bound)
-                            if (upper_bound := 
upper_bounds.get(field.field_id))
-                            else None,
-                        }
-                        for field in self.tbl.metadata.schema().fields
-                    }
-
-                    partition = entry.data_file.partition
-                    partition_record_dict = {
-                        field.name: partition[pos]
-                        for pos, field in 
enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
+        snapshot = self._get_snapshot(snapshot_id)
+        for manifest in snapshot.manifests(self.tbl.io):
+            for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
+                column_sizes = entry.data_file.column_sizes or {}
+                value_counts = entry.data_file.value_counts or {}
+                null_value_counts = entry.data_file.null_value_counts or {}
+                nan_value_counts = entry.data_file.nan_value_counts or {}
+                lower_bounds = entry.data_file.lower_bounds or {}
+                upper_bounds = entry.data_file.upper_bounds or {}
+                readable_metrics = {
+                    schema.find_column_name(field.field_id): {
+                        "column_size": column_sizes.get(field.field_id),
+                        "value_count": value_counts.get(field.field_id),
+                        "null_value_count": 
null_value_counts.get(field.field_id),
+                        "nan_value_count": 
nan_value_counts.get(field.field_id),
+                        # Makes them readable
+                        "lower_bound": from_bytes(field.field_type, 
lower_bound)
+                        if (lower_bound := lower_bounds.get(field.field_id))
+                        else None,
+                        "upper_bound": from_bytes(field.field_type, 
upper_bound)
+                        if (upper_bound := upper_bounds.get(field.field_id))
+                        else None,
                     }
-
-                    entries.append({
-                        'status': entry.status.value,
-                        'snapshot_id': entry.snapshot_id,
-                        'sequence_number': entry.data_sequence_number,
-                        'file_sequence_number': entry.file_sequence_number,
-                        'data_file': {
-                            "content": entry.data_file.content,
-                            "file_path": entry.data_file.file_path,
-                            "file_format": entry.data_file.file_format,
-                            "partition": partition_record_dict,
-                            "record_count": entry.data_file.record_count,
-                            "file_size_in_bytes": 
entry.data_file.file_size_in_bytes,
-                            "column_sizes": dict(entry.data_file.column_sizes),
-                            "value_counts": dict(entry.data_file.value_counts),
-                            "null_value_counts": 
dict(entry.data_file.null_value_counts),
-                            "nan_value_counts": 
entry.data_file.nan_value_counts,
-                            "lower_bounds": entry.data_file.lower_bounds,
-                            "upper_bounds": entry.data_file.upper_bounds,
-                            "key_metadata": entry.data_file.key_metadata,
-                            "split_offsets": entry.data_file.split_offsets,
-                            "equality_ids": entry.data_file.equality_ids,
-                            "sort_order_id": entry.data_file.sort_order_id,
-                            "spec_id": entry.data_file.spec_id,
-                        },
-                        'readable_metrics': readable_metrics,
-                    })
+                    for field in self.tbl.metadata.schema().fields
+                }
+
+                partition = entry.data_file.partition
+                partition_record_dict = {
+                    field.name: partition[pos]
+                    for pos, field in 
enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
+                }
+
+                entries.append({
+                    'status': entry.status.value,
+                    'snapshot_id': entry.snapshot_id,
+                    'sequence_number': entry.data_sequence_number,
+                    'file_sequence_number': entry.file_sequence_number,
+                    'data_file': {
+                        "content": entry.data_file.content,
+                        "file_path": entry.data_file.file_path,
+                        "file_format": entry.data_file.file_format,
+                        "partition": partition_record_dict,
+                        "record_count": entry.data_file.record_count,
+                        "file_size_in_bytes": 
entry.data_file.file_size_in_bytes,
+                        "column_sizes": dict(entry.data_file.column_sizes),
+                        "value_counts": dict(entry.data_file.value_counts),
+                        "null_value_counts": 
dict(entry.data_file.null_value_counts),
+                        "nan_value_counts": entry.data_file.nan_value_counts,
+                        "lower_bounds": entry.data_file.lower_bounds,
+                        "upper_bounds": entry.data_file.upper_bounds,
+                        "key_metadata": entry.data_file.key_metadata,
+                        "split_offsets": entry.data_file.split_offsets,
+                        "equality_ids": entry.data_file.equality_ids,
+                        "sort_order_id": entry.data_file.sort_order_id,
+                        "spec_id": entry.data_file.spec_id,
+                    },
+                    'readable_metrics': readable_metrics,
+                })
 
         return pa.Table.from_pylist(
             entries,
diff --git a/tests/integration/test_inspect_table.py 
b/tests/integration/test_inspect_table.py
index f2515cae..7cbfc6da 100644
--- a/tests/integration/test_inspect_table.py
+++ b/tests/integration/test_inspect_table.py
@@ -22,7 +22,7 @@ from datetime import date, datetime
 import pyarrow as pa
 import pytest
 import pytz
-from pyspark.sql import SparkSession
+from pyspark.sql import DataFrame, SparkSession
 
 from pyiceberg.catalog import Catalog
 from pyiceberg.exceptions import NoSuchTableError
@@ -148,81 +148,85 @@ def test_inspect_entries(
     # Write some data
     tbl.append(arrow_table_with_null)
 
-    df = tbl.inspect.entries()
-
-    assert df.column_names == [
-        'status',
-        'snapshot_id',
-        'sequence_number',
-        'file_sequence_number',
-        'data_file',
-        'readable_metrics',
-    ]
-
-    # Make sure that they are filled properly
-    for int_column in ['status', 'snapshot_id', 'sequence_number', 
'file_sequence_number']:
-        for value in df[int_column]:
-            assert isinstance(value.as_py(), int)
-
-    for snapshot_id in df['snapshot_id']:
-        assert isinstance(snapshot_id.as_py(), int)
-
-    lhs = df.to_pandas()
-    rhs = spark.table(f"{identifier}.entries").toPandas()
-    for column in df.column_names:
-        for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
-            if column == 'data_file':
-                right = right.asDict(recursive=True)
-                for df_column in left.keys():
-                    if df_column == 'partition':
-                        # Spark leaves out the partition if the table is 
unpartitioned
-                        continue
-
-                    df_lhs = left[df_column]
-                    df_rhs = right[df_column]
-                    if isinstance(df_rhs, dict):
-                        # Arrow turns dicts into lists of tuple
-                        df_lhs = dict(df_lhs)
-
-                    assert df_lhs == df_rhs, f"Difference in data_file column 
{df_column}: {df_lhs} != {df_rhs}"
-            elif column == 'readable_metrics':
-                right = right.asDict(recursive=True)
-
-                assert list(left.keys()) == [
-                    'bool',
-                    'string',
-                    'string_long',
-                    'int',
-                    'long',
-                    'float',
-                    'double',
-                    'timestamp',
-                    'timestamptz',
-                    'date',
-                    'binary',
-                    'fixed',
-                ]
-
-                assert left.keys() == right.keys()
-
-                for rm_column in left.keys():
-                    rm_lhs = left[rm_column]
-                    rm_rhs = right[rm_column]
-
-                    assert rm_lhs['column_size'] == rm_rhs['column_size']
-                    assert rm_lhs['value_count'] == rm_rhs['value_count']
-                    assert rm_lhs['null_value_count'] == 
rm_rhs['null_value_count']
-                    assert rm_lhs['nan_value_count'] == 
rm_rhs['nan_value_count']
-
-                    if rm_column == 'timestamptz':
-                        # PySpark does not correctly set the timstamptz
-                        rm_rhs['lower_bound'] = 
rm_rhs['lower_bound'].replace(tzinfo=pytz.utc)
-                        rm_rhs['upper_bound'] = 
rm_rhs['upper_bound'].replace(tzinfo=pytz.utc)
-
-                    assert rm_lhs['lower_bound'] == rm_rhs['lower_bound']
-                    assert rm_lhs['upper_bound'] == rm_rhs['upper_bound']
-            else:
-                assert left == right, f"Difference in column {column}: {left} 
!= {right}"
+    def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) 
-> None:
+        assert df.column_names == [
+            'status',
+            'snapshot_id',
+            'sequence_number',
+            'file_sequence_number',
+            'data_file',
+            'readable_metrics',
+        ]
+
+        # Make sure that they are filled properly
+        for int_column in ['status', 'snapshot_id', 'sequence_number', 
'file_sequence_number']:
+            for value in df[int_column]:
+                assert isinstance(value.as_py(), int)
+
+        for snapshot_id in df['snapshot_id']:
+            assert isinstance(snapshot_id.as_py(), int)
+
+        lhs = df.to_pandas()
+        rhs = spark_df.toPandas()
+        for column in df.column_names:
+            for left, right in zip(lhs[column].to_list(), 
rhs[column].to_list()):
+                if column == 'data_file':
+                    right = right.asDict(recursive=True)
+                    for df_column in left.keys():
+                        if df_column == 'partition':
+                            # Spark leaves out the partition if the table is 
unpartitioned
+                            continue
+
+                        df_lhs = left[df_column]
+                        df_rhs = right[df_column]
+                        if isinstance(df_rhs, dict):
+                            # Arrow turns dicts into lists of tuple
+                            df_lhs = dict(df_lhs)
+
+                        assert df_lhs == df_rhs, f"Difference in data_file 
column {df_column}: {df_lhs} != {df_rhs}"
+                elif column == 'readable_metrics':
+                    right = right.asDict(recursive=True)
+
+                    assert list(left.keys()) == [
+                        'bool',
+                        'string',
+                        'string_long',
+                        'int',
+                        'long',
+                        'float',
+                        'double',
+                        'timestamp',
+                        'timestamptz',
+                        'date',
+                        'binary',
+                        'fixed',
+                    ]
+
+                    assert left.keys() == right.keys()
+
+                    for rm_column in left.keys():
+                        rm_lhs = left[rm_column]
+                        rm_rhs = right[rm_column]
+
+                        assert rm_lhs['column_size'] == rm_rhs['column_size']
+                        assert rm_lhs['value_count'] == rm_rhs['value_count']
+                        assert rm_lhs['null_value_count'] == 
rm_rhs['null_value_count']
+                        assert rm_lhs['nan_value_count'] == 
rm_rhs['nan_value_count']
+
+                        if rm_column == 'timestamptz':
+                            # PySpark does not correctly set the timstamptz
+                            rm_rhs['lower_bound'] = 
rm_rhs['lower_bound'].replace(tzinfo=pytz.utc)
+                            rm_rhs['upper_bound'] = 
rm_rhs['upper_bound'].replace(tzinfo=pytz.utc)
+
+                        assert rm_lhs['lower_bound'] == rm_rhs['lower_bound']
+                        assert rm_lhs['upper_bound'] == rm_rhs['upper_bound']
+                else:
+                    assert left == right, f"Difference in column {column}: 
{left} != {right}"
+
+    for snapshot in tbl.metadata.snapshots:
+        df = tbl.inspect.entries(snapshot_id=snapshot.snapshot_id)
+        spark_df = spark.sql(f"SELECT * FROM {identifier}.entries VERSION AS 
OF {snapshot.snapshot_id}")
+        check_pyiceberg_df_equals_spark_df(df, spark_df)
 
 
 @pytest.mark.integration

Reply via email to