This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new b8c5bb77 Support `Table.to_arrow_batch_reader` (#786)
b8c5bb77 is described below
commit b8c5bb77c5ea436aeced17676aa30d09c1224ed9
Author: Sung Yun <[email protected]>
AuthorDate: Fri Jun 21 09:44:24 2024 -0400
Support `Table.to_arrow_batch_reader` (#786)
* _task_to_table to _task_to_record_batches
* to_arrow_batches
* tests
* fix
* fix
* deletes
* batch reader
* merge main
* adopt review feedback
---
mkdocs/docs/api.md | 9 +++
pyiceberg/io/pyarrow.py | 155 ++++++++++++++++++++++++++++++----------
pyiceberg/table/__init__.py | 18 +++++
tests/integration/test_reads.py | 126 ++++++++++++++++++++++++++++++++
4 files changed, 269 insertions(+), 39 deletions(-)
diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md
index 6bbd9abe..54f4a20c 100644
--- a/mkdocs/docs/api.md
+++ b/mkdocs/docs/api.md
@@ -1003,6 +1003,15 @@ tpep_dropoff_datetime: [[2021-04-01
00:47:59.000000,...,2021-05-01 00:14:47.0000
This will only pull in the files that that might contain matching rows.
+One can also return a PyArrow RecordBatchReader, if reading one record batch
at a time is preferred:
+
+```python
+table.scan(
+ row_filter=GreaterThanOrEqual("trip_distance", 10.0),
+ selected_fields=("VendorID", "tpep_pickup_datetime",
"tpep_dropoff_datetime"),
+).to_arrow_batch_reader()
+```
+
### Pandas
<!-- prettier-ignore-start -->
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 935b78ce..e6490ae1 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -655,12 +655,12 @@ def _read_deletes(fs: FileSystem, data_file: DataFile) ->
Dict[str, pa.ChunkedAr
}
-def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray],
rows: int) -> pa.Array:
+def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray],
start_index: int, end_index: int) -> pa.Array:
if len(positional_deletes) == 1:
all_chunks = positional_deletes[0]
else:
all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in
positional_deletes]))
- return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)
+ return np.subtract(np.setdiff1d(np.arange(start_index, end_index),
all_chunks, assume_unique=False), start_index)
def pyarrow_to_schema(schema: pa.Schema, name_mapping: Optional[NameMapping] =
None) -> Schema:
@@ -995,7 +995,7 @@ class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
return -1
-def _task_to_table(
+def _task_to_record_batches(
fs: FileSystem,
task: FileScanTask,
bound_row_filter: BooleanExpression,
@@ -1003,9 +1003,8 @@ def _task_to_table(
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
- limit: Optional[int] = None,
name_mapping: Optional[NameMapping] = None,
-) -> Optional[pa.Table]:
+) -> Iterator[pa.RecordBatch]:
_, _, path = PyArrowFileIO.parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True,
buffer_size=(ONE_MEGABYTE * 8))
with fs.open_input_file(path) as fin:
@@ -1035,36 +1034,39 @@ def _task_to_table(
columns=[col.name for col in file_project_schema.columns],
)
- if positional_deletes:
- # Create the mask of indices that we're interested in
- indices = _combine_positional_deletes(positional_deletes,
fragment.count_rows())
-
- if limit:
- if pyarrow_filter is not None:
- # In case of the filter, we don't exactly know how many
rows
- # we need to fetch upfront, can be optimized in the future:
- # https://github.com/apache/arrow/issues/35301
- arrow_table = fragment_scanner.take(indices)
- arrow_table = arrow_table.filter(pyarrow_filter)
- arrow_table = arrow_table.slice(0, limit)
- else:
- arrow_table = fragment_scanner.take(indices[0:limit])
- else:
- arrow_table = fragment_scanner.take(indices)
+ current_index = 0
+ batches = fragment_scanner.to_batches()
+ for batch in batches:
+ if positional_deletes:
+ # Create the mask of indices that we're interested in
+ indices = _combine_positional_deletes(positional_deletes,
current_index, current_index + len(batch))
+ batch = batch.take(indices)
# Apply the user filter
if pyarrow_filter is not None:
+ # we need to switch back and forth between RecordBatch and
Table
+ # as Expression filter isn't yet supported in RecordBatch
+ # https://github.com/apache/arrow/issues/39220
+ arrow_table = pa.Table.from_batches([batch])
arrow_table = arrow_table.filter(pyarrow_filter)
- else:
- # If there are no deletes, we can just take the head
- # and the user-filter is already applied
- if limit:
- arrow_table = fragment_scanner.head(limit)
- else:
- arrow_table = fragment_scanner.to_table()
+ batch = arrow_table.to_batches()[0]
+ yield to_requested_schema(projected_schema, file_project_schema,
batch)
+ current_index += len(batch)
- if len(arrow_table) < 1:
- return None
- return to_requested_schema(projected_schema, file_project_schema,
arrow_table)
+
+def _task_to_table(
+ fs: FileSystem,
+ task: FileScanTask,
+ bound_row_filter: BooleanExpression,
+ projected_schema: Schema,
+ projected_field_ids: Set[int],
+ positional_deletes: Optional[List[ChunkedArray]],
+ case_sensitive: bool,
+ name_mapping: Optional[NameMapping] = None,
+) -> pa.Table:
+ batches = _task_to_record_batches(
+ fs, task, bound_row_filter, projected_schema, projected_field_ids,
positional_deletes, case_sensitive, name_mapping
+ )
+ return pa.Table.from_batches(batches,
schema=schema_to_pyarrow(projected_schema, include_field_ids=False))
def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) ->
Dict[str, List[ChunkedArray]]:
@@ -1143,7 +1145,6 @@ def project_table(
projected_field_ids,
deletes_per_file.get(task.file.file_path),
case_sensitive,
- limit,
table_metadata.name_mapping(),
)
for task in tasks
@@ -1177,8 +1178,78 @@ def project_table(
return result
-def to_requested_schema(requested_schema: Schema, file_schema: Schema, table:
pa.Table) -> pa.Table:
- struct_array = visit_with_partner(requested_schema, table,
ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
+def project_batches(
+ tasks: Iterable[FileScanTask],
+ table_metadata: TableMetadata,
+ io: FileIO,
+ row_filter: BooleanExpression,
+ projected_schema: Schema,
+ case_sensitive: bool = True,
+ limit: Optional[int] = None,
+) -> Iterator[pa.RecordBatch]:
+ """Resolve the right columns based on the identifier.
+
+ Args:
+ tasks (Iterable[FileScanTask]): A URI or a path to a local file.
+ table_metadata (TableMetadata): The table metadata of the table that's
being queried
+ io (FileIO): A FileIO to open streams to the object store
+ row_filter (BooleanExpression): The expression for filtering rows.
+ projected_schema (Schema): The output schema.
+ case_sensitive (bool): Case sensitivity when looking up column names.
+ limit (Optional[int]): Limit the number of records.
+
+ Raises:
+ ResolveError: When an incompatible query is done.
+ """
+ scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location)
+ if isinstance(io, PyArrowFileIO):
+ fs = io.fs_by_scheme(scheme, netloc)
+ else:
+ try:
+ from pyiceberg.io.fsspec import FsspecFileIO
+
+ if isinstance(io, FsspecFileIO):
+ from pyarrow.fs import PyFileSystem
+
+ fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
+ else:
+ raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO,
got: {io}")
+ except ModuleNotFoundError as e:
+ # When FsSpec is not installed
+ raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got:
{io}") from e
+
+ bound_row_filter = bind(table_metadata.schema(), row_filter,
case_sensitive=case_sensitive)
+
+ projected_field_ids = {
+ id for id in projected_schema.field_ids if not
isinstance(projected_schema.find_type(id), (MapType, ListType))
+ }.union(extract_field_ids(bound_row_filter))
+
+ deletes_per_file = _read_all_delete_files(fs, tasks)
+
+ total_row_count = 0
+
+ for task in tasks:
+ batches = _task_to_record_batches(
+ fs,
+ task,
+ bound_row_filter,
+ projected_schema,
+ projected_field_ids,
+ deletes_per_file.get(task.file.file_path),
+ case_sensitive,
+ table_metadata.name_mapping(),
+ )
+ for batch in batches:
+ if limit is not None:
+ if total_row_count + len(batch) >= limit:
+ yield batch.slice(0, limit - total_row_count)
+ break
+ yield batch
+ total_row_count += len(batch)
+
+
+def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch:
pa.RecordBatch) -> pa.RecordBatch:
+ struct_array = visit_with_partner(requested_schema, batch,
ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
arrays = []
fields = []
@@ -1186,7 +1257,7 @@ def to_requested_schema(requested_schema: Schema,
file_schema: Schema, table: pa
array = struct_array.field(pos)
arrays.append(array)
fields.append(pa.field(field.name, array.type, field.optional))
- return pa.Table.from_arrays(arrays, schema=pa.schema(fields))
+ return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array,
Optional[pa.Array]]):
@@ -1293,8 +1364,10 @@ class ArrowAccessor(PartnerAccessor[pa.Array]):
if isinstance(partner_struct, pa.StructArray):
return partner_struct.field(name)
- elif isinstance(partner_struct, pa.Table):
- return partner_struct.column(name).combine_chunks()
+ elif isinstance(partner_struct, pa.RecordBatch):
+ return partner_struct.column(name)
+ else:
+ raise ValueError(f"Cannot find {name} in expected
partner_struct type {type(partner_struct)}")
return None
@@ -1831,7 +1904,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata,
tasks: Iterator[WriteT
def write_parquet(task: WriteTask) -> DataFile:
table_schema = task.schema
- arrow_table = pa.Table.from_batches(task.record_batches)
+
# if schema needs to be transformed, use the transformed schema and
adjust the arrow table accordingly
# otherwise use the original schema
if (sanitized_schema := sanitize_column_names(table_schema)) !=
table_schema:
@@ -1839,7 +1912,11 @@ def write_file(io: FileIO, table_metadata:
TableMetadata, tasks: Iterator[WriteT
else:
file_schema = table_schema
- arrow_table = to_requested_schema(requested_schema=file_schema,
file_schema=table_schema, table=arrow_table)
+ batches = [
+ to_requested_schema(requested_schema=file_schema,
file_schema=table_schema, batch=batch)
+ for batch in task.record_batches
+ ]
+ arrow_table = pa.Table.from_batches(batches)
file_path =
f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 9a10fc6b..c78e005c 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -1878,6 +1878,24 @@ class DataScan(TableScan):
limit=self.limit,
)
+ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
+ import pyarrow as pa
+
+ from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow
+
+ return pa.RecordBatchReader.from_batches(
+ schema_to_pyarrow(self.projection()),
+ project_batches(
+ self.plan_files(),
+ self.table_metadata,
+ self.io,
+ self.row_filter,
+ self.projection(),
+ case_sensitive=self.case_sensitive,
+ limit=self.limit,
+ ),
+ )
+
def to_pandas(self, **kwargs: Any) -> pd.DataFrame:
return self.to_arrow().to_pandas(**kwargs)
diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py
index 80a6f186..078abf40 100644
--- a/tests/integration/test_reads.py
+++ b/tests/integration/test_reads.py
@@ -21,6 +21,7 @@ import time
import uuid
from urllib.parse import urlparse
+import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from hive_metastore.ttypes import LockRequest, LockResponse, LockState,
UnlockRequest
@@ -174,6 +175,47 @@ def test_pyarrow_not_nan_count(catalog: Catalog) -> None:
assert len(not_nan) == 2
[email protected]
[email protected]("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
+def test_pyarrow_batches_nan(catalog: Catalog) -> None:
+ table_test_null_nan = catalog.load_table("default.test_null_nan")
+ arrow_batch_reader = table_test_null_nan.scan(
+ row_filter=IsNaN("col_numeric"), selected_fields=("idx", "col_numeric")
+ ).to_arrow_batch_reader()
+ assert isinstance(arrow_batch_reader, pa.RecordBatchReader)
+ arrow_table = arrow_batch_reader.read_all()
+ assert len(arrow_table) == 1
+ assert arrow_table["idx"][0].as_py() == 1
+ assert math.isnan(arrow_table["col_numeric"][0].as_py())
+
+
[email protected]
[email protected]("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
+def test_pyarrow_batches_nan_rewritten(catalog: Catalog) -> None:
+ table_test_null_nan_rewritten =
catalog.load_table("default.test_null_nan_rewritten")
+ arrow_batch_reader = table_test_null_nan_rewritten.scan(
+ row_filter=IsNaN("col_numeric"), selected_fields=("idx", "col_numeric")
+ ).to_arrow_batch_reader()
+ assert isinstance(arrow_batch_reader, pa.RecordBatchReader)
+ arrow_table = arrow_batch_reader.read_all()
+ assert len(arrow_table) == 1
+ assert arrow_table["idx"][0].as_py() == 1
+ assert math.isnan(arrow_table["col_numeric"][0].as_py())
+
+
[email protected]
[email protected]("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
[email protected](reason="Fixing issues with NaN's:
https://github.com/apache/arrow/issues/34162")
+def test_pyarrow_batches_not_nan_count(catalog: Catalog) -> None:
+ table_test_null_nan = catalog.load_table("default.test_null_nan")
+ arrow_batch_reader = table_test_null_nan.scan(
+ row_filter=NotNaN("col_numeric"), selected_fields=("idx",)
+ ).to_arrow_batch_reader()
+ assert isinstance(arrow_batch_reader, pa.RecordBatchReader)
+ arrow_table = arrow_batch_reader.read_all()
+ assert len(arrow_table) == 2
+
+
@pytest.mark.integration
@pytest.mark.parametrize("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
def test_duckdb_nan(catalog: Catalog) -> None:
@@ -354,6 +396,90 @@ def test_pyarrow_deletes_double(catalog: Catalog) -> None:
assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10]
[email protected]
[email protected]("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
+def test_pyarrow_batches_deletes(catalog: Catalog) -> None:
+ # number, letter
+ # (1, 'a'),
+ # (2, 'b'),
+ # (3, 'c'),
+ # (4, 'd'),
+ # (5, 'e'),
+ # (6, 'f'),
+ # (7, 'g'),
+ # (8, 'h'),
+ # (9, 'i'), <- deleted
+ # (10, 'j'),
+ # (11, 'k'),
+ # (12, 'l')
+ test_positional_mor_deletes =
catalog.load_table("default.test_positional_mor_deletes")
+ arrow_table =
test_positional_mor_deletes.scan().to_arrow_batch_reader().read_all()
+ assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 6, 7, 8, 10,
11, 12]
+
+ # Checking the filter
+ arrow_table = (
+
test_positional_mor_deletes.scan(row_filter=And(GreaterThanOrEqual("letter",
"e"), LessThan("letter", "k")))
+ .to_arrow_batch_reader()
+ .read_all()
+ )
+ assert arrow_table["number"].to_pylist() == [5, 6, 7, 8, 10]
+
+ # Testing the combination of a filter and a limit
+ arrow_table = (
+
test_positional_mor_deletes.scan(row_filter=And(GreaterThanOrEqual("letter",
"e"), LessThan("letter", "k")), limit=1)
+ .to_arrow_batch_reader()
+ .read_all()
+ )
+ assert arrow_table["number"].to_pylist() == [5]
+
+ # Testing the slicing of indices
+ arrow_table =
test_positional_mor_deletes.scan(limit=3).to_arrow_batch_reader().read_all()
+ assert arrow_table["number"].to_pylist() == [1, 2, 3]
+
+
[email protected]
[email protected]("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
+def test_pyarrow_batches_deletes_double(catalog: Catalog) -> None:
+ # number, letter
+ # (1, 'a'),
+ # (2, 'b'),
+ # (3, 'c'),
+ # (4, 'd'),
+ # (5, 'e'),
+ # (6, 'f'), <- second delete
+ # (7, 'g'),
+ # (8, 'h'),
+ # (9, 'i'), <- first delete
+ # (10, 'j'),
+ # (11, 'k'),
+ # (12, 'l')
+ test_positional_mor_double_deletes =
catalog.load_table("default.test_positional_mor_double_deletes")
+ arrow_table =
test_positional_mor_double_deletes.scan().to_arrow_batch_reader().read_all()
+ assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10, 11,
12]
+
+ # Checking the filter
+ arrow_table = (
+
test_positional_mor_double_deletes.scan(row_filter=And(GreaterThanOrEqual("letter",
"e"), LessThan("letter", "k")))
+ .to_arrow_batch_reader()
+ .read_all()
+ )
+ assert arrow_table["number"].to_pylist() == [5, 7, 8, 10]
+
+ # Testing the combination of a filter and a limit
+ arrow_table = (
+ test_positional_mor_double_deletes.scan(
+ row_filter=And(GreaterThanOrEqual("letter", "e"),
LessThan("letter", "k")), limit=1
+ )
+ .to_arrow_batch_reader()
+ .read_all()
+ )
+ assert arrow_table["number"].to_pylist() == [5]
+
+ # Testing the slicing of indices
+ arrow_table =
test_positional_mor_double_deletes.scan(limit=8).to_arrow_batch_reader().read_all()
+ assert arrow_table["number"].to_pylist() == [1, 2, 3, 4, 5, 7, 8, 10]
+
+
@pytest.mark.integration
@pytest.mark.parametrize("catalog",
[pytest.lazy_fixture("session_catalog_hive"),
pytest.lazy_fixture("session_catalog")])
def test_partitioned_tables(catalog: Catalog) -> None: