This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new d69407ce Move writes to the transaction class (#571)
d69407ce is described below
commit d69407ceed545e72c24643e2abcd2d6ec335c26c
Author: Sung Yun <[email protected]>
AuthorDate: Thu Apr 4 03:18:05 2024 -0400
Move writes to the transaction class (#571)
---
pyiceberg/table/__init__.py | 160 ++++++++++++++++++++++++---------------
tests/integration/test_writes.py | 28 +++++++
2 files changed, 127 insertions(+), 61 deletions(-)
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 4e968eb6..0f113f3b 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -356,6 +356,100 @@ class Transaction:
"""
return UpdateSnapshot(self, io=self._table.io,
snapshot_properties=snapshot_properties)
+ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] =
EMPTY_DICT) -> None:
+ """
+ Shorthand API for appending a PyArrow table to a table transaction.
+
+ Args:
+ df: The Arrow dataframe that will be appended to overwrite the
table
+ snapshot_properties: Custom properties to be added to the snapshot
summary
+ """
+ try:
+ import pyarrow as pa
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("For writes PyArrow needs to be
installed") from e
+
+ if not isinstance(df, pa.Table):
+ raise ValueError(f"Expected PyArrow table, got: {df}")
+
+ if len(self._table.spec().fields) > 0:
+ raise ValueError("Cannot write to partitioned tables")
+
+ _check_schema_compatible(self._table.schema(), other_schema=df.schema)
+ # cast if the two schemas are compatible but not equal
+ table_arrow_schema = self._table.schema().as_arrow()
+ if table_arrow_schema != df.schema:
+ df = df.cast(table_arrow_schema)
+
+ with
self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as
update_snapshot:
+ # skip writing data files if the dataframe is empty
+ if df.shape[0] > 0:
+ data_files = _dataframe_to_data_files(
+ table_metadata=self._table.metadata,
write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
+ )
+ for data_file in data_files:
+ update_snapshot.append_data_file(data_file)
+
+ def overwrite(
+ self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT
+ ) -> None:
+ """
+ Shorthand for adding a table overwrite with a PyArrow table to the
transaction.
+
+ Args:
+ df: The Arrow dataframe that will be used to overwrite the table
+ overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
+ or a boolean expression in case of a partial
overwrite
+ snapshot_properties: Custom properties to be added to the snapshot
summary
+ """
+ try:
+ import pyarrow as pa
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError("For writes PyArrow needs to be
installed") from e
+
+ if not isinstance(df, pa.Table):
+ raise ValueError(f"Expected PyArrow table, got: {df}")
+
+ if overwrite_filter != AlwaysTrue():
+ raise NotImplementedError("Cannot overwrite a subset of a table")
+
+ if len(self._table.spec().fields) > 0:
+ raise ValueError("Cannot write to partitioned tables")
+
+ _check_schema_compatible(self._table.schema(), other_schema=df.schema)
+ # cast if the two schemas are compatible but not equal
+ table_arrow_schema = self._table.schema().as_arrow()
+ if table_arrow_schema != df.schema:
+ df = df.cast(table_arrow_schema)
+
+ with
self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as
update_snapshot:
+ # skip writing data files if the dataframe is empty
+ if df.shape[0] > 0:
+ data_files = _dataframe_to_data_files(
+ table_metadata=self._table.metadata,
write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
+ )
+ for data_file in data_files:
+ update_snapshot.append_data_file(data_file)
+
+ def add_files(self, file_paths: List[str]) -> None:
+ """
+ Shorthand API for adding files as data files to the table transaction.
+
+ Args:
+ file_paths: The list of full file paths to be added as data files
to the table
+
+ Raises:
+ FileNotFoundError: If the file does not exist.
+ """
+ if self._table.name_mapping() is None:
+ self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING:
self._table.schema().name_mapping.model_dump_json()})
+ with self.update_snapshot().fast_append() as update_snapshot:
+ data_files = _parquet_files_to_data_files(
+ table_metadata=self._table.metadata, file_paths=file_paths,
io=self._table.io
+ )
+ for data_file in data_files:
+ update_snapshot.append_data_file(data_file)
+
def update_spec(self) -> UpdateSpec:
"""Create a new UpdateSpec to update the partitioning of the table.
@@ -1219,32 +1313,8 @@ class Table:
df: The Arrow dataframe that will be appended to overwrite the
table
snapshot_properties: Custom properties to be added to the snapshot
summary
"""
- try:
- import pyarrow as pa
- except ModuleNotFoundError as e:
- raise ModuleNotFoundError("For writes PyArrow needs to be
installed") from e
-
- if not isinstance(df, pa.Table):
- raise ValueError(f"Expected PyArrow table, got: {df}")
-
- if len(self.spec().fields) > 0:
- raise ValueError("Cannot write to partitioned tables")
-
- _check_schema_compatible(self.schema(), other_schema=df.schema)
- # cast if the two schemas are compatible but not equal
- table_arrow_schema = self.schema().as_arrow()
- if table_arrow_schema != df.schema:
- df = df.cast(table_arrow_schema)
-
- with self.transaction() as txn:
- with
txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as
update_snapshot:
- # skip writing data files if the dataframe is empty
- if df.shape[0] > 0:
- data_files = _dataframe_to_data_files(
- table_metadata=self.metadata,
write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
- )
- for data_file in data_files:
- update_snapshot.append_data_file(data_file)
+ with self.transaction() as tx:
+ tx.append(df=df, snapshot_properties=snapshot_properties)
def overwrite(
self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE,
snapshot_properties: Dict[str, str] = EMPTY_DICT
@@ -1258,35 +1328,8 @@ class Table:
or a boolean expression in case of a partial
overwrite
snapshot_properties: Custom properties to be added to the snapshot
summary
"""
- try:
- import pyarrow as pa
- except ModuleNotFoundError as e:
- raise ModuleNotFoundError("For writes PyArrow needs to be
installed") from e
-
- if not isinstance(df, pa.Table):
- raise ValueError(f"Expected PyArrow table, got: {df}")
-
- if overwrite_filter != AlwaysTrue():
- raise NotImplementedError("Cannot overwrite a subset of a table")
-
- if len(self.spec().fields) > 0:
- raise ValueError("Cannot write to partitioned tables")
-
- _check_schema_compatible(self.schema(), other_schema=df.schema)
- # cast if the two schemas are compatible but not equal
- table_arrow_schema = self.schema().as_arrow()
- if table_arrow_schema != df.schema:
- df = df.cast(table_arrow_schema)
-
- with self.transaction() as txn:
- with
txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as
update_snapshot:
- # skip writing data files if the dataframe is empty
- if df.shape[0] > 0:
- data_files = _dataframe_to_data_files(
- table_metadata=self.metadata,
write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
- )
- for data_file in data_files:
- update_snapshot.append_data_file(data_file)
+ with self.transaction() as tx:
+ tx.overwrite(df=df, overwrite_filter=overwrite_filter,
snapshot_properties=snapshot_properties)
def add_files(self, file_paths: List[str]) -> None:
"""
@@ -1299,12 +1342,7 @@ class Table:
FileNotFoundError: If the file does not exist.
"""
with self.transaction() as tx:
- if self.name_mapping() is None:
- tx.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING:
self.schema().name_mapping.model_dump_json()})
- with tx.update_snapshot().fast_append() as update_snapshot:
- data_files =
_parquet_files_to_data_files(table_metadata=self.metadata,
file_paths=file_paths, io=self.io)
- for data_file in data_files:
- update_snapshot.append_data_file(data_file)
+ tx.add_files(file_paths=file_paths)
def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
return UpdateSpec(Transaction(self, autocommit=True),
case_sensitive=case_sensitive)
diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py
index e8ad6b08..77567023 100644
--- a/tests/integration/test_writes.py
+++ b/tests/integration/test_writes.py
@@ -832,3 +832,31 @@ def test_inspect_snapshots(
continue
assert left == right, f"Difference in column {column}: {left} !=
{right}"
+
+
[email protected]
+def test_write_within_transaction(spark: SparkSession, session_catalog:
Catalog, arrow_table_with_null: pa.Table) -> None:
+ identifier = "default.write_in_open_transaction"
+ tbl = _create_table(session_catalog, identifier, {"format-version": "1"},
[])
+
+ def get_metadata_entries_count(identifier: str) -> int:
+ return spark.sql(
+ f"""
+ SELECT *
+ FROM {identifier}.metadata_log_entries
+ """
+ ).count()
+
+ # one metadata entry from table creation
+ assert get_metadata_entries_count(identifier) == 1
+
+ # one more metadata entry from transaction
+ with tbl.transaction() as tx:
+ tx.set_properties({"test": "1"})
+ tx.append(arrow_table_with_null)
+ assert get_metadata_entries_count(identifier) == 2
+
+ # two more metadata entries added from two separate transactions
+ tbl.transaction().set_properties({"test": "2"}).commit_transaction()
+ tbl.append(arrow_table_with_null)
+ assert get_metadata_entries_count(identifier) == 4