This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 990ce80e Make `add_files` to support `snapshot_properties` argument
(#695)
990ce80e is described below
commit 990ce80ed937fa1db092e1aac2b0e87aecf34d84
Author: Maksym Shalenyi <[email protected]>
AuthorDate: Tue May 7 09:46:02 2024 -0700
Make `add_files` to support `snapshot_properties` argument (#695)
---
pyiceberg/table/__init__.py | 8 ++++----
tests/integration/test_add_files.py | 40 +++++++++++++++++++++++++++++--------
2 files changed, 36 insertions(+), 12 deletions(-)
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 13186c42..5b7d04b5 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -443,7 +443,7 @@ class Transaction:
for data_file in data_files:
update_snapshot.append_data_file(data_file)
- def add_files(self, file_paths: List[str]) -> None:
+ def add_files(self, file_paths: List[str], snapshot_properties: Dict[str,
str] = EMPTY_DICT) -> None:
"""
Shorthand API for adding files as data files to the table transaction.
@@ -455,7 +455,7 @@ class Transaction:
"""
if self._table.name_mapping() is None:
self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING:
self._table.schema().name_mapping.model_dump_json()})
- with self.update_snapshot().fast_append() as update_snapshot:
+ with
self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as
update_snapshot:
data_files = _parquet_files_to_data_files(
table_metadata=self._table.metadata, file_paths=file_paths,
io=self._table.io
)
@@ -1341,7 +1341,7 @@ class Table:
with self.transaction() as tx:
tx.overwrite(df=df, overwrite_filter=overwrite_filter,
snapshot_properties=snapshot_properties)
- def add_files(self, file_paths: List[str]) -> None:
+ def add_files(self, file_paths: List[str], snapshot_properties: Dict[str,
str] = EMPTY_DICT) -> None:
"""
Shorthand API for adding files as data files to the table.
@@ -1352,7 +1352,7 @@ class Table:
FileNotFoundError: If the file does not exist.
"""
with self.transaction() as tx:
- tx.add_files(file_paths=file_paths)
+ tx.add_files(file_paths=file_paths,
snapshot_properties=snapshot_properties)
def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
return UpdateSpec(Transaction(self, autocommit=True),
case_sensitive=case_sensitive)
diff --git a/tests/integration/test_add_files.py
b/tests/integration/test_add_files.py
index 0de5d5f4..94c73918 100644
--- a/tests/integration/test_add_files.py
+++ b/tests/integration/test_add_files.py
@@ -17,7 +17,7 @@
# pylint:disable=redefined-outer-name
from datetime import date
-from typing import Optional
+from typing import Iterator, Optional
import pyarrow as pa
import pyarrow.parquet as pq
@@ -122,8 +122,13 @@ def _create_table(
return tbl
[email protected](name="format_version", params=[pytest.param(1,
id="format_version=1"), pytest.param(2, id="format_version=2")])
+def format_version_fixure(request: pytest.FixtureRequest) -> Iterator[int]:
+ """Fixture to run tests with different table format versions."""
+ yield request.param
+
+
@pytest.mark.integration
[email protected]("format_version", [1, 2])
def test_add_files_to_unpartitioned_table(spark: SparkSession,
session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.unpartitioned_table_v{format_version}"
tbl = _create_table(session_catalog, identifier, format_version)
@@ -163,7 +168,6 @@ def test_add_files_to_unpartitioned_table(spark:
SparkSession, session_catalog:
@pytest.mark.integration
[email protected]("format_version", [1, 2])
def test_add_files_to_unpartitioned_table_raises_file_not_found(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
@@ -184,7 +188,6 @@ def
test_add_files_to_unpartitioned_table_raises_file_not_found(
@pytest.mark.integration
[email protected]("format_version", [1, 2])
def test_add_files_to_unpartitioned_table_raises_has_field_ids(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
@@ -205,7 +208,6 @@ def
test_add_files_to_unpartitioned_table_raises_has_field_ids(
@pytest.mark.integration
[email protected]("format_version", [1, 2])
def test_add_files_to_unpartitioned_table_with_schema_updates(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
@@ -263,7 +265,6 @@ def
test_add_files_to_unpartitioned_table_with_schema_updates(
@pytest.mark.integration
[email protected]("format_version", [1, 2])
def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog:
Catalog, format_version: int) -> None:
identifier = f"default.partitioned_table_v{format_version}"
@@ -335,7 +336,6 @@ def test_add_files_to_partitioned_table(spark:
SparkSession, session_catalog: Ca
@pytest.mark.integration
[email protected]("format_version", [1, 2])
def test_add_files_to_bucket_partitioned_table_fails(spark: SparkSession,
session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.partitioned_table_bucket_fails_v{format_version}"
@@ -378,7 +378,6 @@ def test_add_files_to_bucket_partitioned_table_fails(spark:
SparkSession, sessio
@pytest.mark.integration
[email protected]("format_version", [1, 2])
def test_add_files_to_partitioned_table_fails_with_lower_and_upper_mismatch(
spark: SparkSession, session_catalog: Catalog, format_version: int
) -> None:
@@ -424,3 +423,28 @@ def
test_add_files_to_partitioned_table_fails_with_lower_and_upper_mismatch(
"Cannot infer partition value from parquet metadata as there are more
than one partition values for Partition Field: baz. lower_value=123,
upper_value=124"
in str(exc_info.value)
)
+
+
[email protected]
+def test_add_files_snapshot_properties(spark: SparkSession, session_catalog:
Catalog, format_version: int) -> None:
+ identifier = f"default.unpartitioned_table_v{format_version}"
+ tbl = _create_table(session_catalog, identifier, format_version)
+
+ file_paths =
[f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for
i in range(5)]
+ # write parquet files
+ for file_path in file_paths:
+ fo = tbl.io.new_output(file_path)
+ with fo.create(overwrite=True) as fos:
+ with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer:
+ writer.write_table(ARROW_TABLE)
+
+ # add the parquet files as data files
+ tbl.add_files(file_paths=file_paths,
snapshot_properties={"snapshot_prop_a": "test_prop_a"})
+
+ # NameMapping must have been set to enable reads
+ assert tbl.name_mapping() is not None
+
+ summary = spark.sql(f"SELECT * FROM
{identifier}.snapshots;").collect()[0].summary
+
+ assert "snapshot_prop_a" in summary
+ assert summary["snapshot_prop_a"] == "test_prop_a"