This is an automated email from the ASF dual-hosted git repository.

fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new b8c5bb77 Support `Table.to_arrow_batch_reader` (#786)
b8c5bb77 is described below

commit b8c5bb77c5ea436aeced17676aa30d09c1224ed9
Author: Sung Yun <[email protected]>
AuthorDate: Fri Jun 21 09:44:24 2024 -0400

    Support `Table.to_arrow_batch_reader` (#786)
    
    * _task_to_table to _task_to_record_batches
    
    * to_arrow_batches
    
    * tests
    
    * fix
    
    * fix
    
    * deletes
    
    * batch reader
    
    * merge main
    
    * adopt review feedback
---
 mkdocs/docs/api.md              |   9 +++
 pyiceberg/io/pyarrow.py         | 155 ++++++++++++++++++++++++++++++----------
 pyiceberg/table/__init__.py     |  18 +++++
 tests/integration/test_reads.py | 126 ++++++++++++++++++++++++++++++++
 4 files changed, 269 insertions(+), 39 deletions(-)

diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md
index 6bbd9abe..54f4a20c 100644
--- a/mkdocs/docs/api.md
+++ b/mkdocs/docs/api.md
@@ -1003,6 +1003,15 @@ tpep_dropoff_datetime: [[2021-04-01 
00:47:59.000000,...,2021-05-01 00:14:47.0000
 
 This will only pull in the files that that might contain matching rows.
 
+One can also return a PyArrow RecordBatchReader, if reading one record batch 
at a time is preferred:
+
+```python
+table.scan(
+    row_filter=GreaterThanOrEqual("trip_distance", 10.0),
+    selected_fields=("VendorID", "tpep_pickup_datetime", 
"tpep_dropoff_datetime"),
+).to_arrow_batch_reader()
+```
+
 ### Pandas
 
 <!-- prettier-ignore-start -->
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 935b78ce..e6490ae1 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -655,12 +655,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) -> 
Dict[str, pa.ChunkedAr
     }
 
 
-def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], 
rows: int) -> pa.Array:
+def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], 
start_index: int, end_index: int) -> pa.Array:
     if len(positional_deletes) == 1:
         all_chunks = positional_deletes[0]
     else:
         all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in 
positional_deletes]))
-    return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)
+    return np.subtract(np.setdiff1d(np.arange(start_index, end_index), 
all_chunks, assume_unique=False), start_index)
 
 
 def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] = 
None) -> Schema:
@@ -995,7 +995,7 @@ class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
         return -1
 
 
-def _task_to_table(
+def _task_to_record_batches(
     fs: FileSystem,
     task: FileScanTask,
     bound_row_filter: BooleanExpression,
@@ -1003,9 +1003,8 @@ def _task_to_table(
     projected_field_ids: Set[int],
     positional_deletes: Optional[List[ChunkedArray]],
     case_sensitive: bool,
-    limit: Optional[int] = None,
     name_mapping: Optional[NameMapping] = None,
-) -> Optional[pa.Table]:
+) -> Iterator[pa.RecordBatch]:
     _, _, path = PyArrowFileIO.parse_location(task.file.file_path)
     arrow_format = ds.ParquetFileFormat(pre_buffer=True, 
buffer_size=(ONE_MEGABYTE * 8))
     with fs.open_input_file(path) as fin:
@@ -1035,36 +1034,39 @@ def _task_to_table(
             columns=[col.name for col in file_project_schema.columns],
         )
 
-        if positional_deletes:
-            # Create the mask of indices that we're interested in
-            indices = _combine_positional_deletes(positional_deletes, 
fragment.count_rows())
-
-            if limit:
-                if pyarrow_filter is not None:
-                    # In case of the filter, we don't exactly know how many 
rows
-                    # we need to fetch upfront, can be optimized in the future:
-                    # https://github.com/apache/arrow/issues/35301
-                    arrow_table = fragment_scanner.take(indices)
-                    arrow_table = arrow_table.filter(pyarrow_filter)
-                    arrow_table = arrow_table.slice(0, limit)
-                else:
-                    arrow_table = fragment_scanner.take(indices[0:limit])
-            else:
-                arrow_table = fragment_scanner.take(indices)
+        current_index = 0
+        batches = fragment_scanner.to_batches()
+        for batch in batches:
+            if positional_deletes:
+                # Create the mask of indices that we're interested in
+                indices = _combine_positional_deletes(positional_deletes, 
current_index, current_index + len(batch))
+                batch = batch.take(indices)
                 # Apply the user filter
                 if pyarrow_filter is not None:
+                    # we need to switch back and forth between RecordBatch and 
Table
+                    # as Expression filter isn't yet supported in RecordBatch
+                    # https://github.com/apache/arrow/issues/39220
+                    arrow_table = pa.Table.from_batches([batch])
                     arrow_table = arrow_table.filter(pyarrow_filter)
-        else:
-            # If there are no deletes, we can just take the head
-            # and the user-filter is already applied
-            if limit:
-                arrow_table = fragment_scanner.head(limit)
-            else:
-                arrow_table = fragment_scanner.to_table()
+                    batch = arrow_table.to_batches()[0]
+            yield to_requested_schema(projected_schema, file_project_schema, 
batch)
+            current_index += len(batch)
 
-        if len(arrow_table) < 1:
-            return None
-        return to_requested_schema(projected_schema, file_project_schema, 
arrow_table)
+
+def _task_to_table(
+    fs: FileSystem,
+    task: FileScanTask,
+    bound_row_filter: BooleanExpression,
+    projected_schema: Schema,
+    projected_field_ids: Set[int],
+    positional_deletes: Optional[List[ChunkedArray]],
+    case_sensitive: bool,
+    name_mapping: Optional[NameMapping] = None,
+) -> pa.Table:
+    batches = _task_to_record_batches(
+        fs, task, bound_row_filter, projected_schema, projected_field_ids, 
positional_deletes, case_sensitive, name_mapping
+    )
+    return pa.Table.from_batches(batches, 
schema=schema_to_pyarrow(projected_schema, include_field_ids=False))
 
 
 def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> 
Dict[str, List[ChunkedArray]]:
@@ -1143,7 +1145,6 @@ def project_table(
             projected_field_ids,
             deletes_per_file.get(task.file.file_path),
             case_sensitive,
-            limit,
             table_metadata.name_mapping(),
         )
         for task in tasks
@@ -1177,8 +1178,78 @@ def project_table(
     return result
 
 
-def to_requested_schema(requested_schema: Schema, file_schema: Schema, table: 
pa.Table) -> pa.Table:
-    struct_array = visit_with_partner(requested_schema, table, 
ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
+def project_batches(
+    tasks: Iterable[FileScanTask],
+    table_metadata: TableMetadata,
+    io: FileIO,
+    row_filter: BooleanExpression,
+    projected_schema: Schema,
+    case_sensitive: bool = True,
+    limit: Optional[int] = None,
+) -> Iterator[pa.RecordBatch]:
+    """Resolve the right columns based on the identifier.
+
+    Args:
+        tasks (Iterable[FileScanTask]): A URI or a path to a local file.
+        table_metadata (TableMetadata): The table metadata of the table that's 
being queried
+        io (FileIO): A FileIO to open streams to the object store
+        row_filter (BooleanExpression): The expression for filtering rows.
+        projected_schema (Schema): The output schema.
+        case_sensitive (bool): Case sensitivity when looking up column names.
+        limit (Optional[int]): Limit the number of records.
+
+    Raises:
+        ResolveError: When an incompatible query is done.
+    """
+    scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location)
+    if isinstance(io, PyArrowFileIO):
+        fs = io.fs_by_scheme(scheme, netloc)
+    else:
+        try:
+            from pyiceberg.io.fsspec import FsspecFileIO
+
+            if isinstance(io, FsspecFileIO):
+                from pyarrow.fs import PyFileSystem
+
+                fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
+            else:
+                raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, 
got: {io}")
+        except ModuleNotFoundError as e:
+            # When FsSpec is not installed
+            raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: 
{io}") from e
+
+    bound_row_filter = bind(table_metadata.schema(), row_filter, 
case_sensitive=case_sensitive)
+
+    projected_field_ids = {
+        id for id in projected_schema.field_ids if not 
isinstance(projected_schema.find_type(id), (MapType, ListType))
+    }.union(extract_field_ids(bound_row_filter))
+
+    deletes_per_file = _read_all_delete_files(fs, tasks)
+
+    total_row_count = 0
+
+    for task in tasks:
+        batches = _task_to_record_batches(
+            fs,
+            task,
+            bound_row_filter,
+            projected_schema,
+            projected_field_ids,
+            deletes_per_file.get(task.file.file_path),
+            case_sensitive,
+            table_metadata.name_mapping(),
+        )
+        for batch in batches:
+            if limit is not None:
+                if total_row_count + len(batch) >= limit:
+                    yield batch.slice(0, limit - total_row_count)
+                    break
+            yield batch
+            total_row_count += len(batch)
+
+
+def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: 
pa.RecordBatch) -> pa.RecordBatch:
+    struct_array = visit_with_partner(requested_schema, batch, 
ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
 
     arrays = []
     fields = []
@@ -1186,7 +1257,7 @@ def to_requested_schema(requested_schema: Schema, 
file_schema: Schema, table: pa
         array = struct_array.field(pos)
         arrays.append(array)
         fields.append(pa.field(field.name, array.type, field.optional))
-    return pa.Table.from_arrays(arrays, schema=pa.schema(fields))
+    return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))
 
 
 class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, 
Optional[pa.Array]]):
@@ -1293,8 +1364,10 @@ class ArrowAccessor(PartnerAccessor[pa.Array]):
 
             if isinstance(partner_struct, pa.StructArray):
                 return partner_struct.field(name)
-            elif isinstance(partner_struct, pa.Table):
-                return partner_struct.column(name).combine_chunks()
+            elif isinstance(partner_struct, pa.RecordBatch):
+                return partner_struct.column(name)
+            else:
+                raise ValueError(f"Cannot find {name} in expected 
partner_struct type {type(partner_struct)}")
 
         return None
 
@@ -1831,7 +1904,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, 
tasks: Iterator[WriteT
 
     def write_parquet(task: WriteTask) -> DataFile:
         table_schema = task.schema
-        arrow_table = pa.Table.from_batches(task.record_batches)
+
         # if schema needs to be transformed, use the transformed schema and 
adjust the arrow table accordingly
         # otherwise use the original schema
         if (sanitized_schema := sanitize_column_names(table_schema)) != 
table_schema:
@@ -1839,7 +1912,11 @@ def write_file(io: FileIO, table_metadata: 
TableMetadata, tasks: Iterator[WriteT
         else:
             file_schema = table_schema
 
-        arrow_table = to_requested_schema(requested_schema=file_schema, 
file_schema=table_schema, table=arrow_table)
+        batches = [
+            to_requested_schema(requested_schema=file_schema, 
file_schema=table_schema, batch=batch)
+            for batch in task.record_batches
+        ]
+        arrow_table = pa.Table.from_batches(batches)
         file_path = 
f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
         fo = io.new_output(file_path)
         with fo.create(overwrite=True) as fos:
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 9a10fc6b..c78e005c 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -1878,6 +1878,24 @@ class DataScan(TableScan):
             limit=self.limit,
         )
 
+    def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
+        import pyarrow as pa
+
+        from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow
+
+        return pa.RecordBatchReader.from_batches(
+            schema_to_pyarrow(self.projection()),
+            project_batches(
+                self.plan_files(),
+                self.table_metadata,
+                self.io,
+                self.row_filter,
+                self.projection(),
+                case_sensitive=self.case_sensitive,
+                limit=self.limit,
+            ),
+        )
+
     def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
         return self.to_arrow().to_pandas(**kwargs)
 
diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py
index 80a6f186..078abf40 100644
--- a/tests/integration/test_reads.py
+++ b/tests/integration/test_reads.py
@@ -21,6 +21,7 @@ import time
 import uuid
 from urllib.parse import urlparse
 
+import pyarrow as pa
 import pyarrow.parquet as pq
 import pytest
 from hive_metastore.ttypes import LockRequest, LockResponse, LockState, 
UnlockRequest
@@ -174,6 +175,47 @@ def test_pyarrow_not_nan_count(catalog: Catalog) -> None:
     assert len(not_nan) == 2
 
 
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_pyarrow_batches_nan(catalog: Catalog) -> None:
+    table_test_null_nan = catalog.load_table("default.test_null_nan")
+    arrow_batch_reader = table_test_null_nan.scan(
+        row_filter=IsNaN("col_numeric"), selected_fields=("idx", "col_numeric")
+    ).to_arrow_batch_reader()
+    assert isinstance(arrow_batch_reader, pa.RecordBatchReader)
+    arrow_table = arrow_batch_reader.read_all()
+    assert len(arrow_table) == 1
+    assert arrow_table["idx"][0].as_py() == 1
+    assert math.isnan(arrow_table["col_numeric"][0].as_py())
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_pyarrow_batches_nan_rewritten(catalog: Catalog) -> None:
+    table_test_null_nan_rewritten = 
catalog.load_table("default.test_null_nan_rewritten")
+    arrow_batch_reader = table_test_null_nan_rewritten.scan(
+        row_filter=IsNaN("col_numeric"), selected_fields=("idx", "col_numeric")
+    ).to_arrow_batch_reader()
+    assert isinstance(arrow_batch_reader, pa.RecordBatchReader)
+    arrow_table = arrow_batch_reader.read_all()
+    assert len(arrow_table) == 1
+    assert arrow_table["idx"][0].as_py() == 1
+    assert math.isnan(arrow_table["col_numeric"][0].as_py())
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
[email protected](reason="Fixing issues with NaN's: 
https://github.com/apache/arrow/issues/34162";)
+def test_pyarrow_batches_not_nan_count(catalog: Catalog) -> None:
+    table_test_null_nan = catalog.load_table("default.test_null_nan")
+    arrow_batch_reader = table_test_null_nan.scan(
+        row_filter=NotNaN("col_numeric"), selected_fields=("idx",)
+    ).to_arrow_batch_reader()
+    assert isinstance(arrow_batch_reader, pa.RecordBatchReader)
+    arrow_table = arrow_batch_reader.read_all()
+    assert len(arrow_table) == 2
+
+
 @pytest.mark.integration
 @pytest.mark.parametrize("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
 def test_duckdb_nan(catalog: Catalog) -> None:
@@ -354,6 +396,90 @@ def test_pyarrow_deletes_double(catalog: Catalog) -> None:
     assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10]
 
 
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_pyarrow_batches_deletes(catalog: Catalog) -> None:
+    # number, letter
+    #  (1, 'a'),
+    #  (2, 'b'),
+    #  (3, 'c'),
+    #  (4, 'd'),
+    #  (5, 'e'),
+    #  (6, 'f'),
+    #  (7, 'g'),
+    #  (8, 'h'),
+    #  (9, 'i'), <- deleted
+    #  (10, 'j'),
+    #  (11, 'k'),
+    #  (12, 'l')
+    test_positional_mor_deletes = 
catalog.load_table("default.test_positional_mor_deletes")
+    arrow_table = 
test_positional_mor_deletes.scan().to_arrow_batch_reader().read_all()
+    assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 6, 7, 8, 10, 
11, 12]
+
+    # Checking the filter
+    arrow_table = (
+        
test_positional_mor_deletes.scan(row_filter=And(GreaterThanOrEqual("letter", 
"e"), LessThan("letter", "k")))
+        .to_arrow_batch_reader()
+        .read_all()
+    )
+    assert arrow_table["number"].to_pylist() == [5, 6, 7, 8, 10]
+
+    # Testing the combination of a filter and a limit
+    arrow_table = (
+        
test_positional_mor_deletes.scan(row_filter=And(GreaterThanOrEqual("letter", 
"e"), LessThan("letter", "k")), limit=1)
+        .to_arrow_batch_reader()
+        .read_all()
+    )
+    assert arrow_table["number"].to_pylist() == [5]
+
+    # Testing the slicing of indices
+    arrow_table = 
test_positional_mor_deletes.scan(limit=3).to_arrow_batch_reader().read_all()
+    assert arrow_table["number"].to_pylist() == [1, 2, 3]
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_pyarrow_batches_deletes_double(catalog: Catalog) -> None:
+    # number, letter
+    #  (1, 'a'),
+    #  (2, 'b'),
+    #  (3, 'c'),
+    #  (4, 'd'),
+    #  (5, 'e'),
+    #  (6, 'f'), <- second delete
+    #  (7, 'g'),
+    #  (8, 'h'),
+    #  (9, 'i'), <- first delete
+    #  (10, 'j'),
+    #  (11, 'k'),
+    #  (12, 'l')
+    test_positional_mor_double_deletes = 
catalog.load_table("default.test_positional_mor_double_deletes")
+    arrow_table = 
test_positional_mor_double_deletes.scan().to_arrow_batch_reader().read_all()
+    assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10, 11, 
12]
+
+    # Checking the filter
+    arrow_table = (
+        
test_positional_mor_double_deletes.scan(row_filter=And(GreaterThanOrEqual("letter",
 "e"), LessThan("letter", "k")))
+        .to_arrow_batch_reader()
+        .read_all()
+    )
+    assert arrow_table["number"].to_pylist() == [5, 7, 8, 10]
+
+    # Testing the combination of a filter and a limit
+    arrow_table = (
+        test_positional_mor_double_deletes.scan(
+            row_filter=And(GreaterThanOrEqual("letter", "e"), 
LessThan("letter", "k")), limit=1
+        )
+        .to_arrow_batch_reader()
+        .read_all()
+    )
+    assert arrow_table["number"].to_pylist() == [5]
+
+    # Testing the slicing of indices
+    arrow_table = 
test_positional_mor_double_deletes.scan(limit=8).to_arrow_batch_reader().read_all()
+    assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10]
+
+
 @pytest.mark.integration
 @pytest.mark.parametrize("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
 def test_partitioned_tables(catalog: Catalog) -> None:

Reply via email to