This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new f942551f Parallelize `add_files` (#1717)
f942551f is described below

commit f942551fe721b8ce8d195f09bf12f962e9bc36ac
Author: vtk9 <[email protected]>
AuthorDate: Mon Mar 3 03:22:42 2025 -0700

    Parallelize `add_files` (#1717)
    
    - `parquet_files_to_data_files` changed to `parquet_file_to_data_files`
    which processes a single parquet file and returns a `DataFile`
    - `_parquet_files_to_data_files` uses internal ExecutorFactory
    
    resolves https://github.com/apache/iceberg-python/issues/1335
---
 pyiceberg/io/pyarrow.py             | 59 +++++++++++++++++++++----------------
 pyiceberg/table/__init__.py         |  7 +++--
 tests/integration/test_add_files.py | 53 +++++++++++++++++++++++++++++++++
 3 files changed, 91 insertions(+), 28 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index f7e3c7c0..bf16ec5e 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -2466,36 +2466,43 @@ def _check_pyarrow_schema_compatible(
 
 def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, 
file_paths: Iterator[str]) -> Iterator[DataFile]:
     for file_path in file_paths:
-        input_file = io.new_input(file_path)
-        with input_file.open() as input_stream:
-            parquet_metadata = pq.read_metadata(input_stream)
+        data_file = parquet_file_to_data_file(io=io, 
table_metadata=table_metadata, file_path=file_path)
+        yield data_file
 
-        if visit_pyarrow(parquet_metadata.schema.to_arrow_schema(), _HasIds()):
-            raise NotImplementedError(
-                f"Cannot add file {file_path} because it has field IDs. 
`add_files` only supports addition of files without field_ids"
-            )
-        schema = table_metadata.schema()
-        _check_pyarrow_schema_compatible(schema, 
parquet_metadata.schema.to_arrow_schema())
 
-        statistics = data_file_statistics_from_parquet_metadata(
-            parquet_metadata=parquet_metadata,
-            stats_columns=compute_statistics_plan(schema, 
table_metadata.properties),
-            parquet_column_mapping=parquet_path_to_id_mapping(schema),
-        )
-        data_file = DataFile(
-            content=DataFileContent.DATA,
-            file_path=file_path,
-            file_format=FileFormat.PARQUET,
-            partition=statistics.partition(table_metadata.spec(), 
table_metadata.schema()),
-            file_size_in_bytes=len(input_file),
-            sort_order_id=None,
-            spec_id=table_metadata.default_spec_id,
-            equality_ids=None,
-            key_metadata=None,
-            **statistics.to_serialized_dict(),
+def parquet_file_to_data_file(io: FileIO, table_metadata: TableMetadata, 
file_path: str) -> DataFile:
+    input_file = io.new_input(file_path)
+    with input_file.open() as input_stream:
+        parquet_metadata = pq.read_metadata(input_stream)
+
+    arrow_schema = parquet_metadata.schema.to_arrow_schema()
+    if visit_pyarrow(arrow_schema, _HasIds()):
+        raise NotImplementedError(
+            f"Cannot add file {file_path} because it has field IDs. 
`add_files` only supports addition of files without field_ids"
         )
 
-        yield data_file
+    schema = table_metadata.schema()
+    _check_pyarrow_schema_compatible(schema, arrow_schema)
+
+    statistics = data_file_statistics_from_parquet_metadata(
+        parquet_metadata=parquet_metadata,
+        stats_columns=compute_statistics_plan(schema, 
table_metadata.properties),
+        parquet_column_mapping=parquet_path_to_id_mapping(schema),
+    )
+    data_file = DataFile(
+        content=DataFileContent.DATA,
+        file_path=file_path,
+        file_format=FileFormat.PARQUET,
+        partition=statistics.partition(table_metadata.spec(), 
table_metadata.schema()),
+        file_size_in_bytes=len(input_file),
+        sort_order_id=None,
+        spec_id=table_metadata.default_spec_id,
+        equality_ids=None,
+        key_metadata=None,
+        **statistics.to_serialized_dict(),
+    )
+
+    return data_file
 
 
 ICEBERG_UNCOMPRESSED_CODEC = "uncompressed"
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index ee50a8b5..a189b07c 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -1889,6 +1889,9 @@ def _parquet_files_to_data_files(table_metadata: 
TableMetadata, file_paths: List
     Returns:
         An iterable that supplies DataFiles that describe the parquet files.
     """
-    from pyiceberg.io.pyarrow import parquet_files_to_data_files
+    from pyiceberg.io.pyarrow import parquet_file_to_data_file
 
-    yield from parquet_files_to_data_files(io=io, 
table_metadata=table_metadata, file_paths=iter(file_paths))
+    executor = ExecutorFactory.get_or_create()
+    futures = [executor.submit(parquet_file_to_data_file, io, table_metadata, 
file_path) for file_path in file_paths]
+
+    return [f.result() for f in futures if f.result()]
diff --git a/tests/integration/test_add_files.py 
b/tests/integration/test_add_files.py
index 87136152..bfbc8db6 100644
--- a/tests/integration/test_add_files.py
+++ b/tests/integration/test_add_files.py
@@ -16,10 +16,13 @@
 # under the License.
 # pylint:disable=redefined-outer-name
 
+import multiprocessing
 import os
 import re
+import threading
 from datetime import date
 from typing import Iterator
+from unittest import mock
 
 import pyarrow as pa
 import pyarrow.parquet as pq
@@ -31,9 +34,11 @@ from pyiceberg.catalog import Catalog
 from pyiceberg.exceptions import NoSuchTableError
 from pyiceberg.io import FileIO
 from pyiceberg.io.pyarrow import UnsupportedPyArrowTypeException, 
_pyarrow_schema_ensure_large_types
+from pyiceberg.manifest import DataFile
 from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, 
PartitionField, PartitionSpec
 from pyiceberg.schema import Schema
 from pyiceberg.table import Table
+from pyiceberg.table.metadata import TableMetadata
 from pyiceberg.transforms import BucketTransform, IdentityTransform, 
MonthTransform
 from pyiceberg.types import (
     BooleanType,
@@ -229,6 +234,54 @@ def 
test_add_files_to_unpartitioned_table_raises_has_field_ids(
         tbl.add_files(file_paths=file_paths)
 
 
[email protected]
+def test_add_files_parallelized(spark: SparkSession, session_catalog: Catalog, 
format_version: int) -> None:
+    from pyiceberg.io.pyarrow import parquet_file_to_data_file
+
+    real_parquet_file_to_data_file = parquet_file_to_data_file
+
+    lock = threading.Lock()
+    unique_threads_seen = set()
+    cpu_count = multiprocessing.cpu_count()
+
+    # patch the function _parquet_file_to_data_file to we can track how many 
unique thread IDs
+    # it was executed from
+    with mock.patch("pyiceberg.io.pyarrow.parquet_file_to_data_file") as 
patch_func:
+
+        def mock_parquet_file_to_data_file(io: FileIO, table_metadata: 
TableMetadata, file_path: str) -> DataFile:
+            lock.acquire()
+            thread_id = threading.get_ident()  # the current thread ID
+            unique_threads_seen.add(thread_id)
+            lock.release()
+            return real_parquet_file_to_data_file(io=io, 
table_metadata=table_metadata, file_path=file_path)
+
+        patch_func.side_effect = mock_parquet_file_to_data_file
+
+        identifier = 
f"default.unpartitioned_table_schema_updates_v{format_version}"
+        tbl = _create_table(session_catalog, identifier, format_version)
+
+        file_paths = [
+            
f"s3://warehouse/default/add_files_parallel/v{format_version}/test-{i}.parquet" 
for i in range(cpu_count * 2)
+        ]
+        # write parquet files
+        for file_path in file_paths:
+            fo = tbl.io.new_output(file_path)
+            with fo.create(overwrite=True) as fos:
+                with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer:
+                    writer.write_table(ARROW_TABLE)
+
+        tbl.add_files(file_paths=file_paths)
+
+    # duration creation of threadpool processor, when max_workers is not
+    # specified, python will add cpu_count + 4 as the number of threads in the
+    # pool in this case
+    # 
https://github.com/python/cpython/blob/e06bebb87e1b33f7251196e1ddb566f528c3fc98/Lib/concurrent/futures/thread.py#L173-L181
+    # we check that we have at least seen the number of threads. we don't
+    # specify the workers in the thread pool and we can't check without
+    # accessing private attributes of ThreadPoolExecutor
+    assert len(unique_threads_seen) >= cpu_count
+
+
 @pytest.mark.integration
 def test_add_files_to_unpartitioned_table_with_schema_updates(
     spark: SparkSession, session_catalog: Catalog, format_version: int

Reply via email to