This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new d69407ce Move writes to the transaction class (#571)
d69407ce is described below

commit d69407ceed545e72c24643e2abcd2d6ec335c26c
Author: Sung Yun <[email protected]>
AuthorDate: Thu Apr 4 03:18:05 2024 -0400

    Move writes to the transaction class (#571)
---
 pyiceberg/table/__init__.py      | 160 ++++++++++++++++++++++++---------------
 tests/integration/test_writes.py |  28 +++++++
 2 files changed, 127 insertions(+), 61 deletions(-)

diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 4e968eb6..0f113f3b 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -356,6 +356,100 @@ class Transaction:
         """
         return UpdateSnapshot(self, io=self._table.io, 
snapshot_properties=snapshot_properties)
 
+    def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = 
EMPTY_DICT) -> None:
+        """
+        Shorthand API for appending a PyArrow table to a table transaction.
+
+        Args:
+            df: The Arrow dataframe that will be appended to overwrite the 
table
+            snapshot_properties: Custom properties to be added to the snapshot 
summary
+        """
+        try:
+            import pyarrow as pa
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError("For writes PyArrow needs to be 
installed") from e
+
+        if not isinstance(df, pa.Table):
+            raise ValueError(f"Expected PyArrow table, got: {df}")
+
+        if len(self._table.spec().fields) > 0:
+            raise ValueError("Cannot write to partitioned tables")
+
+        _check_schema_compatible(self._table.schema(), other_schema=df.schema)
+        # cast if the two schemas are compatible but not equal
+        table_arrow_schema = self._table.schema().as_arrow()
+        if table_arrow_schema != df.schema:
+            df = df.cast(table_arrow_schema)
+
+        with 
self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as 
update_snapshot:
+            # skip writing data files if the dataframe is empty
+            if df.shape[0] > 0:
+                data_files = _dataframe_to_data_files(
+                    table_metadata=self._table.metadata, 
write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
+                )
+                for data_file in data_files:
+                    update_snapshot.append_data_file(data_file)
+
+    def overwrite(
+        self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, 
snapshot_properties: Dict[str, str] = EMPTY_DICT
+    ) -> None:
+        """
+        Shorthand for adding a table overwrite with a PyArrow table to the 
transaction.
+
+        Args:
+            df: The Arrow dataframe that will be used to overwrite the table
+            overwrite_filter: ALWAYS_TRUE when you overwrite all the data,
+                              or a boolean expression in case of a partial 
overwrite
+            snapshot_properties: Custom properties to be added to the snapshot 
summary
+        """
+        try:
+            import pyarrow as pa
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError("For writes PyArrow needs to be 
installed") from e
+
+        if not isinstance(df, pa.Table):
+            raise ValueError(f"Expected PyArrow table, got: {df}")
+
+        if overwrite_filter != AlwaysTrue():
+            raise NotImplementedError("Cannot overwrite a subset of a table")
+
+        if len(self._table.spec().fields) > 0:
+            raise ValueError("Cannot write to partitioned tables")
+
+        _check_schema_compatible(self._table.schema(), other_schema=df.schema)
+        # cast if the two schemas are compatible but not equal
+        table_arrow_schema = self._table.schema().as_arrow()
+        if table_arrow_schema != df.schema:
+            df = df.cast(table_arrow_schema)
+
+        with 
self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as 
update_snapshot:
+            # skip writing data files if the dataframe is empty
+            if df.shape[0] > 0:
+                data_files = _dataframe_to_data_files(
+                    table_metadata=self._table.metadata, 
write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
+                )
+                for data_file in data_files:
+                    update_snapshot.append_data_file(data_file)
+
+    def add_files(self, file_paths: List[str]) -> None:
+        """
+        Shorthand API for adding files as data files to the table transaction.
+
+        Args:
+            file_paths: The list of full file paths to be added as data files 
to the table
+
+        Raises:
+            FileNotFoundError: If the file does not exist.
+        """
+        if self._table.name_mapping() is None:
+            self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: 
self._table.schema().name_mapping.model_dump_json()})
+        with self.update_snapshot().fast_append() as update_snapshot:
+            data_files = _parquet_files_to_data_files(
+                table_metadata=self._table.metadata, file_paths=file_paths, 
io=self._table.io
+            )
+            for data_file in data_files:
+                update_snapshot.append_data_file(data_file)
+
     def update_spec(self) -> UpdateSpec:
         """Create a new UpdateSpec to update the partitioning of the table.
 
@@ -1219,32 +1313,8 @@ class Table:
             df: The Arrow dataframe that will be appended to overwrite the 
table
             snapshot_properties: Custom properties to be added to the snapshot 
summary
         """
-        try:
-            import pyarrow as pa
-        except ModuleNotFoundError as e:
-            raise ModuleNotFoundError("For writes PyArrow needs to be 
installed") from e
-
-        if not isinstance(df, pa.Table):
-            raise ValueError(f"Expected PyArrow table, got: {df}")
-
-        if len(self.spec().fields) > 0:
-            raise ValueError("Cannot write to partitioned tables")
-
-        _check_schema_compatible(self.schema(), other_schema=df.schema)
-        # cast if the two schemas are compatible but not equal
-        table_arrow_schema = self.schema().as_arrow()
-        if table_arrow_schema != df.schema:
-            df = df.cast(table_arrow_schema)
-
-        with self.transaction() as txn:
-            with 
txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as 
update_snapshot:
-                # skip writing data files if the dataframe is empty
-                if df.shape[0] > 0:
-                    data_files = _dataframe_to_data_files(
-                        table_metadata=self.metadata, 
write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
-                    )
-                    for data_file in data_files:
-                        update_snapshot.append_data_file(data_file)
+        with self.transaction() as tx:
+            tx.append(df=df, snapshot_properties=snapshot_properties)
 
     def overwrite(
         self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE, 
snapshot_properties: Dict[str, str] = EMPTY_DICT
@@ -1258,35 +1328,8 @@ class Table:
                               or a boolean expression in case of a partial 
overwrite
             snapshot_properties: Custom properties to be added to the snapshot 
summary
         """
-        try:
-            import pyarrow as pa
-        except ModuleNotFoundError as e:
-            raise ModuleNotFoundError("For writes PyArrow needs to be 
installed") from e
-
-        if not isinstance(df, pa.Table):
-            raise ValueError(f"Expected PyArrow table, got: {df}")
-
-        if overwrite_filter != AlwaysTrue():
-            raise NotImplementedError("Cannot overwrite a subset of a table")
-
-        if len(self.spec().fields) > 0:
-            raise ValueError("Cannot write to partitioned tables")
-
-        _check_schema_compatible(self.schema(), other_schema=df.schema)
-        # cast if the two schemas are compatible but not equal
-        table_arrow_schema = self.schema().as_arrow()
-        if table_arrow_schema != df.schema:
-            df = df.cast(table_arrow_schema)
-
-        with self.transaction() as txn:
-            with 
txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as 
update_snapshot:
-                # skip writing data files if the dataframe is empty
-                if df.shape[0] > 0:
-                    data_files = _dataframe_to_data_files(
-                        table_metadata=self.metadata, 
write_uuid=update_snapshot.commit_uuid, df=df, io=self.io
-                    )
-                    for data_file in data_files:
-                        update_snapshot.append_data_file(data_file)
+        with self.transaction() as tx:
+            tx.overwrite(df=df, overwrite_filter=overwrite_filter, 
snapshot_properties=snapshot_properties)
 
     def add_files(self, file_paths: List[str]) -> None:
         """
@@ -1299,12 +1342,7 @@ class Table:
             FileNotFoundError: If the file does not exist.
         """
         with self.transaction() as tx:
-            if self.name_mapping() is None:
-                tx.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: 
self.schema().name_mapping.model_dump_json()})
-            with tx.update_snapshot().fast_append() as update_snapshot:
-                data_files = 
_parquet_files_to_data_files(table_metadata=self.metadata, 
file_paths=file_paths, io=self.io)
-                for data_file in data_files:
-                    update_snapshot.append_data_file(data_file)
+            tx.add_files(file_paths=file_paths)
 
     def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
         return UpdateSpec(Transaction(self, autocommit=True), 
case_sensitive=case_sensitive)
diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py
index e8ad6b08..77567023 100644
--- a/tests/integration/test_writes.py
+++ b/tests/integration/test_writes.py
@@ -832,3 +832,31 @@ def test_inspect_snapshots(
                 continue
 
             assert left == right, f"Difference in column {column}: {left} != 
{right}"
+
+
[email protected]
+def test_write_within_transaction(spark: SparkSession, session_catalog: 
Catalog, arrow_table_with_null: pa.Table) -> None:
+    identifier = "default.write_in_open_transaction"
+    tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, 
[])
+
+    def get_metadata_entries_count(identifier: str) -> int:
+        return spark.sql(
+            f"""
+            SELECT *
+            FROM {identifier}.metadata_log_entries
+        """
+        ).count()
+
+    # one metadata entry from table creation
+    assert get_metadata_entries_count(identifier) == 1
+
+    # one more metadata entry from transaction
+    with tbl.transaction() as tx:
+        tx.set_properties({"test": "1"})
+        tx.append(arrow_table_with_null)
+    assert get_metadata_entries_count(identifier) == 2
+
+    # two more metadata entries added from two separate transactions
+    tbl.transaction().set_properties({"test": "2"}).commit_transaction()
+    tbl.append(arrow_table_with_null)
+    assert get_metadata_entries_count(identifier) == 4

Reply via email to