This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 4b911057 Support Time Travel in InspectTable.entries (#599)
4b911057 is described below
commit 4b911057f13491f30f89f133544c063133133fa5
Author: Sung Yun <[email protected]>
AuthorDate: Sat Apr 13 15:07:51 2024 -0400
Support Time Travel in InspectTable.entries (#599)
* time travel in entries table
* undo
* Update pyiceberg/table/__init__.py
Co-authored-by: Fokko Driesprong <[email protected]>
* adopt review feedback
* docs
---------
Co-authored-by: Fokko Driesprong <[email protected]>
---
mkdocs/docs/api.md | 12 +++
pyiceberg/table/__init__.py | 128 ++++++++++++++------------
tests/integration/test_inspect_table.py | 156 ++++++++++++++++----------------
3 files changed, 162 insertions(+), 134 deletions(-)
diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md
index 15931d02..9bdb6dcd 100644
--- a/mkdocs/docs/api.md
+++ b/mkdocs/docs/api.md
@@ -342,6 +342,18 @@ table.append(df)
To explore the table metadata, tables can be inspected.
+<!-- prettier-ignore-start -->
+
+!!! tip "Time Travel"
+ To inspect a tables's metadata with the time travel feature, call the
inspect table method with the `snapshot_id` argument.
+ Time travel is supported on all metadata tables except `snapshots` and
`refs`.
+
+ ```python
+ table.inspect.entries(snapshot_id=805611270568163028)
+ ```
+
+<!-- prettier-ignore-end -->
+
### Snapshots
Inspect the snapshots of the table:
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index ea813176..da4b1465 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -3253,6 +3253,18 @@ class InspectTable:
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For metadata operations PyArrow needs
to be installed") from e
+ def _get_snapshot(self, snapshot_id: Optional[int] = None) -> Snapshot:
+ if snapshot_id is not None:
+ if snapshot := self.tbl.metadata.snapshot_by_id(snapshot_id):
+ return snapshot
+ else:
+ raise ValueError(f"Cannot find snapshot with ID {snapshot_id}")
+
+ if snapshot := self.tbl.metadata.current_snapshot():
+ return snapshot
+ else:
+ raise ValueError("Cannot get a snapshot as the table does not have
any.")
+
def snapshots(self) -> "pa.Table":
import pyarrow as pa
@@ -3287,7 +3299,7 @@ class InspectTable:
schema=snapshots_schema,
)
- def entries(self) -> "pa.Table":
+ def entries(self, snapshot_id: Optional[int] = None) -> "pa.Table":
import pyarrow as pa
from pyiceberg.io.pyarrow import schema_to_pyarrow
@@ -3346,64 +3358,64 @@ class InspectTable:
])
entries = []
- if snapshot := self.tbl.metadata.current_snapshot():
- for manifest in snapshot.manifests(self.tbl.io):
- for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
- column_sizes = entry.data_file.column_sizes or {}
- value_counts = entry.data_file.value_counts or {}
- null_value_counts = entry.data_file.null_value_counts or {}
- nan_value_counts = entry.data_file.nan_value_counts or {}
- lower_bounds = entry.data_file.lower_bounds or {}
- upper_bounds = entry.data_file.upper_bounds or {}
- readable_metrics = {
- schema.find_column_name(field.field_id): {
- "column_size": column_sizes.get(field.field_id),
- "value_count": value_counts.get(field.field_id),
- "null_value_count":
null_value_counts.get(field.field_id),
- "nan_value_count":
nan_value_counts.get(field.field_id),
- # Makes them readable
- "lower_bound": from_bytes(field.field_type,
lower_bound)
- if (lower_bound :=
lower_bounds.get(field.field_id))
- else None,
- "upper_bound": from_bytes(field.field_type,
upper_bound)
- if (upper_bound :=
upper_bounds.get(field.field_id))
- else None,
- }
- for field in self.tbl.metadata.schema().fields
- }
-
- partition = entry.data_file.partition
- partition_record_dict = {
- field.name: partition[pos]
- for pos, field in
enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
+ snapshot = self._get_snapshot(snapshot_id)
+ for manifest in snapshot.manifests(self.tbl.io):
+ for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
+ column_sizes = entry.data_file.column_sizes or {}
+ value_counts = entry.data_file.value_counts or {}
+ null_value_counts = entry.data_file.null_value_counts or {}
+ nan_value_counts = entry.data_file.nan_value_counts or {}
+ lower_bounds = entry.data_file.lower_bounds or {}
+ upper_bounds = entry.data_file.upper_bounds or {}
+ readable_metrics = {
+ schema.find_column_name(field.field_id): {
+ "column_size": column_sizes.get(field.field_id),
+ "value_count": value_counts.get(field.field_id),
+ "null_value_count":
null_value_counts.get(field.field_id),
+ "nan_value_count":
nan_value_counts.get(field.field_id),
+ # Makes them readable
+ "lower_bound": from_bytes(field.field_type,
lower_bound)
+ if (lower_bound := lower_bounds.get(field.field_id))
+ else None,
+ "upper_bound": from_bytes(field.field_type,
upper_bound)
+ if (upper_bound := upper_bounds.get(field.field_id))
+ else None,
}
-
- entries.append({
- 'status': entry.status.value,
- 'snapshot_id': entry.snapshot_id,
- 'sequence_number': entry.data_sequence_number,
- 'file_sequence_number': entry.file_sequence_number,
- 'data_file': {
- "content": entry.data_file.content,
- "file_path": entry.data_file.file_path,
- "file_format": entry.data_file.file_format,
- "partition": partition_record_dict,
- "record_count": entry.data_file.record_count,
- "file_size_in_bytes":
entry.data_file.file_size_in_bytes,
- "column_sizes": dict(entry.data_file.column_sizes),
- "value_counts": dict(entry.data_file.value_counts),
- "null_value_counts":
dict(entry.data_file.null_value_counts),
- "nan_value_counts":
entry.data_file.nan_value_counts,
- "lower_bounds": entry.data_file.lower_bounds,
- "upper_bounds": entry.data_file.upper_bounds,
- "key_metadata": entry.data_file.key_metadata,
- "split_offsets": entry.data_file.split_offsets,
- "equality_ids": entry.data_file.equality_ids,
- "sort_order_id": entry.data_file.sort_order_id,
- "spec_id": entry.data_file.spec_id,
- },
- 'readable_metrics': readable_metrics,
- })
+ for field in self.tbl.metadata.schema().fields
+ }
+
+ partition = entry.data_file.partition
+ partition_record_dict = {
+ field.name: partition[pos]
+ for pos, field in
enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
+ }
+
+ entries.append({
+ 'status': entry.status.value,
+ 'snapshot_id': entry.snapshot_id,
+ 'sequence_number': entry.data_sequence_number,
+ 'file_sequence_number': entry.file_sequence_number,
+ 'data_file': {
+ "content": entry.data_file.content,
+ "file_path": entry.data_file.file_path,
+ "file_format": entry.data_file.file_format,
+ "partition": partition_record_dict,
+ "record_count": entry.data_file.record_count,
+ "file_size_in_bytes":
entry.data_file.file_size_in_bytes,
+ "column_sizes": dict(entry.data_file.column_sizes),
+ "value_counts": dict(entry.data_file.value_counts),
+ "null_value_counts":
dict(entry.data_file.null_value_counts),
+ "nan_value_counts": entry.data_file.nan_value_counts,
+ "lower_bounds": entry.data_file.lower_bounds,
+ "upper_bounds": entry.data_file.upper_bounds,
+ "key_metadata": entry.data_file.key_metadata,
+ "split_offsets": entry.data_file.split_offsets,
+ "equality_ids": entry.data_file.equality_ids,
+ "sort_order_id": entry.data_file.sort_order_id,
+ "spec_id": entry.data_file.spec_id,
+ },
+ 'readable_metrics': readable_metrics,
+ })
return pa.Table.from_pylist(
entries,
diff --git a/tests/integration/test_inspect_table.py
b/tests/integration/test_inspect_table.py
index f2515cae..7cbfc6da 100644
--- a/tests/integration/test_inspect_table.py
+++ b/tests/integration/test_inspect_table.py
@@ -22,7 +22,7 @@ from datetime import date, datetime
import pyarrow as pa
import pytest
import pytz
-from pyspark.sql import SparkSession
+from pyspark.sql import DataFrame, SparkSession
from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
@@ -148,81 +148,85 @@ def test_inspect_entries(
# Write some data
tbl.append(arrow_table_with_null)
- df = tbl.inspect.entries()
-
- assert df.column_names == [
- 'status',
- 'snapshot_id',
- 'sequence_number',
- 'file_sequence_number',
- 'data_file',
- 'readable_metrics',
- ]
-
- # Make sure that they are filled properly
- for int_column in ['status', 'snapshot_id', 'sequence_number',
'file_sequence_number']:
- for value in df[int_column]:
- assert isinstance(value.as_py(), int)
-
- for snapshot_id in df['snapshot_id']:
- assert isinstance(snapshot_id.as_py(), int)
-
- lhs = df.to_pandas()
- rhs = spark.table(f"{identifier}.entries").toPandas()
- for column in df.column_names:
- for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
- if column == 'data_file':
- right = right.asDict(recursive=True)
- for df_column in left.keys():
- if df_column == 'partition':
- # Spark leaves out the partition if the table is
unpartitioned
- continue
-
- df_lhs = left[df_column]
- df_rhs = right[df_column]
- if isinstance(df_rhs, dict):
- # Arrow turns dicts into lists of tuple
- df_lhs = dict(df_lhs)
-
- assert df_lhs == df_rhs, f"Difference in data_file column
{df_column}: {df_lhs} != {df_rhs}"
- elif column == 'readable_metrics':
- right = right.asDict(recursive=True)
-
- assert list(left.keys()) == [
- 'bool',
- 'string',
- 'string_long',
- 'int',
- 'long',
- 'float',
- 'double',
- 'timestamp',
- 'timestamptz',
- 'date',
- 'binary',
- 'fixed',
- ]
-
- assert left.keys() == right.keys()
-
- for rm_column in left.keys():
- rm_lhs = left[rm_column]
- rm_rhs = right[rm_column]
-
- assert rm_lhs['column_size'] == rm_rhs['column_size']
- assert rm_lhs['value_count'] == rm_rhs['value_count']
- assert rm_lhs['null_value_count'] ==
rm_rhs['null_value_count']
- assert rm_lhs['nan_value_count'] ==
rm_rhs['nan_value_count']
-
- if rm_column == 'timestamptz':
- # PySpark does not correctly set the timstamptz
- rm_rhs['lower_bound'] =
rm_rhs['lower_bound'].replace(tzinfo=pytz.utc)
- rm_rhs['upper_bound'] =
rm_rhs['upper_bound'].replace(tzinfo=pytz.utc)
-
- assert rm_lhs['lower_bound'] == rm_rhs['lower_bound']
- assert rm_lhs['upper_bound'] == rm_rhs['upper_bound']
- else:
- assert left == right, f"Difference in column {column}: {left}
!= {right}"
+ def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame)
-> None:
+ assert df.column_names == [
+ 'status',
+ 'snapshot_id',
+ 'sequence_number',
+ 'file_sequence_number',
+ 'data_file',
+ 'readable_metrics',
+ ]
+
+ # Make sure that they are filled properly
+ for int_column in ['status', 'snapshot_id', 'sequence_number',
'file_sequence_number']:
+ for value in df[int_column]:
+ assert isinstance(value.as_py(), int)
+
+ for snapshot_id in df['snapshot_id']:
+ assert isinstance(snapshot_id.as_py(), int)
+
+ lhs = df.to_pandas()
+ rhs = spark_df.toPandas()
+ for column in df.column_names:
+ for left, right in zip(lhs[column].to_list(),
rhs[column].to_list()):
+ if column == 'data_file':
+ right = right.asDict(recursive=True)
+ for df_column in left.keys():
+ if df_column == 'partition':
+ # Spark leaves out the partition if the table is
unpartitioned
+ continue
+
+ df_lhs = left[df_column]
+ df_rhs = right[df_column]
+ if isinstance(df_rhs, dict):
+ # Arrow turns dicts into lists of tuple
+ df_lhs = dict(df_lhs)
+
+ assert df_lhs == df_rhs, f"Difference in data_file
column {df_column}: {df_lhs} != {df_rhs}"
+ elif column == 'readable_metrics':
+ right = right.asDict(recursive=True)
+
+ assert list(left.keys()) == [
+ 'bool',
+ 'string',
+ 'string_long',
+ 'int',
+ 'long',
+ 'float',
+ 'double',
+ 'timestamp',
+ 'timestamptz',
+ 'date',
+ 'binary',
+ 'fixed',
+ ]
+
+ assert left.keys() == right.keys()
+
+ for rm_column in left.keys():
+ rm_lhs = left[rm_column]
+ rm_rhs = right[rm_column]
+
+ assert rm_lhs['column_size'] == rm_rhs['column_size']
+ assert rm_lhs['value_count'] == rm_rhs['value_count']
+ assert rm_lhs['null_value_count'] ==
rm_rhs['null_value_count']
+ assert rm_lhs['nan_value_count'] ==
rm_rhs['nan_value_count']
+
+ if rm_column == 'timestamptz':
+ # PySpark does not correctly set the timstamptz
+ rm_rhs['lower_bound'] =
rm_rhs['lower_bound'].replace(tzinfo=pytz.utc)
+ rm_rhs['upper_bound'] =
rm_rhs['upper_bound'].replace(tzinfo=pytz.utc)
+
+ assert rm_lhs['lower_bound'] == rm_rhs['lower_bound']
+ assert rm_lhs['upper_bound'] == rm_rhs['upper_bound']
+ else:
+ assert left == right, f"Difference in column {column}:
{left} != {right}"
+
+ for snapshot in tbl.metadata.snapshots:
+ df = tbl.inspect.entries(snapshot_id=snapshot.snapshot_id)
+ spark_df = spark.sql(f"SELECT * FROM {identifier}.entries VERSION AS
OF {snapshot.snapshot_id}")
+ check_pyiceberg_df_equals_spark_df(df, spark_df)
@pytest.mark.integration