This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b262be  Support async iteration of RecordBatchStream (#975)
4b262be is described below

commit 4b262be15202f5efb9a963faf66452f7fb0bbad3
Author: Kyle Barron <[email protected]>
AuthorDate: Thu Jan 9 03:54:38 2025 -0800

    Support async iteration of RecordBatchStream (#975)
    
    * Support async iteration of RecordBatchStream
    
    * use __anext__
    
    * use await
    
    * fix failing test
    
    * Since we are raising an error instead of returning a None, we can update 
the type hint.
    
    ---------
    
    Co-authored-by: Tim Saucer <[email protected]>
---
 Cargo.lock                        | 14 +++++++++++
 Cargo.toml                        |  3 ++-
 python/datafusion/record_batch.py | 16 +++++++-----
 python/tests/test_dataframe.py    |  4 +--
 src/record_batch.rs               | 51 +++++++++++++++++++++++++++++++--------
 5 files changed, 69 insertions(+), 19 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index d1f291b..352771c 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1303,6 +1303,7 @@ dependencies = [
  "prost",
  "prost-types",
  "pyo3",
+ "pyo3-async-runtimes",
  "pyo3-build-config",
  "tokio",
  "url",
@@ -2672,6 +2673,19 @@ dependencies = [
  "unindent",
 ]
 
+[[package]]
+name = "pyo3-async-runtimes"
+version = "0.22.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "2529f0be73ffd2be0cc43c013a640796558aa12d7ca0aab5cc14f375b4733031"
+dependencies = [
+ "futures",
+ "once_cell",
+ "pin-project-lite",
+ "pyo3",
+ "tokio",
+]
+
 [[package]]
 name = "pyo3-build-config"
 version = "0.22.6"
diff --git a/Cargo.toml b/Cargo.toml
index 703fc5a..d288446 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -36,6 +36,7 @@ substrait = ["dep:datafusion-substrait"]
 [dependencies]
 tokio = { version = "1.41", features = ["macros", "rt", "rt-multi-thread", 
"sync"] }
 pyo3 = { version = "0.22", features = ["extension-module", "abi3", 
"abi3-py38"] }
+pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"]}
 arrow = { version = "53", features = ["pyarrow"] }
 datafusion = { version = "43.0.0", features = ["pyarrow", "avro", 
"unicode_expressions"] }
 datafusion-substrait = { version = "43.0.0", optional = true }
@@ -60,4 +61,4 @@ crate-type = ["cdylib", "rlib"]
 
 [profile.release]
 lto = true
-codegen-units = 1
\ No newline at end of file
+codegen-units = 1
diff --git a/python/datafusion/record_batch.py 
b/python/datafusion/record_batch.py
index 44936f7..75e5899 100644
--- a/python/datafusion/record_batch.py
+++ b/python/datafusion/record_batch.py
@@ -57,20 +57,24 @@ class RecordBatchStream:
         """This constructor is typically not called by the end user."""
         self.rbs = record_batch_stream
 
-    def next(self) -> RecordBatch | None:
+    def next(self) -> RecordBatch:
         """See :py:func:`__next__` for the iterator function."""
-        try:
-            next_batch = next(self)
-        except StopIteration:
-            return None
+        return next(self)
 
-        return next_batch
+    async def __anext__(self) -> RecordBatch:
+        """Async iterator function."""
+        next_batch = await self.rbs.__anext__()
+        return RecordBatch(next_batch)
 
     def __next__(self) -> RecordBatch:
         """Iterator function."""
         next_batch = next(self.rbs)
         return RecordBatch(next_batch)
 
+    def __aiter__(self) -> typing_extensions.Self:
+        """Async iterator function."""
+        return self
+
     def __iter__(self) -> typing_extensions.Self:
         """Iterator function."""
         return self
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index b82f95e..e3bd1b2 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -761,8 +761,8 @@ def test_execution_plan(aggregate_df):
     batch = stream.next()
     assert batch is not None
     # there should be no more batches
-    batch = stream.next()
-    assert batch is None
+    with pytest.raises(StopIteration):
+        stream.next()
 
 
 def test_repartition(df):
diff --git a/src/record_batch.rs b/src/record_batch.rs
index 427807f..eacdb58 100644
--- a/src/record_batch.rs
+++ b/src/record_batch.rs
@@ -15,13 +15,17 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::sync::Arc;
+
 use crate::utils::wait_for_future;
 use datafusion::arrow::pyarrow::ToPyArrow;
 use datafusion::arrow::record_batch::RecordBatch;
 use datafusion::physical_plan::SendableRecordBatchStream;
 use futures::StreamExt;
+use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration};
 use pyo3::prelude::*;
 use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
+use tokio::sync::Mutex;
 
 #[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
 pub struct PyRecordBatch {
@@ -43,31 +47,58 @@ impl From<RecordBatch> for PyRecordBatch {
 
 #[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
 pub struct PyRecordBatchStream {
-    stream: SendableRecordBatchStream,
+    stream: Arc<Mutex<SendableRecordBatchStream>>,
 }
 
 impl PyRecordBatchStream {
     pub fn new(stream: SendableRecordBatchStream) -> Self {
-        Self { stream }
+        Self {
+            stream: Arc::new(Mutex::new(stream)),
+        }
     }
 }
 
 #[pymethods]
 impl PyRecordBatchStream {
-    fn next(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
-        let result = self.stream.next();
-        match wait_for_future(py, result) {
-            None => Ok(None),
-            Some(Ok(b)) => Ok(Some(b.into())),
-            Some(Err(e)) => Err(e.into()),
-        }
+    fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
+        let stream = self.stream.clone();
+        wait_for_future(py, next_stream(stream, true))
     }
 
-    fn __next__(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
+    fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {
         self.next(py)
     }
 
+    fn __anext__<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, 
PyAny>> {
+        let stream = self.stream.clone();
+        pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream, 
false))
+    }
+
     fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
         slf
     }
+
+    fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
+        slf
+    }
+}
+
+async fn next_stream(
+    stream: Arc<Mutex<SendableRecordBatchStream>>,
+    sync: bool,
+) -> PyResult<PyRecordBatch> {
+    let mut stream = stream.lock().await;
+    match stream.next().await {
+        Some(Ok(batch)) => Ok(batch.into()),
+        Some(Err(e)) => Err(e.into()),
+        None => {
+            // Depending on whether the iteration is sync or not, we raise 
either a
+            // StopIteration or a StopAsyncIteration
+            if sync {
+                Err(PyStopIteration::new_err("stream exhausted"))
+            } else {
+                Err(PyStopAsyncIteration::new_err("stream exhausted"))
+            }
+        }
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to