This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new d2b7ba9d7b Python: Optimize concurrency for limited queries (#8104)
d2b7ba9d7b is described below

commit d2b7ba9d7b0b0e7ff45f213741ca9dcd81f4b549
Author: Josh Wiley <[email protected]>
AuthorDate: Wed Aug 16 06:17:00 2023 -0700

    Python: Optimize concurrency for limited queries (#8104)
    
    * feat(python): remove explicit row count lock
    
    * feat(python): stop waiting for task when limit reached
    
    * fix(python): remove unnecessary arg splat
    
    * fix(python): cancel all futures once result set acquired
    
    * fix(python): consistent scan ordering when limit applied
    
    * feat(python): reuse executor (wip)
    
    * fix(python): consolidate row count tracking
    
    * feat(python): use sortedcontainers in future ordering
    
    * feat(python): global executor
    
    * fix(python): formatting
    
    * fix(python): support limit = 0
    
    Co-authored-by: Fokko Driesprong <[email protected]>
    
    * fix(python): support limit = 0
    
    Co-authored-by: Fokko Driesprong <[email protected]>
    
    * fix(python): support limit = 0
    
    Co-authored-by: Fokko Driesprong <[email protected]>
    
    * feat(python): lazy re-usable executory
    
    * fix(python): appease rat
    
    * fix(python): use row counts container for mutability
    
    * fix(python): scan future cancelling
    
    * fix(python): data scan row count updates in worker
    
    * feat(python): simplify result agg
    
    * fix(python): reusable executor factory method name
    
    * fix(python): remove custom config error
    
    * fix(python): only slice result when limit provided
    
    ---------
    
    Co-authored-by: Fokko Driesprong <[email protected]>
---
 python/mkdocs/docs/configuration.md   |   4 ++
 python/pyiceberg/io/pyarrow.py        | 102 +++++++++++++++++++---------------
 python/pyiceberg/table/__init__.py    |  52 +++++++++--------
 python/pyiceberg/utils/concurrent.py  |  65 ++++++++--------------
 python/tests/utils/test_concurrent.py |  52 +++++++++++++++++
 5 files changed, 161 insertions(+), 114 deletions(-)

diff --git a/python/mkdocs/docs/configuration.md 
b/python/mkdocs/docs/configuration.md
index e9f50042f2..a56baff7b5 100644
--- a/python/mkdocs/docs/configuration.md
+++ b/python/mkdocs/docs/configuration.md
@@ -194,3 +194,7 @@ catalog:
     type: dynamodb
     table-name: iceberg
 ```
+
+# Concurrency
+
+PyIceberg uses multiple threads to parallelize operations. The number of 
workers can be configured by supplying a `max-workers` entry in the 
configuration file, or by setting the `PYICEBERG_MAX_WORKERS` environment 
variable. The default value depends on the system hardware and Python version. 
See [the Python 
documentation](https://docs.python.org/3/library/concurrent.futures.html#threadpoolexecutor)
 for more details.
diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py
index eb48b34500..2e33f7174c 100644
--- a/python/pyiceberg/io/pyarrow.py
+++ b/python/pyiceberg/io/pyarrow.py
@@ -24,9 +24,10 @@ with the pyarrow library.
 """
 from __future__ import annotations
 
+import concurrent.futures
 import os
 from abc import ABC, abstractmethod
-from concurrent.futures import Executor
+from concurrent.futures import Future
 from functools import lru_cache, singledispatch
 from itertools import chain
 from typing import (
@@ -63,6 +64,7 @@ from pyarrow.fs import (
     PyFileSystem,
     S3FileSystem,
 )
+from sortedcontainers import SortedList
 
 from pyiceberg.avro.resolver import ResolveError
 from pyiceberg.expressions import (
@@ -133,7 +135,7 @@ from pyiceberg.types import (
     TimeType,
     UUIDType,
 )
-from pyiceberg.utils.concurrent import ManagedThreadPoolExecutor, Synchronized
+from pyiceberg.utils.concurrent import ExecutorFactory
 from pyiceberg.utils.datetime import millis_to_datetime
 from pyiceberg.utils.singleton import Singleton
 
@@ -765,10 +767,10 @@ def _task_to_table(
     projected_field_ids: Set[int],
     positional_deletes: Optional[List[ChunkedArray]],
     case_sensitive: bool,
-    rows_counter: Synchronized[int],
+    row_counts: List[int],
     limit: Optional[int] = None,
 ) -> Optional[pa.Table]:
-    if limit and rows_counter.value >= limit:
+    if limit and sum(row_counts) >= limit:
         return None
 
     _, path = PyArrowFileIO.parse_location(task.file.file_path)
@@ -830,23 +832,22 @@ def _task_to_table(
             else:
                 arrow_table = fragment_scanner.to_table()
 
-        if limit:
-            with rows_counter:
-                if rows_counter.value >= limit:
-                    return None
-                rows_counter.value += len(arrow_table)
+        if len(arrow_table) < 1:
+            return None
 
-        # If there is no data, we don't have to go through the schema
-        if len(arrow_table) > 0:
-            return to_requested_schema(projected_schema, file_project_schema, 
arrow_table)
-        else:
+        if limit is not None and sum(row_counts) >= limit:
             return None
 
+        row_counts.append(len(arrow_table))
+
+        return to_requested_schema(projected_schema, file_project_schema, 
arrow_table)
+
 
-def _read_all_delete_files(fs: FileSystem, executor: Executor, tasks: 
Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
+def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> 
Dict[str, List[ChunkedArray]]:
     deletes_per_file: Dict[str, List[ChunkedArray]] = {}
     unique_deletes = set(chain.from_iterable([task.delete_files for task in 
tasks]))
     if len(unique_deletes) > 0:
+        executor = ExecutorFactory.get_or_create()
         deletes_per_files: Iterator[Dict[str, ChunkedArray]] = executor.map(
             lambda args: _read_deletes(*args), [(fs, delete) for delete in 
unique_deletes]
         )
@@ -902,39 +903,50 @@ def project_table(
         id for id in projected_schema.field_ids if not 
isinstance(projected_schema.find_type(id), (MapType, ListType))
     }.union(extract_field_ids(bound_row_filter))
 
-    with ManagedThreadPoolExecutor() as executor:
-        rows_counter = executor.synchronized(0)
-        deletes_per_file = _read_all_delete_files(fs, executor, tasks)
-        tables = [
-            table
-            for table in executor.map(
-                lambda args: _task_to_table(*args),
-                [
-                    (
-                        fs,
-                        task,
-                        bound_row_filter,
-                        projected_schema,
-                        projected_field_ids,
-                        deletes_per_file.get(task.file.file_path),
-                        case_sensitive,
-                        rows_counter,
-                        limit,
-                    )
-                    for task in tasks
-                ],
-            )
-            if table is not None
-        ]
+    row_counts: List[int] = []
+    deletes_per_file = _read_all_delete_files(fs, tasks)
+    executor = ExecutorFactory.get_or_create()
+    futures = [
+        executor.submit(
+            _task_to_table,
+            fs,
+            task,
+            bound_row_filter,
+            projected_schema,
+            projected_field_ids,
+            deletes_per_file.get(task.file.file_path),
+            case_sensitive,
+            row_counts,
+            limit,
+        )
+        for task in tasks
+    ]
 
-    if len(tables) > 1:
-        final_table = pa.concat_tables(tables)
-    elif len(tables) == 1:
-        final_table = tables[0]
-    else:
-        final_table = pa.Table.from_batches([], 
schema=schema_to_pyarrow(projected_schema))
+    # for consistent ordering, we need to maintain future order
+    futures_index = {f: i for i, f in enumerate(futures)}
+    completed_futures: SortedList[Future[pa.Table]] = SortedList(iterable=[], 
key=lambda f: futures_index[f])
+    for future in concurrent.futures.as_completed(futures):
+        completed_futures.add(future)
+
+        # stop early if limit is satisfied
+        if limit is not None and sum(row_counts) >= limit:
+            break
+
+    # by now, we've either completed all tasks or satisfied the limit
+    if limit is not None:
+        _ = [f.cancel() for f in futures if not f.done()]
+
+    tables = [f.result() for f in completed_futures if f.result()]
+
+    if len(tables) < 1:
+        return pa.Table.from_batches([], 
schema=schema_to_pyarrow(projected_schema))
+
+    result = pa.concat_tables(tables)
+
+    if limit is not None:
+        return result.slice(0, limit)
 
-    return final_table.slice(0, limit)
+    return result
 
 
 def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: 
pa.Table) -> pa.Table:
diff --git a/python/pyiceberg/table/__init__.py 
b/python/pyiceberg/table/__init__.py
index 5214bed01c..4ede2582fc 100644
--- a/python/pyiceberg/table/__init__.py
+++ b/python/pyiceberg/table/__init__.py
@@ -17,7 +17,6 @@
 from __future__ import annotations
 
 from abc import ABC, abstractmethod
-from concurrent.futures import ThreadPoolExecutor
 from dataclasses import dataclass
 from enum import Enum
 from functools import cached_property
@@ -70,6 +69,7 @@ from pyiceberg.typedef import (
     KeyDefaultDict,
     Properties,
 )
+from pyiceberg.utils.concurrent import ExecutorFactory
 
 if TYPE_CHECKING:
     import pandas as pd
@@ -775,33 +775,31 @@ class DataScan(TableScan):
         data_entries: List[ManifestEntry] = []
         positional_delete_entries = SortedList(key=lambda entry: 
entry.data_sequence_number or INITIAL_SEQUENCE_NUMBER)
 
-        with ThreadPoolExecutor() as executor:
-            for manifest_entry in chain(
-                *executor.map(
-                    lambda args: _open_manifest(*args),
-                    [
-                        (
-                            io,
-                            manifest,
-                            partition_evaluators[manifest.partition_spec_id],
-                            metrics_evaluator,
-                        )
-                        for manifest in manifests
-                        if 
self._check_sequence_number(min_data_sequence_number, manifest)
-                    ],
-                )
-            ):
-                data_file = manifest_entry.data_file
-                if data_file.content == DataFileContent.DATA:
-                    data_entries.append(manifest_entry)
-                elif data_file.content == DataFileContent.POSITION_DELETES:
-                    positional_delete_entries.add(manifest_entry)
-                elif data_file.content == DataFileContent.EQUALITY_DELETES:
-                    raise ValueError(
-                        "PyIceberg does not yet support equality deletes: 
https://github.com/apache/iceberg/issues/6568";
+        executor = ExecutorFactory.get_or_create()
+        for manifest_entry in chain(
+            *executor.map(
+                lambda args: _open_manifest(*args),
+                [
+                    (
+                        io,
+                        manifest,
+                        partition_evaluators[manifest.partition_spec_id],
+                        metrics_evaluator,
                     )
-                else:
-                    raise ValueError(f"Unknown DataFileContent 
({data_file.content}): {manifest_entry}")
+                    for manifest in manifests
+                    if self._check_sequence_number(min_data_sequence_number, 
manifest)
+                ],
+            )
+        ):
+            data_file = manifest_entry.data_file
+            if data_file.content == DataFileContent.DATA:
+                data_entries.append(manifest_entry)
+            elif data_file.content == DataFileContent.POSITION_DELETES:
+                positional_delete_entries.add(manifest_entry)
+            elif data_file.content == DataFileContent.EQUALITY_DELETES:
+                raise ValueError("PyIceberg does not yet support equality 
deletes: https://github.com/apache/iceberg/issues/6568";)
+            else:
+                raise ValueError(f"Unknown DataFileContent 
({data_file.content}): {manifest_entry}")
 
         return [
             FileScanTask(
diff --git a/python/pyiceberg/utils/concurrent.py 
b/python/pyiceberg/utils/concurrent.py
index a71ee3281e..f54a919d7c 100644
--- a/python/pyiceberg/utils/concurrent.py
+++ b/python/pyiceberg/utils/concurrent.py
@@ -14,54 +14,35 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=redefined-outer-name,arguments-renamed,fixme
-"""Concurrency concepts that support multi-threading."""
-import threading
+"""Concurrency concepts that support efficient multi-threading."""
 from concurrent.futures import Executor, ThreadPoolExecutor
-from contextlib import AbstractContextManager
-from typing import Any, Generic, TypeVar
+from typing import Optional
 
-from typing_extensions import Self
+from pyiceberg.utils.config import Config
 
-T = TypeVar("T")
 
+class ExecutorFactory:
+    _instance: Optional[Executor] = None
 
-class Synchronized(Generic[T], AbstractContextManager):  # type: ignore
-    """A context manager that provides concurrency-safe access to a value."""
+    @staticmethod
+    def get_or_create() -> Executor:
+        """Returns the same executor in each call."""
+        if ExecutorFactory._instance is None:
+            max_workers = ExecutorFactory.max_workers()
+            ExecutorFactory._instance = 
ThreadPoolExecutor(max_workers=max_workers)
 
-    value: T
-    lock: threading.Lock
+        return ExecutorFactory._instance
 
-    def __init__(self, value: T, lock: threading.Lock):
-        super().__init__()
-        self.value = value
-        self.lock = lock
+    @staticmethod
+    def max_workers() -> Optional[int]:
+        """Returns the max number of workers configured."""
+        config = Config()
+        val = config.config.get("max-workers")
 
-    def __enter__(self) -> T:
-        """Acquires a lock, allowing access to the wrapped value."""
-        self.lock.acquire()
-        return self.value
+        if val is None:
+            return None
 
-    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
-        """Releases the lock, allowing other threads to access the value."""
-        self.lock.release()
-
-
-class ManagedExecutor(Executor):
-    """An executor that provides synchronization."""
-
-    def synchronized(self, value: T) -> Synchronized[T]:
-        raise NotImplementedError
-
-
-class ManagedThreadPoolExecutor(ThreadPoolExecutor, ManagedExecutor):
-    """A thread pool executor that provides synchronization."""
-
-    def __enter__(self) -> Self:
-        """Returns the executor itself as a context manager."""
-        super().__enter__()
-        return self
-
-    def synchronized(self, value: T) -> Synchronized[T]:
-        lock = threading.Lock()
-        return Synchronized(value, lock)
+        try:
+            return int(val)  # type: ignore
+        except ValueError as err:
+            raise ValueError(f"Max workers should be an integer or left unset. 
Current value: {val}") from err
diff --git a/python/tests/utils/test_concurrent.py 
b/python/tests/utils/test_concurrent.py
new file mode 100644
index 0000000000..6d730cbe75
--- /dev/null
+++ b/python/tests/utils/test_concurrent.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+from concurrent.futures import ThreadPoolExecutor
+from typing import Dict, Optional
+from unittest import mock
+
+import pytest
+
+from pyiceberg.utils.concurrent import ExecutorFactory
+
+EMPTY_ENV: Dict[str, Optional[str]] = {}
+VALID_ENV = {"PYICEBERG_MAX_WORKERS": "5"}
+INVALID_ENV = {"PYICEBERG_MAX_WORKERS": "invalid"}
+
+
+def test_create_reused() -> None:
+    first = ExecutorFactory.get_or_create()
+    second = ExecutorFactory.get_or_create()
+    assert isinstance(first, ThreadPoolExecutor)
+    assert first is second
+
+
[email protected](os.environ, EMPTY_ENV)
+def test_max_workers_none() -> None:
+    assert ExecutorFactory.max_workers() is None
+
+
[email protected](os.environ, VALID_ENV)
+def test_max_workers() -> None:
+    assert ExecutorFactory.max_workers() == 5
+
+
[email protected](os.environ, INVALID_ENV)
+def test_max_workers_invalid() -> None:
+    with pytest.raises(ValueError):
+        ExecutorFactory.max_workers()

Reply via email to