This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new e97ed57b Add Arrow C streaming, DataFrame iteration, and OOM-safe
streaming execution (#1222)
e97ed57b is described below
commit e97ed57b4ca2e28dc292649ab2cda6a5bfc811e0
Author: kosiew <[email protected]>
AuthorDate: Fri Nov 7 05:00:23 2025 +0800
Add Arrow C streaming, DataFrame iteration, and OOM-safe streaming
execution (#1222)
* feat: add streaming utilities, range support, and improve async handling
in DataFrame
- Add `range` method to SessionContext and iterator support to DataFrame
- Introduce `spawn_stream` utility and refactor async execution for
better signal handling
- Add tests for `KeyboardInterrupt` in `__arrow_c_stream__` and
incremental DataFrame streaming
- Improve memory usage tracking in tests with psutil
- Update DataFrame docs with PyArrow streaming section and enhance
`__arrow_c_stream__` documentation
- Replace Tokio runtime creation with `spawn_stream` in PySessionContext
- Bump datafusion packages to 49.0.1 and update dependencies
- Remove unused imports and restore main Cargo.toml
* refactor: improve DataFrame streaming, memory management, and error
handling
- Refactor record batch streaming to use `poll_next_batch` for clearer
error handling
- Improve `spawn_future`/`spawn_stream` functions for better Python
exception integration and code reuse
- Update `datafusion` and `datafusion-ffi` dependencies to 49.0.2
- Fix PyArrow `RecordBatchReader` import to use `_import_from_c_capsule`
for safer memory handling
- Refactor `ArrowArrayStream` handling to use `PyCapsule` with
destructor for improved memory management
- Refactor projection initialization in `PyDataFrame` for clarity
- Move `range` functionality into `_testing.py` helper
- Rename test column in `test_table_from_batches_stream` for accuracy
- Add tests for `RecordBatchReader` and enhance DataFrame stream
handling
* feat: enhance DataFrame streaming and improve robustness, tests, and docs
- Preserve partition order in DataFrame streaming and update related
tests
- Add tests for record batch ordering and DataFrame batch iteration
- Improve `drop_stream` to correctly handle PyArrow ownership transfer
and null pointers
- Replace `assert` with `debug_assert` for safer ArrowArrayStream
validation
- Add documentation for `poll_next_batch` in PyRecordBatchStream
- Refactor tests to use `fail_collect` fixture for DataFrame collect
- Refactor `range_table` return type to `DataFrame` for clearer type
hints
- Minor cleanup in SessionContext (remove extra blank line)
* feat: add testing utilities for DataFrame range generation
* feat: ensure proper resource management in DataFrame streaming
* refactor: replace spawn_stream and spawn_streams with spawn_future for
consistency
* feat: add test for Arrow C stream schema selection in DataFrame
* test: rename and extend test_arrow_c_stream_to_table to include
RecordBatchReader validation
* test: add validation for schema mismatch in Arrow C stream
* fix Ruff errors
* Update docs/source/user-guide/dataframe/index.rst
Co-authored-by: Kyle Barron <[email protected]>
* test: add batch iteration test for DataFrame
* refactor: simplify stream capsule creation in PyDataFrame
* refactor: enhance stream capsule management in PyDataFrame
* refactor: enhance DataFrame and RecordBatchStream iteration support
* refactor: improve docstrings for DataFrame and RecordBatchStream methods
* refactor: add to_record_batch_stream method and improve iteration support
in DataFrame
* test: update test_iter_batches_dataframe to assert RecordBatch type and
conversion
* fix: update table creation from batches to use to_pyarrow conversion
* test: add test_iter_returns_datafusion_recordbatch to verify RecordBatch
type
* docs: clarify RecordBatch reference and add PyArrow conversion example
* test: improve test_iter_batches_dataframe to validate RecordBatch
conversion
* test: enhance test_arrow_c_stream_to_table_and_reader for batch equality
validation
* Shelve unrelated changes
* Fix documentation to reference datafusion.RecordBatch instead of
pyarrow.RecordBatch
* Remove redundant to_record_batch_stream method from DataFrame class
* Refactor Arrow stream creation in PyDataFrame to use PyCapsule directly
* Add `once_cell` dependency and refactor Arrow array stream capsule name
handling
* Add `cstr` dependency and refactor Arrow array stream capsule name
handling
* Refactor test_iter_returns_datafusion_recordbatch to use RecordBatch
directly
* Add streaming execution examples to DataFrame documentation
* Rename `to_record_batch_stream` to `execute_stream` and update references
in the codebase; mark the old method as deprecated.
* Clean up formatting in Cargo.toml for improved readability
* Refactor Cargo.toml for improved formatting and readability
* Update python/tests/test_io.py
Co-authored-by: Kyle Barron <[email protected]>
* Update python/datafusion/dataframe.py
Co-authored-by: Kyle Barron <[email protected]>
* Refactor test_table_from_batches_stream to use pa.table for improved
clarity
* Remove deprecated to_record_batch_stream method; use execute_stream
instead
* Add example for concurrent processing of partitioned streams using asyncio
* Update documentation to reflect changes in execute_stream return type and
usage
* Update PyArrow streaming example to use pa.table for eager collection
* Enhance documentation for DataFrame streaming API, clarifying schema
handling and limitations
* Clarify behavior of __arrow_c_stream__ execution, emphasizing incremental
batch processing and memory efficiency
* Add note on limitations of `arrow::compute::cast` for schema
transformations
* Update python/tests/test_io.py
Co-authored-by: Kyle Barron <[email protected]>
* Rename test function for clarity: update `test_table_from_batches_stream`
to `test_table_from_arrow_c_stream`
* Update python/datafusion/dataframe.py
Co-authored-by: Kyle Barron <[email protected]>
* Add documentation note for Arrow C Data Interface PyCapsule in DataFrame
class
* Enhance documentation on zero-copy streaming to Arrow-based Python
libraries, clarifying the protocol and adding implementation-agnostic notes.
* Fix formatting of section header for zero-copy streaming in DataFrame
documentation
* Refine zero-copy streaming documentation by removing outdated information
about eager conversion, emphasizing on-demand batch processing to prevent
memory exhaustion.
* Add alternative method for creating RecordBatchReader from Arrow C stream
* Refactor tests to use RecordBatchReader.from_stream instead of deprecated
_import_from_c_capsule method
* Replace deprecated _import_from_c_capsule method with from_stream for
RecordBatchReader in test_arrow_c_stream_schema_selection
* Update test description for arrow_c_stream_large_dataset to clarify
streaming method and usage of public API
* Add comments to clarify RSS measurement in
test_arrow_c_stream_large_dataset
* Fix ruff errors
* Update async iterator implementation in DataFrame to ensure compatibility
with Python < 3.10
* Fix async iterator implementation in DataFrame for compatibility with
Python < 3.10
* fix typo
* Fix formatting in DataFrame documentation and add example usage for Arrow
integration
* fix: correct formatting in documentation for RecordBatchStream
* refactor: remove unused import from errors module in dataframe.rs
* Simplified the streaming protocol description by removing the clause
about arbitrarily large results while keeping the paragraph smooth.
* Updated the Arrow streaming documentation to describe incremental
execution, remove the note block, and highlight lazy batch retrieval when using
__arrow_c_stream__
* Replaced the DataFrame.__arrow_c_stream__ docstring example with a link
to the Apache Arrow streaming documentation for practical guidance.
* fix: update user guide links in DataFrame class documentation for clarity
* minor ruff change
---------
Co-authored-by: Kyle Barron <[email protected]>
Co-authored-by: Tim Saucer <[email protected]>
---
Cargo.lock | 11 ++
Cargo.toml | 41 ++++-
docs/source/user-guide/dataframe/index.rst | 110 +++++++++++-
docs/source/user-guide/io/arrow.rst | 10 +-
python/datafusion/dataframe.py | 56 +++++--
python/datafusion/record_batch.py | 28 +++-
python/tests/conftest.py | 11 +-
python/tests/test_dataframe.py | 259 ++++++++++++++++++++++++++++-
python/tests/test_io.py | 44 +++++
python/tests/utils.py | 62 +++++++
src/context.rs | 12 +-
src/dataframe.rs | 128 ++++++++++----
src/record_batch.rs | 15 +-
src/utils.rs | 33 +++-
14 files changed, 743 insertions(+), 77 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index c7257594..2e345e71 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -833,6 +833,16 @@ dependencies = [
"typenum",
]
+[[package]]
+name = "cstr"
+version = "0.2.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "68523903c8ae5aacfa32a0d9ae60cadeb764e1da14ee0d26b1f3089f13a54636"
+dependencies = [
+ "proc-macro2",
+ "quote",
+]
+
[[package]]
name = "csv"
version = "1.3.1"
@@ -1587,6 +1597,7 @@ version = "50.1.0"
dependencies = [
"arrow",
"async-trait",
+ "cstr",
"datafusion",
"datafusion-ffi",
"datafusion-proto",
diff --git a/Cargo.toml b/Cargo.toml
index 92d531d9..3b7a4caa 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -26,17 +26,34 @@ readme = "README.md"
license = "Apache-2.0"
edition = "2021"
rust-version = "1.78"
-include = ["/src", "/datafusion", "/LICENSE.txt", "build.rs",
"pyproject.toml", "Cargo.toml", "Cargo.lock"]
+include = [
+ "/src",
+ "/datafusion",
+ "/LICENSE.txt",
+ "build.rs",
+ "pyproject.toml",
+ "Cargo.toml",
+ "Cargo.lock",
+]
[features]
default = ["mimalloc"]
-protoc = [ "datafusion-substrait/protoc" ]
+protoc = ["datafusion-substrait/protoc"]
substrait = ["dep:datafusion-substrait"]
[dependencies]
-tokio = { version = "1.47", features = ["macros", "rt", "rt-multi-thread",
"sync"] }
-pyo3 = { version = "0.25", features = ["extension-module", "abi3",
"abi3-py310"] }
-pyo3-async-runtimes = { version = "0.25", features = ["tokio-runtime"]}
+tokio = { version = "1.47", features = [
+ "macros",
+ "rt",
+ "rt-multi-thread",
+ "sync",
+] }
+pyo3 = { version = "0.25", features = [
+ "extension-module",
+ "abi3",
+ "abi3-py310",
+] }
+pyo3-async-runtimes = { version = "0.25", features = ["tokio-runtime"] }
pyo3-log = "0.12.4"
arrow = { version = "56", features = ["pyarrow"] }
datafusion = { version = "50", features = ["avro", "unicode_expressions"] }
@@ -45,16 +62,24 @@ datafusion-proto = { version = "50" }
datafusion-ffi = { version = "50" }
prost = "0.13.1" # keep in line with `datafusion-substrait`
uuid = { version = "1.18", features = ["v4"] }
-mimalloc = { version = "0.1", optional = true, default-features = false,
features = ["local_dynamic_tls"] }
+mimalloc = { version = "0.1", optional = true, default-features = false,
features = [
+ "local_dynamic_tls",
+] }
async-trait = "0.1.89"
futures = "0.3"
-object_store = { version = "0.12.4", features = ["aws", "gcp", "azure",
"http"] }
+cstr = "0.2"
+object_store = { version = "0.12.4", features = [
+ "aws",
+ "gcp",
+ "azure",
+ "http",
+] }
url = "2"
log = "0.4.27"
parking_lot = "0.12"
[build-dependencies]
-prost-types = "0.13.1" # keep in line with `datafusion-substrait`
+prost-types = "0.13.1" # keep in line with `datafusion-substrait`
pyo3-build-config = "0.25"
[lib]
diff --git a/docs/source/user-guide/dataframe/index.rst
b/docs/source/user-guide/dataframe/index.rst
index 6d82f707..659589cf 100644
--- a/docs/source/user-guide/dataframe/index.rst
+++ b/docs/source/user-guide/dataframe/index.rst
@@ -196,10 +196,118 @@ To materialize the results of your DataFrame operations:
# Display results
df.show() # Print tabular format to console
-
+
# Count rows
count = df.count()
+Zero-copy streaming to Arrow-based Python libraries
+---------------------------------------------------
+
+DataFusion DataFrames implement the ``__arrow_c_stream__`` protocol, enabling
+zero-copy, lazy streaming into Arrow-based Python libraries. With the streaming
+protocol, batches are produced on demand.
+
+.. note::
+
+ The protocol is implementation-agnostic and works with any Python library
+ that understands the Arrow C streaming interface (for example, PyArrow
+ or other Arrow-compatible implementations). The sections below provide a
+ short PyArrow-specific example and general guidance for other
+ implementations.
+
+PyArrow
+-------
+
+.. code-block:: python
+
+ import pyarrow as pa
+
+ # Create a PyArrow RecordBatchReader without materializing all batches
+ reader = pa.RecordBatchReader.from_stream(df)
+ for batch in reader:
+ ... # process each batch as it is produced
+
+DataFrames are also iterable, yielding :class:`datafusion.RecordBatch`
+objects lazily so you can loop over results directly without importing
+PyArrow:
+
+.. code-block:: python
+
+ for batch in df:
+ ... # each batch is a ``datafusion.RecordBatch``
+
+Each batch exposes ``to_pyarrow()``, allowing conversion to a PyArrow
+table. ``pa.table(df)`` collects the entire DataFrame eagerly into a
+PyArrow table::
+
+.. code-block:: python
+
+ import pyarrow as pa
+ table = pa.table(df)
+
+Asynchronous iteration is supported as well, allowing integration with
+``asyncio`` event loops::
+
+.. code-block:: python
+
+ async for batch in df:
+ ... # process each batch as it is produced
+
+To work with the stream directly, use ``execute_stream()``, which returns a
+:class:`~datafusion.RecordBatchStream`.
+
+.. code-block:: python
+
+ stream = df.execute_stream()
+ for batch in stream:
+ ...
+
+Execute as Stream
+^^^^^^^^^^^^^^^^^
+
+For finer control over streaming execution, use
+:py:meth:`~datafusion.DataFrame.execute_stream` to obtain a
+:py:class:`datafusion.RecordBatchStream`:
+
+.. code-block:: python
+
+ stream = df.execute_stream()
+ for batch in stream:
+ ... # process each batch as it is produced
+
+.. tip::
+
+ To get a PyArrow reader instead, call
+
+ ``pa.RecordBatchReader.from_stream(df)``.
+
+When partition boundaries are important,
+:py:meth:`~datafusion.DataFrame.execute_stream_partitioned`
+returns an iterable of :py:class:`datafusion.RecordBatchStream` objects, one
per
+partition:
+
+.. code-block:: python
+
+ for stream in df.execute_stream_partitioned():
+ for batch in stream:
+ ... # each stream yields RecordBatches
+
+To process partitions concurrently, first collect the streams into a list
+and then poll each one in a separate ``asyncio`` task:
+
+.. code-block:: python
+
+ import asyncio
+
+ async def consume(stream):
+ async for batch in stream:
+ ...
+
+ streams = list(df.execute_stream_partitioned())
+ await asyncio.gather(*(consume(s) for s in streams))
+
+See :doc:`../io/arrow` for additional details on the Arrow interface.
+
HTML Rendering
--------------
diff --git a/docs/source/user-guide/io/arrow.rst
b/docs/source/user-guide/io/arrow.rst
index d571aa99..9196fcea 100644
--- a/docs/source/user-guide/io/arrow.rst
+++ b/docs/source/user-guide/io/arrow.rst
@@ -60,14 +60,16 @@ Exporting from DataFusion
DataFusion DataFrames implement ``__arrow_c_stream__`` PyCapsule interface, so
any
Python library that accepts these can import a DataFusion DataFrame directly.
-.. warning::
- It is important to note that this will cause the DataFrame execution to
happen, which may be
- a time consuming task. That is, you will cause a
- :py:func:`datafusion.dataframe.DataFrame.collect` operation call to occur.
+Invoking ``__arrow_c_stream__`` triggers execution of the underlying query, but
+batches are yielded incrementally rather than materialized all at once in
memory.
+Consumers can process the stream as it arrives. The stream executes lazily,
+letting downstream readers pull batches on demand.
.. ipython:: python
+ from datafusion import col, lit
+
df = df.select((col("a") * lit(1.5)).alias("c"), lit("df").alias("d"))
pa.table(df)
diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py
index 8d692aca..c6ff7eda 100644
--- a/python/datafusion/dataframe.py
+++ b/python/datafusion/dataframe.py
@@ -22,7 +22,7 @@ See :ref:`user_guide_concepts` in the online documentation
for more information.
from __future__ import annotations
import warnings
-from collections.abc import Iterable, Sequence
+from collections.abc import AsyncIterator, Iterable, Iterator, Sequence
from typing import (
TYPE_CHECKING,
Any,
@@ -50,7 +50,7 @@ from datafusion.expr import (
sort_list_to_raw_sort_list,
)
from datafusion.plan import ExecutionPlan, LogicalPlan
-from datafusion.record_batch import RecordBatchStream
+from datafusion.record_batch import RecordBatch, RecordBatchStream
if TYPE_CHECKING:
import pathlib
@@ -304,6 +304,9 @@ class ParquetColumnOptions:
class DataFrame:
"""Two dimensional table representation of data.
+ DataFrame objects are iterable; iterating over a DataFrame yields
+ :class:`datafusion.RecordBatch` instances lazily.
+
See :ref:`user_guide_concepts` in the online documentation for more
information.
"""
@@ -332,7 +335,7 @@ class DataFrame:
return _Table(self.df.into_view(temporary))
def __getitem__(self, key: str | list[str]) -> DataFrame:
- """Return a new :py:class`DataFrame` with the specified column or
columns.
+ """Return a new :py:class:`DataFrame` with the specified column or
columns.
Args:
key: Column name or list of column names to select.
@@ -1291,21 +1294,54 @@ class DataFrame:
return DataFrame(self.df.unnest_columns(columns,
preserve_nulls=preserve_nulls))
def __arrow_c_stream__(self, requested_schema: object | None = None) ->
object:
- """Export an Arrow PyCapsule Stream.
+ """Export the DataFrame as an Arrow C Stream.
+
+ The DataFrame is executed using DataFusion's streaming APIs and
exposed via
+ Arrow's C Stream interface. Record batches are produced incrementally,
so the
+ full result set is never materialized in memory.
- This will execute and collect the DataFrame. We will attempt to
respect the
- requested schema, but only trivial transformations will be applied
such as only
- returning the fields listed in the requested schema if their data
types match
- those in the DataFrame.
+ When ``requested_schema`` is provided, DataFusion applies only simple
+ projections such as selecting a subset of existing columns or
reordering
+ them. Column renaming, computed expressions, or type coercion are not
+ supported through this interface.
Args:
- requested_schema: Attempt to provide the DataFrame using this
schema.
+ requested_schema: Either a :py:class:`pyarrow.Schema` or an Arrow C
+ Schema capsule (``PyCapsule``) produced by
+ ``schema._export_to_c_capsule()``. The DataFrame will attempt
to
+ align its output with the fields and order specified by this
schema.
Returns:
- Arrow PyCapsule object.
+ Arrow ``PyCapsule`` object representing an ``ArrowArrayStream``.
+
+ For practical usage patterns, see the Apache Arrow streaming
+ documentation: https://arrow.apache.org/docs/python/ipc.html#streaming.
+
+ For details on DataFusion's Arrow integration and DataFrame streaming,
+ see the user guide (user-guide/io/arrow and
user-guide/dataframe/index).
+
+ Notes:
+ The Arrow C Data Interface PyCapsule details are documented by
Apache
+ Arrow and can be found at:
+
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
"""
+ # ``DataFrame.__arrow_c_stream__`` in the Rust extension leverages
+ # ``execute_stream_partitioned`` under the hood to stream batches while
+ # preserving the original partition order.
return self.df.__arrow_c_stream__(requested_schema)
+ def __iter__(self) -> Iterator[RecordBatch]:
+ """Return an iterator over this DataFrame's record batches."""
+ return iter(self.execute_stream())
+
+ def __aiter__(self) -> AsyncIterator[RecordBatch]:
+ """Return an async iterator over this DataFrame's record batches.
+
+ We're using __aiter__ because we support Python < 3.10 where aiter()
is not
+ available.
+ """
+ return self.execute_stream().__aiter__()
+
def transform(self, func: Callable[..., DataFrame], *args: Any) ->
DataFrame:
"""Apply a function to the current DataFrame which returns another
DataFrame.
diff --git a/python/datafusion/record_batch.py
b/python/datafusion/record_batch.py
index 556eaa78..c24cde0a 100644
--- a/python/datafusion/record_batch.py
+++ b/python/datafusion/record_batch.py
@@ -46,6 +46,26 @@ class RecordBatch:
"""Convert to :py:class:`pa.RecordBatch`."""
return self.record_batch.to_pyarrow()
+ def __arrow_c_array__(
+ self, requested_schema: object | None = None
+ ) -> tuple[object, object]:
+ """Export the record batch via the Arrow C Data Interface.
+
+ This allows zero-copy interchange with libraries that support the
+ `Arrow PyCapsule interface <https://arrow.apache.org/docs/format/
+ CDataInterface/PyCapsuleInterface.html>`_.
+
+ Args:
+ requested_schema: Attempt to provide the record batch using this
+ schema. Only straightforward projections such as column
+ selection or reordering are applied.
+
+ Returns:
+ Two Arrow PyCapsule objects representing the ``ArrowArray`` and
+ ``ArrowSchema``.
+ """
+ return self.record_batch.__arrow_c_array__(requested_schema)
+
class RecordBatchStream:
"""This class represents a stream of record batches.
@@ -63,19 +83,19 @@ class RecordBatchStream:
return next(self)
async def __anext__(self) -> RecordBatch:
- """Async iterator function."""
+ """Return the next :py:class:`RecordBatch` in the stream
asynchronously."""
next_batch = await self.rbs.__anext__()
return RecordBatch(next_batch)
def __next__(self) -> RecordBatch:
- """Iterator function."""
+ """Return the next :py:class:`RecordBatch` in the stream."""
next_batch = next(self.rbs)
return RecordBatch(next_batch)
def __aiter__(self) -> typing_extensions.Self:
- """Async iterator function."""
+ """Return an asynchronous iterator over record batches."""
return self
def __iter__(self) -> typing_extensions.Self:
- """Iterator function."""
+ """Return an iterator over record batches."""
return self
diff --git a/python/tests/conftest.py b/python/tests/conftest.py
index 9548fbfe..26ed7281 100644
--- a/python/tests/conftest.py
+++ b/python/tests/conftest.py
@@ -17,7 +17,7 @@
import pyarrow as pa
import pytest
-from datafusion import SessionContext
+from datafusion import DataFrame, SessionContext
from pyarrow.csv import write_csv
@@ -49,3 +49,12 @@ def database(ctx, tmp_path):
delimiter=",",
schema_infer_max_records=10,
)
+
+
[email protected]
+def fail_collect(monkeypatch):
+ def _fail_collect(self, *args, **kwargs): # pragma: no cover - failure
path
+ msg = "collect should not be called"
+ raise AssertionError(msg)
+
+ monkeypatch.setattr(DataFrame, "collect", _fail_collect)
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index aed477af..101dfc5b 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -32,6 +32,7 @@ from datafusion import (
InsertOp,
ParquetColumnOptions,
ParquetWriterOptions,
+ RecordBatch,
SessionContext,
WindowFrame,
column,
@@ -53,6 +54,8 @@ from datafusion.dataframe_formatter import (
from datafusion.expr import EXPR_TYPE_ERROR, Window
from pyarrow.csv import write_csv
+pa_cffi = pytest.importorskip("pyarrow.cffi")
+
MB = 1024 * 1024
@@ -579,6 +582,41 @@ def test_cast(df):
assert df.schema() == expected
+def test_iter_batches(df):
+ batches = []
+ for batch in df:
+ batches.append(batch) # noqa: PERF402
+
+ # Delete DataFrame to ensure RecordBatches remain valid
+ del df
+
+ assert len(batches) == 1
+
+ batch = batches[0]
+ assert isinstance(batch, RecordBatch)
+ pa_batch = batch.to_pyarrow()
+ assert pa_batch.column(0).to_pylist() == [1, 2, 3]
+ assert pa_batch.column(1).to_pylist() == [4, 5, 6]
+ assert pa_batch.column(2).to_pylist() == [8, 5, 8]
+
+
+def test_iter_returns_datafusion_recordbatch(df):
+ for batch in df:
+ assert isinstance(batch, RecordBatch)
+
+
+def test_execute_stream_basic(df):
+ stream = df.execute_stream()
+ batches = list(stream)
+
+ assert len(batches) == 1
+ assert isinstance(batches[0], RecordBatch)
+ pa_batch = batches[0].to_pyarrow()
+ assert pa_batch.column(0).to_pylist() == [1, 2, 3]
+ assert pa_batch.column(1).to_pylist() == [4, 5, 6]
+ assert pa_batch.column(2).to_pylist() == [8, 5, 8]
+
+
def test_with_column_renamed(df):
df = df.with_column("c", column("a") +
column("b")).with_column_renamed("c", "sum")
@@ -1609,7 +1647,7 @@ def test_execution_plan(aggregate_df):
@pytest.mark.asyncio
async def test_async_iteration_of_df(aggregate_df):
rows_returned = 0
- async for batch in aggregate_df.execute_stream():
+ async for batch in aggregate_df:
assert batch is not None
rows_returned += len(batch.to_pyarrow()[0])
@@ -1887,6 +1925,121 @@ def test_empty_to_arrow_table(df):
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
+def test_iter_batches_dataframe(fail_collect):
+ ctx = SessionContext()
+
+ batch1 = pa.record_batch([pa.array([1])], names=["a"])
+ batch2 = pa.record_batch([pa.array([2])], names=["a"])
+ df = ctx.create_dataframe([[batch1], [batch2]])
+
+ expected = [batch1, batch2]
+ results = [b.to_pyarrow() for b in df]
+
+ assert len(results) == len(expected)
+ for exp in expected:
+ assert any(got.equals(exp) for got in results)
+
+
+def test_arrow_c_stream_to_table_and_reader(fail_collect):
+ ctx = SessionContext()
+
+ # Create a DataFrame with two separate record batches
+ batch1 = pa.record_batch([pa.array([1])], names=["a"])
+ batch2 = pa.record_batch([pa.array([2])], names=["a"])
+ df = ctx.create_dataframe([[batch1], [batch2]])
+
+ table = pa.Table.from_batches(batch.to_pyarrow() for batch in df)
+ batches = table.to_batches()
+
+ assert len(batches) == 2
+ expected = [batch1, batch2]
+ for exp in expected:
+ assert any(got.equals(exp) for got in batches)
+ assert table.schema == df.schema()
+ assert table.column("a").num_chunks == 2
+
+ reader = pa.RecordBatchReader.from_stream(df)
+ assert isinstance(reader, pa.RecordBatchReader)
+ reader_table = pa.Table.from_batches(reader)
+ expected = pa.Table.from_batches([batch1, batch2])
+ assert reader_table.equals(expected)
+
+
+def test_arrow_c_stream_order():
+ ctx = SessionContext()
+
+ batch1 = pa.record_batch([pa.array([1])], names=["a"])
+ batch2 = pa.record_batch([pa.array([2])], names=["a"])
+
+ df = ctx.create_dataframe([[batch1, batch2]])
+
+ table = pa.Table.from_batches(batch.to_pyarrow() for batch in df)
+ expected = pa.Table.from_batches([batch1, batch2])
+
+ assert table.equals(expected)
+ col = table.column("a")
+ assert col.chunk(0)[0].as_py() == 1
+ assert col.chunk(1)[0].as_py() == 2
+
+
+def test_arrow_c_stream_schema_selection(fail_collect):
+ ctx = SessionContext()
+
+ batch = pa.RecordBatch.from_arrays(
+ [
+ pa.array([1, 2]),
+ pa.array([3, 4]),
+ pa.array([5, 6]),
+ ],
+ names=["a", "b", "c"],
+ )
+ df = ctx.create_dataframe([[batch]])
+
+ requested_schema = pa.schema([("c", pa.int64()), ("a", pa.int64())])
+
+ c_schema = pa_cffi.ffi.new("struct ArrowSchema*")
+ address = int(pa_cffi.ffi.cast("uintptr_t", c_schema))
+ requested_schema._export_to_c(address)
+ capsule_new = ctypes.pythonapi.PyCapsule_New
+ capsule_new.restype = ctypes.py_object
+ capsule_new.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
+
+ reader = pa.RecordBatchReader.from_stream(df, schema=requested_schema)
+
+ assert reader.schema == requested_schema
+
+ batches = list(reader)
+
+ assert len(batches) == 1
+ expected_batch = pa.record_batch(
+ [pa.array([5, 6]), pa.array([1, 2])], names=["c", "a"]
+ )
+ assert batches[0].equals(expected_batch)
+
+
+def test_arrow_c_stream_schema_mismatch(fail_collect):
+ ctx = SessionContext()
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2]), pa.array([3, 4])], names=["a", "b"]
+ )
+ df = ctx.create_dataframe([[batch]])
+
+ bad_schema = pa.schema([("a", pa.string())])
+
+ c_schema = pa_cffi.ffi.new("struct ArrowSchema*")
+ address = int(pa_cffi.ffi.cast("uintptr_t", c_schema))
+ bad_schema._export_to_c(address)
+
+ capsule_new = ctypes.pythonapi.PyCapsule_New
+ capsule_new.restype = ctypes.py_object
+ capsule_new.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
+ bad_capsule = capsule_new(ctypes.c_void_p(address), b"arrow_schema", None)
+
+ with pytest.raises(Exception, match="Fail to merge schema"):
+ df.__arrow_c_stream__(bad_capsule)
+
+
def test_to_pylist(df):
# Convert datafusion dataframe to Python list
pylist = df.to_pylist()
@@ -3053,6 +3206,110 @@ def test_collect_interrupted():
interrupt_thread.join(timeout=1.0)
+def test_arrow_c_stream_interrupted():
+ """__arrow_c_stream__ responds to ``KeyboardInterrupt`` signals.
+
+ Similar to ``test_collect_interrupted`` this test issues a long running
+ query, but consumes the results via ``__arrow_c_stream__``. It then raises
+ ``KeyboardInterrupt`` in the main thread and verifies that the stream
+ iteration stops promptly with the appropriate exception.
+ """
+
+ ctx = SessionContext()
+
+ batches = []
+ for i in range(10):
+ batch = pa.RecordBatch.from_arrays(
+ [
+ pa.array(list(range(i * 1000, (i + 1) * 1000))),
+ pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) *
1000)]),
+ ],
+ names=["a", "b"],
+ )
+ batches.append(batch)
+
+ ctx.register_record_batches("t1", [batches])
+ ctx.register_record_batches("t2", [batches])
+
+ df = ctx.sql(
+ """
+ WITH t1_expanded AS (
+ SELECT
+ a,
+ b,
+ CAST(a AS DOUBLE) / 1.5 AS c,
+ CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
+ FROM t1
+ CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
+ ),
+ t2_expanded AS (
+ SELECT
+ a,
+ b,
+ CAST(a AS DOUBLE) * 2.5 AS e,
+ CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
+ FROM t2
+ CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
+ )
+ SELECT
+ t1.a, t1.b, t1.c, t1.d,
+ t2.a AS a2, t2.b AS b2, t2.e, t2.f
+ FROM t1_expanded t1
+ JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
+ WHERE t1.a > 100 AND t2.a > 100
+ """
+ )
+
+ reader = pa.RecordBatchReader.from_stream(df)
+
+ interrupted = False
+ interrupt_error = None
+ query_started = threading.Event()
+ max_wait_time = 5.0
+
+ def trigger_interrupt():
+ start_time = time.time()
+ while not query_started.is_set():
+ time.sleep(0.1)
+ if time.time() - start_time > max_wait_time:
+ msg = f"Query did not start within {max_wait_time} seconds"
+ raise RuntimeError(msg)
+
+ thread_id = threading.main_thread().ident
+ if thread_id is None:
+ msg = "Cannot get main thread ID"
+ raise RuntimeError(msg)
+
+ exception = ctypes.py_object(KeyboardInterrupt)
+ res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
+ ctypes.c_long(thread_id), exception
+ )
+ if res != 1:
+ ctypes.pythonapi.PyThreadState_SetAsyncExc(
+ ctypes.c_long(thread_id), ctypes.py_object(0)
+ )
+ msg = "Failed to raise KeyboardInterrupt in main thread"
+ raise RuntimeError(msg)
+
+ interrupt_thread = threading.Thread(target=trigger_interrupt)
+ interrupt_thread.daemon = True
+ interrupt_thread.start()
+
+ try:
+ query_started.set()
+ # consume the reader which should block and be interrupted
+ reader.read_all()
+ except KeyboardInterrupt:
+ interrupted = True
+ except Exception as e: # pragma: no cover - unexpected errors
+ interrupt_error = e
+
+ if not interrupted:
+ pytest.fail(f"Stream was not interrupted; got error:
{interrupt_error}")
+
+ interrupt_thread.join(timeout=1.0)
+
+
def test_show_select_where_no_rows(capsys) -> None:
ctx = SessionContext()
df = ctx.sql("SELECT 1 WHERE 1=0")
diff --git a/python/tests/test_io.py b/python/tests/test_io.py
index 7ca50968..9f56f74d 100644
--- a/python/tests/test_io.py
+++ b/python/tests/test_io.py
@@ -14,12 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from pathlib import Path
import pyarrow as pa
+import pytest
from datafusion import column
from datafusion.io import read_avro, read_csv, read_json, read_parquet
+from .utils import range_table
+
def test_read_json_global_ctx(ctx):
path = Path(__file__).parent.resolve()
@@ -92,3 +96,43 @@ def test_read_avro():
path = Path.cwd() / "testing/data/avro/alltypes_plain.avro"
avro_df = read_avro(path=path)
assert avro_df is not None
+
+
+def test_arrow_c_stream_large_dataset(ctx):
+ """DataFrame streaming yields batches incrementally using Arrow APIs.
+
+ This test constructs a DataFrame that would be far larger than available
+ memory if materialized. Use the public API
+ ``pa.RecordBatchReader.from_stream(df)`` (which is same as
+ ``pa.RecordBatchReader._import_from_c_capsule(df.__arrow_c_stream__())``)
+ to read record batches incrementally without collecting the full dataset,
+ so reading a handful of batches should not exhaust process memory.
+ """
+ # Create a very large DataFrame using range; this would be terabytes if
collected
+ df = range_table(ctx, 0, 1 << 40)
+
+ reader = pa.RecordBatchReader.from_stream(df)
+
+ # Track RSS before consuming batches
+ # RSS is a practical measure of RAM usage visible to the OS. It excludes
memory
+ # that has been swapped out and provides a simple cross-platform-ish
indicator
+ # (psutil normalizes per-OS sources).
+ psutil = pytest.importorskip("psutil")
+ process = psutil.Process()
+ start_rss = process.memory_info().rss
+
+ for _ in range(5):
+ batch = reader.read_next_batch()
+ assert batch is not None
+ assert len(batch) > 0
+ current_rss = process.memory_info().rss
+ # Ensure memory usage hasn't grown substantially (>50MB)
+ assert current_rss - start_rss < 50 * 1024 * 1024
+
+
+def test_table_from_arrow_c_stream(ctx, fail_collect):
+ df = range_table(ctx, 0, 10)
+
+ table = pa.table(df)
+ assert table.shape == (10, 1)
+ assert table.column_names == ["value"]
diff --git a/python/tests/utils.py b/python/tests/utils.py
new file mode 100644
index 00000000..00efb655
--- /dev/null
+++ b/python/tests/utils.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Testing-only helpers for datafusion-python.
+
+This module contains utilities used by the test-suite that should not be
+exposed as part of the public API. Keep the implementation minimal and
+documented so reviewers can easily see it's test-only.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from datafusion import DataFrame
+ from datafusion.context import SessionContext
+
+
+def range_table(
+ ctx: SessionContext,
+ start: int,
+ stop: int | None = None,
+ step: int = 1,
+ partitions: int | None = None,
+) -> DataFrame:
+ """Create a DataFrame containing a sequence of numbers using SQL RANGE.
+
+ This mirrors the previous ``SessionContext.range`` convenience method but
+ lives in a testing-only module so it doesn't expand the public surface.
+
+ Args:
+ ctx: SessionContext instance to run the SQL against.
+ start: Starting value for the sequence or exclusive stop when ``stop``
+ is ``None``.
+ stop: Exclusive upper bound of the sequence.
+ step: Increment between successive values.
+ partitions: Optional number of partitions for the generated data.
+
+ Returns:
+ DataFrame produced by the range table function.
+ """
+ if stop is None:
+ start, stop = 0, start
+
+ parts = f", {int(partitions)}" if partitions is not None else ""
+ sql = f"SELECT * FROM range({int(start)}, {int(stop)}, {int(step)}{parts})"
+ return ctx.sql(sql)
diff --git a/src/context.rs b/src/context.rs
index dc18a767..e8d87580 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -34,7 +34,7 @@ use pyo3::prelude::*;
use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
-use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
+use crate::errors::{py_datafusion_err, PyDataFusionResult};
use crate::expr::sort_expr::PySortExpr;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
@@ -46,7 +46,7 @@ use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
use crate::udtf::PyTableFunction;
use crate::udwf::PyWindowUDF;
-use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule,
wait_for_future};
+use crate::utils::{get_global_ctx, spawn_future, validate_pycapsule,
wait_for_future};
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
@@ -67,14 +67,12 @@ use datafusion::execution::disk_manager::DiskManagerMode;
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool,
UnboundedMemoryPool};
use datafusion::execution::options::ReadOptions;
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
-use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions,
ParquetReadOptions,
};
use datafusion_ffi::catalog_provider::{FFI_CatalogProvider,
ForeignCatalogProvider};
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
use pyo3::IntoPyObjectExt;
-use tokio::task::JoinHandle;
/// Configuration options for a SessionContext
#[pyclass(frozen, name = "SessionConfig", module = "datafusion", subclass)]
@@ -1107,12 +1105,8 @@ impl PySessionContext {
py: Python,
) -> PyDataFusionResult<PyRecordBatchStream> {
let ctx: TaskContext = TaskContext::from(&self.ctx.state());
- // create a Tokio runtime to run the async code
- let rt = &get_tokio_runtime().0;
let plan = plan.plan.clone();
- let fut:
JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
- rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
- let stream = wait_for_future(py, async {
fut.await.map_err(to_datafusion_err) })???;
+ let stream = spawn_future(py, async move { plan.execute(part,
Arc::new(ctx)) })?;
Ok(PyRecordBatchStream::new(stream))
}
}
diff --git a/src/dataframe.rs b/src/dataframe.rs
index df3e9d31..a93aa018 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -15,17 +15,18 @@
// specific language governing permissions and limitations
// under the License.
+use cstr::cstr;
use std::collections::HashMap;
-use std::ffi::CString;
+use std::ffi::{CStr, CString};
use std::sync::Arc;
-use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator,
RecordBatchReader};
+use arrow::array::{new_null_array, RecordBatch, RecordBatchReader};
use arrow::compute::can_cast_types;
use arrow::error::ArrowError;
use arrow::ffi::FFI_ArrowSchema;
use arrow::ffi_stream::FFI_ArrowArrayStream;
use arrow::pyarrow::FromPyArrow;
-use datafusion::arrow::datatypes::Schema;
+use datafusion::arrow::datatypes::{Schema, SchemaRef};
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::arrow::util::pretty;
use datafusion::catalog::TableProvider;
@@ -43,16 +44,16 @@ use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
-use tokio::task::JoinHandle;
+use pyo3::PyErr;
-use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
+use crate::errors::{py_datafusion_err, PyDataFusionError};
use crate::expr::sort_expr::to_sort_expressions;
use crate::physical_plan::PyExecutionPlan;
-use crate::record_batch::PyRecordBatchStream;
+use crate::record_batch::{poll_next_batch, PyRecordBatchStream};
use crate::sql::logical::PyLogicalPlan;
use crate::table::{PyTable, TempViewTable};
use crate::utils::{
- get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value,
validate_pycapsule, wait_for_future,
+ is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule,
wait_for_future,
};
use crate::{
errors::PyDataFusionResult,
@@ -61,6 +62,9 @@ use crate::{
use parking_lot::Mutex;
+/// File-level static CStr for the Arrow array stream capsule name.
+static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream");
+
// Type aliases to simplify very complex types used in this file and
// avoid compiler complaints about deeply nested types in struct fields.
type CachedBatches = Option<(Vec<RecordBatch>, bool)>;
@@ -341,6 +345,63 @@ impl PyDataFrame {
}
}
+/// Synchronous wrapper around partitioned [`SendableRecordBatchStream`]s used
+/// for the `__arrow_c_stream__` implementation.
+///
+/// It drains each partition's stream sequentially, yielding record batches in
+/// their original partition order. When a `projection` is set, each batch is
+/// converted via `record_batch_into_schema` to apply schema changes per batch.
+struct PartitionedDataFrameStreamReader {
+ streams: Vec<SendableRecordBatchStream>,
+ schema: SchemaRef,
+ projection: Option<SchemaRef>,
+ current: usize,
+}
+
+impl Iterator for PartitionedDataFrameStreamReader {
+ type Item = Result<RecordBatch, ArrowError>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ while self.current < self.streams.len() {
+ let stream = &mut self.streams[self.current];
+ let fut = poll_next_batch(stream);
+ let result = Python::with_gil(|py| wait_for_future(py, fut));
+
+ match result {
+ Ok(Ok(Some(batch))) => {
+ let batch = if let Some(ref schema) = self.projection {
+ match record_batch_into_schema(batch, schema.as_ref())
{
+ Ok(b) => b,
+ Err(e) => return Some(Err(e)),
+ }
+ } else {
+ batch
+ };
+ return Some(Ok(batch));
+ }
+ Ok(Ok(None)) => {
+ self.current += 1;
+ continue;
+ }
+ Ok(Err(e)) => {
+ return Some(Err(ArrowError::ExternalError(Box::new(e))));
+ }
+ Err(e) => {
+ return Some(Err(ArrowError::ExternalError(Box::new(e))));
+ }
+ }
+ }
+
+ None
+ }
+}
+
+impl RecordBatchReader for PartitionedDataFrameStreamReader {
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
#[pymethods]
impl PyDataFrame {
/// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1,
col2, col3]]`
@@ -924,8 +985,11 @@ impl PyDataFrame {
py: Python<'py>,
requested_schema: Option<Bound<'py, PyCapsule>>,
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
- let mut batches = wait_for_future(py,
self.df.as_ref().clone().collect())??;
+ let df = self.df.as_ref().clone();
+ let streams = spawn_future(py, async move {
df.execute_stream_partitioned().await })?;
+
let mut schema: Schema = self.df.schema().to_owned().into();
+ let mut projection: Option<SchemaRef> = None;
if let Some(schema_capsule) = requested_schema {
validate_pycapsule(&schema_capsule, "arrow_schema")?;
@@ -934,44 +998,38 @@ impl PyDataFrame {
let desired_schema = Schema::try_from(schema_ptr)?;
schema = project_schema(schema, desired_schema)?;
-
- batches = batches
- .into_iter()
- .map(|record_batch| record_batch_into_schema(record_batch,
&schema))
- .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
+ projection = Some(Arc::new(schema.clone()));
}
- let batches_wrapped = batches.into_iter().map(Ok);
+ let schema_ref = Arc::new(schema.clone());
- let reader = RecordBatchIterator::new(batches_wrapped,
Arc::new(schema));
+ let reader = PartitionedDataFrameStreamReader {
+ streams,
+ schema: schema_ref,
+ projection,
+ current: 0,
+ };
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
- let ffi_stream = FFI_ArrowArrayStream::new(reader);
- let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
- PyCapsule::new(py, ffi_stream,
Some(stream_capsule_name)).map_err(PyDataFusionError::from)
+ // Create the Arrow stream and wrap it in a PyCapsule. The default
+ // destructor provided by PyO3 will drop the stream unless ownership is
+ // transferred to PyArrow during import.
+ let stream = FFI_ArrowArrayStream::new(reader);
+ let name = CString::new(ARROW_ARRAY_STREAM_NAME.to_bytes()).unwrap();
+ let capsule = PyCapsule::new(py, stream, Some(name))?;
+ Ok(capsule)
}
fn execute_stream(&self, py: Python) ->
PyDataFusionResult<PyRecordBatchStream> {
- // create a Tokio runtime to run the async code
- let rt = &get_tokio_runtime().0;
let df = self.df.as_ref().clone();
- let fut:
JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
- rt.spawn(async move { df.execute_stream().await });
- let stream = wait_for_future(py, async {
fut.await.map_err(to_datafusion_err) })???;
+ let stream = spawn_future(py, async move { df.execute_stream().await
})?;
Ok(PyRecordBatchStream::new(stream))
}
fn execute_stream_partitioned(&self, py: Python) ->
PyResult<Vec<PyRecordBatchStream>> {
- // create a Tokio runtime to run the async code
- let rt = &get_tokio_runtime().0;
let df = self.df.as_ref().clone();
- let fut:
JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
- rt.spawn(async move { df.execute_stream_partitioned().await });
- let stream = wait_for_future(py, async {
fut.await.map_err(to_datafusion_err) })?
- .map_err(py_datafusion_err)?
- .map_err(py_datafusion_err)?;
-
- Ok(stream.into_iter().map(PyRecordBatchStream::new).collect())
+ let streams = spawn_future(py, async move {
df.execute_stream_partitioned().await })?;
+ Ok(streams.into_iter().map(PyRecordBatchStream::new).collect())
}
/// Convert to pandas dataframe with pyarrow
@@ -1132,7 +1190,11 @@ fn project_schema(from_schema: Schema, to_schema:
Schema) -> Result<Schema, Arro
merged_schema.project(&project_indices)
}
-
+// NOTE: `arrow::compute::cast` in combination with `RecordBatch::try_select`
or
+// DataFusion's `schema::cast_record_batch` do not fully cover the required
+// transformations here. They will not create missing columns and may insert
+// nulls for non-nullable fields without erroring. To maintain current behavior
+// we perform the casting and null checks manually.
fn record_batch_into_schema(
record_batch: RecordBatch,
schema: &Schema,
diff --git a/src/record_batch.rs b/src/record_batch.rs
index c3658cf4..00d023b7 100644
--- a/src/record_batch.rs
+++ b/src/record_batch.rs
@@ -84,15 +84,21 @@ impl PyRecordBatchStream {
}
}
+/// Polls the next batch from a `SendableRecordBatchStream`, converting the
`Option<Result<_>>` form.
+pub(crate) async fn poll_next_batch(
+ stream: &mut SendableRecordBatchStream,
+) -> datafusion::error::Result<Option<RecordBatch>> {
+ stream.next().await.transpose()
+}
+
async fn next_stream(
stream: Arc<Mutex<SendableRecordBatchStream>>,
sync: bool,
) -> PyResult<PyRecordBatch> {
let mut stream = stream.lock().await;
- match stream.next().await {
- Some(Ok(batch)) => Ok(batch.into()),
- Some(Err(e)) => Err(PyDataFusionError::from(e))?,
- None => {
+ match poll_next_batch(&mut stream).await {
+ Ok(Some(batch)) => Ok(batch.into()),
+ Ok(None) => {
// Depending on whether the iteration is sync or not, we raise
either a
// StopIteration or a StopAsyncIteration
if sync {
@@ -101,5 +107,6 @@ async fn next_stream(
Err(PyStopAsyncIteration::new_err("stream exhausted"))
}
}
+ Err(e) => Err(PyDataFusionError::from(e))?,
}
}
diff --git a/src/utils.rs b/src/utils.rs
index 0fcfadce..9624f7d7 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -18,7 +18,7 @@
use crate::errors::py_datafusion_err;
use crate::{
common::data_type::PyScalarValue,
- errors::{PyDataFusionError, PyDataFusionResult},
+ errors::{to_datafusion_err, PyDataFusionError, PyDataFusionResult},
TokioRuntime,
};
use datafusion::{
@@ -33,7 +33,7 @@ use std::{
sync::{Arc, OnceLock},
time::Duration,
};
-use tokio::{runtime::Runtime, time::sleep};
+use tokio::{runtime::Runtime, task::JoinHandle, time::sleep};
/// Utility to get the Tokio Runtime from Python
#[inline]
@@ -92,6 +92,35 @@ where
})
}
+/// Spawn a [`Future`] on the Tokio runtime and wait for completion
+/// while respecting Python signal handling.
+pub(crate) fn spawn_future<F, T>(py: Python, fut: F) -> PyDataFusionResult<T>
+where
+ F: Future<Output = datafusion::common::Result<T>> + Send + 'static,
+ T: Send + 'static,
+{
+ let rt = &get_tokio_runtime().0;
+ let handle: JoinHandle<datafusion::common::Result<T>> = rt.spawn(fut);
+ // Wait for the join handle while respecting Python signal handling.
+ // We handle errors in two steps so `?` maps the error types correctly:
+ // 1) convert any Python-related error from `wait_for_future` into
`PyDataFusionError`
+ // 2) convert any DataFusion error (inner result) into `PyDataFusionError`
+ let inner_result = wait_for_future(py, async {
+ // handle.await yields `Result<datafusion::common::Result<T>,
JoinError>`
+ // map JoinError into a DataFusion error so the async block returns
+ // `datafusion::common::Result<T>` (i.e. Result<T, DataFusionError>)
+ match handle.await {
+ Ok(inner) => inner,
+ Err(join_err) => Err(to_datafusion_err(join_err)),
+ }
+ })?; // converts PyErr -> PyDataFusionError
+
+ // `inner_result` is `datafusion::common::Result<T>`; use `?` to convert
+ // the inner DataFusion error into `PyDataFusionError` via `From` and
+ // return the inner `T` on success.
+ Ok(inner_result?)
+}
+
pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
Ok(match value {
"immutable" => Volatility::Immutable,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]