This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 990ce80e Make `add_files` to support `snapshot_properties` argument 
(#695)
990ce80e is described below

commit 990ce80ed937fa1db092e1aac2b0e87aecf34d84
Author: Maksym Shalenyi <[email protected]>
AuthorDate: Tue May 7 09:46:02 2024 -0700

    Make `add_files` to support `snapshot_properties` argument (#695)
---
 pyiceberg/table/__init__.py         |  8 ++++----
 tests/integration/test_add_files.py | 40 +++++++++++++++++++++++++++++--------
 2 files changed, 36 insertions(+), 12 deletions(-)

diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 13186c42..5b7d04b5 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -443,7 +443,7 @@ class Transaction:
                 for data_file in data_files:
                     update_snapshot.append_data_file(data_file)
 
-    def add_files(self, file_paths: List[str]) -> None:
+    def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, 
str] = EMPTY_DICT) -> None:
         """
         Shorthand API for adding files as data files to the table transaction.
 
@@ -455,7 +455,7 @@ class Transaction:
         """
         if self._table.name_mapping() is None:
             self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: 
self._table.schema().name_mapping.model_dump_json()})
-        with self.update_snapshot().fast_append() as update_snapshot:
+        with 
self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as 
update_snapshot:
             data_files = _parquet_files_to_data_files(
                 table_metadata=self._table.metadata, file_paths=file_paths, 
io=self._table.io
             )
@@ -1341,7 +1341,7 @@ class Table:
         with self.transaction() as tx:
             tx.overwrite(df=df, overwrite_filter=overwrite_filter, 
snapshot_properties=snapshot_properties)
 
-    def add_files(self, file_paths: List[str]) -> None:
+    def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, 
str] = EMPTY_DICT) -> None:
         """
         Shorthand API for adding files as data files to the table.
 
@@ -1352,7 +1352,7 @@ class Table:
             FileNotFoundError: If the file does not exist.
         """
         with self.transaction() as tx:
-            tx.add_files(file_paths=file_paths)
+            tx.add_files(file_paths=file_paths, 
snapshot_properties=snapshot_properties)
 
     def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
         return UpdateSpec(Transaction(self, autocommit=True), 
case_sensitive=case_sensitive)
diff --git a/tests/integration/test_add_files.py 
b/tests/integration/test_add_files.py
index 0de5d5f4..94c73918 100644
--- a/tests/integration/test_add_files.py
+++ b/tests/integration/test_add_files.py
@@ -17,7 +17,7 @@
 # pylint:disable=redefined-outer-name
 
 from datetime import date
-from typing import Optional
+from typing import Iterator, Optional
 
 import pyarrow as pa
 import pyarrow.parquet as pq
@@ -122,8 +122,13 @@ def _create_table(
     return tbl
 
 
[email protected](name="format_version", params=[pytest.param(1, 
id="format_version=1"), pytest.param(2, id="format_version=2")])
+def format_version_fixure(request: pytest.FixtureRequest) -> Iterator[int]:
+    """Fixture to run tests with different table format versions."""
+    yield request.param
+
+
 @pytest.mark.integration
[email protected]("format_version", [1, 2])
 def test_add_files_to_unpartitioned_table(spark: SparkSession, 
session_catalog: Catalog, format_version: int) -> None:
     identifier = f"default.unpartitioned_table_v{format_version}"
     tbl = _create_table(session_catalog, identifier, format_version)
@@ -163,7 +168,6 @@ def test_add_files_to_unpartitioned_table(spark: 
SparkSession, session_catalog:
 
 
 @pytest.mark.integration
[email protected]("format_version", [1, 2])
 def test_add_files_to_unpartitioned_table_raises_file_not_found(
     spark: SparkSession, session_catalog: Catalog, format_version: int
 ) -> None:
@@ -184,7 +188,6 @@ def 
test_add_files_to_unpartitioned_table_raises_file_not_found(
 
 
 @pytest.mark.integration
[email protected]("format_version", [1, 2])
 def test_add_files_to_unpartitioned_table_raises_has_field_ids(
     spark: SparkSession, session_catalog: Catalog, format_version: int
 ) -> None:
@@ -205,7 +208,6 @@ def 
test_add_files_to_unpartitioned_table_raises_has_field_ids(
 
 
 @pytest.mark.integration
[email protected]("format_version", [1, 2])
 def test_add_files_to_unpartitioned_table_with_schema_updates(
     spark: SparkSession, session_catalog: Catalog, format_version: int
 ) -> None:
@@ -263,7 +265,6 @@ def 
test_add_files_to_unpartitioned_table_with_schema_updates(
 
 
 @pytest.mark.integration
[email protected]("format_version", [1, 2])
 def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: 
Catalog, format_version: int) -> None:
     identifier = f"default.partitioned_table_v{format_version}"
 
@@ -335,7 +336,6 @@ def test_add_files_to_partitioned_table(spark: 
SparkSession, session_catalog: Ca
 
 
 @pytest.mark.integration
[email protected]("format_version", [1, 2])
 def test_add_files_to_bucket_partitioned_table_fails(spark: SparkSession, 
session_catalog: Catalog, format_version: int) -> None:
     identifier = f"default.partitioned_table_bucket_fails_v{format_version}"
 
@@ -378,7 +378,6 @@ def test_add_files_to_bucket_partitioned_table_fails(spark: 
SparkSession, sessio
 
 
 @pytest.mark.integration
[email protected]("format_version", [1, 2])
 def test_add_files_to_partitioned_table_fails_with_lower_and_upper_mismatch(
     spark: SparkSession, session_catalog: Catalog, format_version: int
 ) -> None:
@@ -424,3 +423,28 @@ def 
test_add_files_to_partitioned_table_fails_with_lower_and_upper_mismatch(
         "Cannot infer partition value from parquet metadata as there are more 
than one partition values for Partition Field: baz. lower_value=123, 
upper_value=124"
         in str(exc_info.value)
     )
+
+
[email protected]
+def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: 
Catalog, format_version: int) -> None:
+    identifier = f"default.unpartitioned_table_v{format_version}"
+    tbl = _create_table(session_catalog, identifier, format_version)
+
+    file_paths = 
[f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for 
i in range(5)]
+    # write parquet files
+    for file_path in file_paths:
+        fo = tbl.io.new_output(file_path)
+        with fo.create(overwrite=True) as fos:
+            with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer:
+                writer.write_table(ARROW_TABLE)
+
+    # add the parquet files as data files
+    tbl.add_files(file_paths=file_paths, 
snapshot_properties={"snapshot_prop_a": "test_prop_a"})
+
+    # NameMapping must have been set to enable reads
+    assert tbl.name_mapping() is not None
+
+    summary = spark.sql(f"SELECT * FROM 
{identifier}.snapshots;").collect()[0].summary
+
+    assert "snapshot_prop_a" in summary
+    assert summary["snapshot_prop_a"] == "test_prop_a"

Reply via email to