This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new eec13a66 Refactor PyArrow DataFiles Projection functions (#1043)
eec13a66 is described below

commit eec13a66570eb03d7704f6ca5df9068583dd5def
Author: Sung Yun <[email protected]>
AuthorDate: Tue Aug 20 04:24:49 2024 -0400

    Refactor PyArrow DataFiles Projection functions (#1043)
    
    * refactoring
    
    * refactor more
    
    * docstring
    
    * #1042
    
    * adopt review feedback
    
    * thanks Kevin!
    
    Co-authored-by: Kevin Liu <[email protected]>
    
    ---------
    
    Co-authored-by: Kevin Liu <[email protected]>
---
 pyiceberg/io/__init__.py    |  13 +++
 pyiceberg/io/pyarrow.py     | 197 +++++++++++++++++++++++++++++++++++++++++++-
 pyiceberg/table/__init__.py |  30 +++----
 3 files changed, 219 insertions(+), 21 deletions(-)

diff --git a/pyiceberg/io/__init__.py b/pyiceberg/io/__init__.py
index d5f26a17..be06bc04 100644
--- a/pyiceberg/io/__init__.py
+++ b/pyiceberg/io/__init__.py
@@ -27,6 +27,7 @@ from __future__ import annotations
 
 import importlib
 import logging
+import os
 import warnings
 from abc import ABC, abstractmethod
 from io import SEEK_SET
@@ -36,6 +37,7 @@ from typing import (
     List,
     Optional,
     Protocol,
+    Tuple,
     Type,
     Union,
     runtime_checkable,
@@ -356,3 +358,14 @@ def load_file_io(properties: Properties = EMPTY_DICT, 
location: Optional[str] =
         raise ModuleNotFoundError(
             'Could not load a FileIO, please consider installing one: pip3 
install "pyiceberg[pyarrow]", for more options refer to the docs.'
         ) from e
+
+
+def _parse_location(location: str) -> Tuple[str, str, str]:
+    """Return the path without the scheme."""
+    uri = urlparse(location)
+    if not uri.scheme:
+        return "file", uri.netloc, os.path.abspath(location)
+    elif uri.scheme in ("hdfs", "viewfs"):
+        return uri.scheme, uri.netloc, uri.path
+    else:
+        return uri.scheme, uri.netloc, f"{uri.netloc}{uri.path}"
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index b2cb167a..df9f9098 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -108,6 +108,7 @@ from pyiceberg.io import (
     InputStream,
     OutputFile,
     OutputStream,
+    _parse_location,
 )
 from pyiceberg.manifest import (
     DataFile,
@@ -1195,7 +1196,7 @@ def _task_to_record_batches(
     name_mapping: Optional[NameMapping] = None,
     use_large_types: bool = True,
 ) -> Iterator[pa.RecordBatch]:
-    _, _, path = PyArrowFileIO.parse_location(task.file.file_path)
+    _, _, path = _parse_location(task.file.file_path)
     arrow_format = ds.ParquetFileFormat(pre_buffer=True, 
buffer_size=(ONE_MEGABYTE * 8))
     with fs.open_input_file(path) as fin:
         fragment = arrow_format.make_fragment(fin)
@@ -1304,6 +1305,195 @@ def _read_all_delete_files(fs: FileSystem, tasks: 
Iterable[FileScanTask]) -> Dic
     return deletes_per_file
 
 
+def _fs_from_file_path(file_path: str, io: FileIO) -> FileSystem:
+    scheme, netloc, _ = _parse_location(file_path)
+    if isinstance(io, PyArrowFileIO):
+        return io.fs_by_scheme(scheme, netloc)
+    else:
+        try:
+            from pyiceberg.io.fsspec import FsspecFileIO
+
+            if isinstance(io, FsspecFileIO):
+                from pyarrow.fs import PyFileSystem
+
+                return PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
+            else:
+                raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, 
got: {io}")
+        except ModuleNotFoundError as e:
+            # When FsSpec is not installed
+            raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: 
{io}") from e
+
+
+class ArrowScan:
+    _table_metadata: TableMetadata
+    _io: FileIO
+    _fs: FileSystem
+    _projected_schema: Schema
+    _bound_row_filter: BooleanExpression
+    _case_sensitive: bool
+    _limit: Optional[int]
+    """Scan the Iceberg Table and create an Arrow construct.
+
+    Attributes:
+        _table_metadata: Current table metadata of the Iceberg table
+        _io: PyIceberg FileIO implementation from which to fetch the io 
properties
+        _fs: PyArrow FileSystem to use to read the files
+        _projected_schema: Iceberg Schema to project onto the data files
+        _bound_row_filter: Schema bound row expression to filter the data with
+        _case_sensitive: Case sensitivity when looking up column names
+        _limit: Limit the number of records.
+    """
+
+    def __init__(
+        self,
+        table_metadata: TableMetadata,
+        io: FileIO,
+        projected_schema: Schema,
+        row_filter: BooleanExpression,
+        case_sensitive: bool = True,
+        limit: Optional[int] = None,
+    ) -> None:
+        self._table_metadata = table_metadata
+        self._io = io
+        self._fs = _fs_from_file_path(table_metadata.location, io)  # TODO: 
use different FileSystem per file
+        self._projected_schema = projected_schema
+        self._bound_row_filter = bind(table_metadata.schema(), row_filter, 
case_sensitive=case_sensitive)
+        self._case_sensitive = case_sensitive
+        self._limit = limit
+
+    @property
+    def _use_large_types(self) -> bool:
+        """Whether to represent data as large arrow types.
+
+        Defaults to True.
+        """
+        return property_as_bool(self._io.properties, 
PYARROW_USE_LARGE_TYPES_ON_READ, True)
+
+    @property
+    def _projected_field_ids(self) -> Set[int]:
+        """Set of field IDs that should be projected from the data files."""
+        return {
+            id
+            for id in self._projected_schema.field_ids
+            if not isinstance(self._projected_schema.find_type(id), (MapType, 
ListType))
+        }.union(extract_field_ids(self._bound_row_filter))
+
+    def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
+        """Scan the Iceberg table and return a pa.Table.
+
+        Returns a pa.Table with data from the Iceberg table by resolving the
+        right columns that match the current table schema. Only data that
+        matches the provided row_filter expression is returned.
+
+        Args:
+            tasks: FileScanTasks representing the data files and delete files 
to read from.
+
+        Returns:
+            A PyArrow table. Total number of rows will be capped if specified.
+
+        Raises:
+            ResolveError: When a required field cannot be found in the file
+            ValueError: When a field type in the file cannot be projected to 
the schema type
+        """
+        deletes_per_file = _read_all_delete_files(self._fs, tasks)
+        executor = ExecutorFactory.get_or_create()
+
+        def _table_from_scan_task(task: FileScanTask) -> pa.Table:
+            batches = 
list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file))
+            if len(batches) > 0:
+                return pa.Table.from_batches(batches)
+            else:
+                return None
+
+        futures = [
+            executor.submit(
+                _table_from_scan_task,
+                task,
+            )
+            for task in tasks
+        ]
+        total_row_count = 0
+        # for consistent ordering, we need to maintain future order
+        futures_index = {f: i for i, f in enumerate(futures)}
+        completed_futures: SortedList[Future[pa.Table]] = 
SortedList(iterable=[], key=lambda f: futures_index[f])
+        for future in concurrent.futures.as_completed(futures):
+            completed_futures.add(future)
+            if table_result := future.result():
+                total_row_count += len(table_result)
+            # stop early if limit is satisfied
+            if self._limit is not None and total_row_count >= self._limit:
+                break
+
+        # by now, we've either completed all tasks or satisfied the limit
+        if self._limit is not None:
+            _ = [f.cancel() for f in futures if not f.done()]
+
+        tables = [f.result() for f in completed_futures if f.result()]
+
+        if len(tables) < 1:
+            return pa.Table.from_batches([], 
schema=schema_to_pyarrow(self._projected_schema, include_field_ids=False))
+
+        result = pa.concat_tables(tables, promote_options="permissive")
+
+        if self._limit is not None:
+            return result.slice(0, self._limit)
+
+        return result
+
+    def to_record_batches(self, tasks: Iterable[FileScanTask]) -> 
Iterator[pa.RecordBatch]:
+        """Scan the Iceberg table and return an Iterator[pa.RecordBatch].
+
+        Returns an Iterator of pa.RecordBatch with data from the Iceberg table
+        by resolving the right columns that match the current table schema.
+        Only data that matches the provided row_filter expression is returned.
+
+        Args:
+            tasks: FileScanTasks representing the data files and delete files 
to read from.
+
+        Returns:
+            An Iterator of PyArrow RecordBatches.
+            Total number of rows will be capped if specified.
+
+        Raises:
+            ResolveError: When a required field cannot be found in the file
+            ValueError: When a field type in the file cannot be projected to 
the schema type
+        """
+        deletes_per_file = _read_all_delete_files(self._fs, tasks)
+        return self._record_batches_from_scan_tasks_and_deletes(tasks, 
deletes_per_file)
+
+    def _record_batches_from_scan_tasks_and_deletes(
+        self, tasks: Iterable[FileScanTask], deletes_per_file: Dict[str, 
List[ChunkedArray]]
+    ) -> Iterator[pa.RecordBatch]:
+        total_row_count = 0
+        for task in tasks:
+            if self._limit is not None and total_row_count >= self._limit:
+                break
+            batches = _task_to_record_batches(
+                self._fs,
+                task,
+                self._bound_row_filter,
+                self._projected_schema,
+                self._projected_field_ids,
+                deletes_per_file.get(task.file.file_path),
+                self._case_sensitive,
+                self._table_metadata.name_mapping(),
+                self._use_large_types,
+            )
+            for batch in batches:
+                if self._limit is not None:
+                    if total_row_count >= self._limit:
+                        break
+                    elif total_row_count + len(batch) >= self._limit:
+                        batch = batch.slice(0, self._limit - total_row_count)
+                yield batch
+                total_row_count += len(batch)
+
+
+@deprecated(
+    deprecated_in="0.8.0",
+    removed_in="0.9.0",
+    help_message="project_table is deprecated. Use ArrowScan.to_table 
instead.",
+)
 def project_table(
     tasks: Iterable[FileScanTask],
     table_metadata: TableMetadata,
@@ -1398,6 +1588,11 @@ def project_table(
     return result
 
 
+@deprecated(
+    deprecated_in="0.8.0",
+    removed_in="0.9.0",
+    help_message="project_table is deprecated. Use ArrowScan.to_record_batches 
instead.",
+)
 def project_batches(
     tasks: Iterable[FileScanTask],
     table_metadata: TableMetadata,
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 784e3e0a..47944831 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -2031,35 +2031,25 @@ class DataScan(TableScan):
         ]
 
     def to_arrow(self) -> pa.Table:
-        from pyiceberg.io.pyarrow import project_table
+        from pyiceberg.io.pyarrow import ArrowScan
 
-        return project_table(
-            self.plan_files(),
-            self.table_metadata,
-            self.io,
-            self.row_filter,
-            self.projection(),
-            case_sensitive=self.case_sensitive,
-            limit=self.limit,
-        )
+        return ArrowScan(
+            self.table_metadata, self.io, self.projection(), self.row_filter, 
self.case_sensitive, self.limit
+        ).to_table(self.plan_files())
 
     def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
         import pyarrow as pa
 
-        from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow
+        from pyiceberg.io.pyarrow import ArrowScan, schema_to_pyarrow
 
         target_schema = schema_to_pyarrow(self.projection())
+        batches = ArrowScan(
+            self.table_metadata, self.io, self.projection(), self.row_filter, 
self.case_sensitive, self.limit
+        ).to_record_batches(self.plan_files())
+
         return pa.RecordBatchReader.from_batches(
             target_schema,
-            project_batches(
-                self.plan_files(),
-                self.table_metadata,
-                self.io,
-                self.row_filter,
-                self.projection(),
-                case_sensitive=self.case_sensitive,
-                limit=self.limit,
-            ),
+            batches,
         )
 
     def to_pandas(self, **kwargs: Any) -> pd.DataFrame:

Reply via email to