This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 4b262be Support async iteration of RecordBatchStream (#975)
4b262be is described below
commit 4b262be15202f5efb9a963faf66452f7fb0bbad3
Author: Kyle Barron <[email protected]>
AuthorDate: Thu Jan 9 03:54:38 2025 -0800
Support async iteration of RecordBatchStream (#975)
* Support async iteration of RecordBatchStream
* use __anext__
* use await
* fix failing test
* Since we are raising an error instead of returning a None, we can update
the type hint.
---------
Co-authored-by: Tim Saucer <[email protected]>
---
Cargo.lock | 14 +++++++++++
Cargo.toml | 3 ++-
python/datafusion/record_batch.py | 16 +++++++-----
python/tests/test_dataframe.py | 4 +--
src/record_batch.rs | 51 +++++++++++++++++++++++++++++++--------
5 files changed, 69 insertions(+), 19 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index d1f291b..352771c 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1303,6 +1303,7 @@ dependencies = [
"prost",
"prost-types",
"pyo3",
+ "pyo3-async-runtimes",
"pyo3-build-config",
"tokio",
"url",
@@ -2672,6 +2673,19 @@ dependencies = [
"unindent",
]
+[[package]]
+name = "pyo3-async-runtimes"
+version = "0.22.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2529f0be73ffd2be0cc43c013a640796558aa12d7ca0aab5cc14f375b4733031"
+dependencies = [
+ "futures",
+ "once_cell",
+ "pin-project-lite",
+ "pyo3",
+ "tokio",
+]
+
[[package]]
name = "pyo3-build-config"
version = "0.22.6"
diff --git a/Cargo.toml b/Cargo.toml
index 703fc5a..d288446 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -36,6 +36,7 @@ substrait = ["dep:datafusion-substrait"]
[dependencies]
tokio = { version = "1.41", features = ["macros", "rt", "rt-multi-thread",
"sync"] }
pyo3 = { version = "0.22", features = ["extension-module", "abi3",
"abi3-py38"] }
+pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"]}
arrow = { version = "53", features = ["pyarrow"] }
datafusion = { version = "43.0.0", features = ["pyarrow", "avro",
"unicode_expressions"] }
datafusion-substrait = { version = "43.0.0", optional = true }
@@ -60,4 +61,4 @@ crate-type = ["cdylib", "rlib"]
[profile.release]
lto = true
-codegen-units = 1
\ No newline at end of file
+codegen-units = 1
diff --git a/python/datafusion/record_batch.py
b/python/datafusion/record_batch.py
index 44936f7..75e5899 100644
--- a/python/datafusion/record_batch.py
+++ b/python/datafusion/record_batch.py
@@ -57,20 +57,24 @@ class RecordBatchStream:
"""This constructor is typically not called by the end user."""
self.rbs = record_batch_stream
- def next(self) -> RecordBatch | None:
+ def next(self) -> RecordBatch:
"""See :py:func:`__next__` for the iterator function."""
- try:
- next_batch = next(self)
- except StopIteration:
- return None
+ return next(self)
- return next_batch
+ async def __anext__(self) -> RecordBatch:
+ """Async iterator function."""
+ next_batch = await self.rbs.__anext__()
+ return RecordBatch(next_batch)
def __next__(self) -> RecordBatch:
"""Iterator function."""
next_batch = next(self.rbs)
return RecordBatch(next_batch)
+ def __aiter__(self) -> typing_extensions.Self:
+ """Async iterator function."""
+ return self
+
def __iter__(self) -> typing_extensions.Self:
"""Iterator function."""
return self
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index b82f95e..e3bd1b2 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -761,8 +761,8 @@ def test_execution_plan(aggregate_df):
batch = stream.next()
assert batch is not None
# there should be no more batches
- batch = stream.next()
- assert batch is None
+ with pytest.raises(StopIteration):
+ stream.next()
def test_repartition(df):
diff --git a/src/record_batch.rs b/src/record_batch.rs
index 427807f..eacdb58 100644
--- a/src/record_batch.rs
+++ b/src/record_batch.rs
@@ -15,13 +15,17 @@
// specific language governing permissions and limitations
// under the License.
+use std::sync::Arc;
+
use crate::utils::wait_for_future;
use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::physical_plan::SendableRecordBatchStream;
use futures::StreamExt;
+use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration};
use pyo3::prelude::*;
use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
+use tokio::sync::Mutex;
#[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
pub struct PyRecordBatch {
@@ -43,31 +47,58 @@ impl From<RecordBatch> for PyRecordBatch {
#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
pub struct PyRecordBatchStream {
- stream: SendableRecordBatchStream,
+ stream: Arc<Mutex<SendableRecordBatchStream>>,
}
impl PyRecordBatchStream {
pub fn new(stream: SendableRecordBatchStream) -> Self {
- Self { stream }
+ Self {
+ stream: Arc::new(Mutex::new(stream)),
+ }
}
}
#[pymethods]
impl PyRecordBatchStream {
- fn next(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
- let result = self.stream.next();
- match wait_for_future(py, result) {
- None => Ok(None),
- Some(Ok(b)) => Ok(Some(b.into())),
- Some(Err(e)) => Err(e.into()),
- }
+ fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
+ let stream = self.stream.clone();
+ wait_for_future(py, next_stream(stream, true))
}
- fn __next__(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
+ fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {
self.next(py)
}
+ fn __anext__<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py,
PyAny>> {
+ let stream = self.stream.clone();
+ pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream,
false))
+ }
+
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
+
+ fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
+ slf
+ }
+}
+
+async fn next_stream(
+ stream: Arc<Mutex<SendableRecordBatchStream>>,
+ sync: bool,
+) -> PyResult<PyRecordBatch> {
+ let mut stream = stream.lock().await;
+ match stream.next().await {
+ Some(Ok(batch)) => Ok(batch.into()),
+ Some(Err(e)) => Err(e.into()),
+ None => {
+ // Depending on whether the iteration is sync or not, we raise
either a
+ // StopIteration or a StopAsyncIteration
+ if sync {
+ Err(PyStopIteration::new_err("stream exhausted"))
+ } else {
+ Err(PyStopAsyncIteration::new_err("stream exhausted"))
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]