This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new cedaf8a6a Add PyArrow integration test for C Stream Interface (#1848)
cedaf8a6a is described below
commit cedaf8a6ab55826c34f3b1bc9a21dbaf3e0328bc
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Jun 13 15:29:46 2022 -0700
Add PyArrow integration test for C Stream Interface (#1848)
* Add PyArrow integration test for ArrowArrayStream
* Trigger Build
---
arrow-pyarrow-integration-testing/src/lib.rs | 9 +++++
.../tests/test_sql.py | 16 +++++++++
arrow/src/ffi_stream.rs | 24 +++++--------
arrow/src/pyarrow.rs | 42 +++++++++++++++++++++-
4 files changed, 74 insertions(+), 17 deletions(-)
diff --git a/arrow-pyarrow-integration-testing/src/lib.rs
b/arrow-pyarrow-integration-testing/src/lib.rs
index 26c09d64d..086b21834 100644
--- a/arrow-pyarrow-integration-testing/src/lib.rs
+++ b/arrow-pyarrow-integration-testing/src/lib.rs
@@ -27,6 +27,7 @@ use arrow::array::{ArrayData, ArrayRef, Int64Array};
use arrow::compute::kernels;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
+use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::PyArrowConvert;
use arrow::record_batch::RecordBatch;
@@ -111,6 +112,13 @@ fn round_trip_record_batch(obj: RecordBatch) ->
PyResult<RecordBatch> {
Ok(obj)
}
+#[pyfunction]
+fn round_trip_record_batch_reader(
+ obj: ArrowArrayStreamReader,
+) -> PyResult<ArrowArrayStreamReader> {
+ Ok(obj)
+}
+
#[pymodule]
fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) ->
PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double))?;
@@ -122,5 +130,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m:
&PyModule) -> PyResult<()>
m.add_wrapped(wrap_pyfunction!(round_trip_schema))?;
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
+ m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
Ok(())
}
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index 324956c9c..a17ba6d06 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -303,3 +303,19 @@ def test_dictionary_python():
assert a == b
del a
del b
+
+def test_record_batch_reader():
+ """
+ Python -> Rust -> Python
+ """
+ schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1':
b'value1'})
+ batches = [
+ pa.record_batch([[[1], [2, 42]]], schema),
+ pa.record_batch([[None, [], [5, 6]]], schema),
+ ]
+ a = pa.RecordBatchReader.from_batches(schema, batches)
+ b = rust.round_trip_record_batch_reader(a)
+
+ assert b.schema == schema
+ got_batches = list(b)
+ assert got_batches == batches
diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs
index ab4caea36..bfc62b888 100644
--- a/arrow/src/ffi_stream.rs
+++ b/arrow/src/ffi_stream.rs
@@ -198,13 +198,6 @@ impl ExportedArrayStream {
}
pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 {
- unsafe {
- match (*out).release {
- None => (),
- Some(release) => release(out),
- };
- };
-
let mut private_data = self.get_private_data();
let reader = &private_data.batch_reader;
@@ -224,18 +217,17 @@ impl ExportedArrayStream {
}
pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 {
- unsafe {
- match (*out).release {
- None => (),
- Some(release) => release(out),
- };
- };
-
let mut private_data = self.get_private_data();
let reader = &mut private_data.batch_reader;
let ret_code = match reader.next() {
- None => 0,
+ None => {
+ // Marks ArrowArray released to indicate reaching the end of
stream.
+ unsafe {
+ (*out).release = None;
+ }
+ 0
+ }
Some(next_batch) => {
if let Ok(batch) = next_batch {
let struct_array = StructArray::from(batch);
@@ -275,7 +267,7 @@ fn get_error_code(err: &ArrowError) -> i32 {
/// Struct used to fetch `RecordBatch` from the C Stream Interface.
/// Its main responsibility is to expose `RecordBatchReader` functionality
/// that requires [FFI_ArrowArrayStream].
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct ArrowArrayStreamReader {
stream: Arc<FFI_ArrowArrayStream>,
schema: SchemaRef,
diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs
index 62e6316b6..3ae5b3b99 100644
--- a/arrow/src/pyarrow.rs
+++ b/arrow/src/pyarrow.rs
@@ -24,13 +24,16 @@ use std::sync::Arc;
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
-use pyo3::types::PyList;
+use pyo3::types::{PyList, PyTuple};
use crate::array::{Array, ArrayData, ArrayRef};
use crate::datatypes::{DataType, Field, Schema};
use crate::error::ArrowError;
use crate::ffi;
use crate::ffi::FFI_ArrowSchema;
+use crate::ffi_stream::{
+ export_reader_into_raw, ArrowArrayStreamReader, FFI_ArrowArrayStream,
+};
use crate::record_batch::RecordBatch;
import_exception!(pyarrow, ArrowException);
@@ -198,6 +201,42 @@ impl PyArrowConvert for RecordBatch {
}
}
+impl PyArrowConvert for ArrowArrayStreamReader {
+ fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ // prepare a pointer to receive the stream struct
+ let stream = Box::new(FFI_ArrowArrayStream::empty());
+ let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream;
+
+ // make the conversion through PyArrow's private API
+ // this changes the pointer's memory and is thus unsafe.
+ // In particular, `_export_to_c` can go out of bounds
+ let args = PyTuple::new(value.py(), &[stream_ptr as Py_uintptr_t]);
+ value.call_method1("_export_to_c", args)?;
+
+ let stream_reader =
+ unsafe { ArrowArrayStreamReader::from_raw(stream_ptr).unwrap() };
+
+ unsafe {
+ Box::from_raw(stream_ptr);
+ }
+
+ Ok(stream_reader)
+ }
+
+ fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
+ let stream = Box::new(FFI_ArrowArrayStream::empty());
+ let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream;
+
+ unsafe { export_reader_into_raw(Box::new(self.clone()), stream_ptr) };
+
+ let module = py.import("pyarrow")?;
+ let class = module.getattr("RecordBatchReader")?;
+ let args = PyTuple::new(py, &[stream_ptr as Py_uintptr_t]);
+ let reader = class.call_method1("_import_from_c", args)?;
+ Ok(PyObject::from(reader))
+ }
+}
+
macro_rules! add_conversion {
($typ:ty) => {
impl<'source> FromPyObject<'source> for $typ {
@@ -219,3 +258,4 @@ add_conversion!(Field);
add_conversion!(Schema);
add_conversion!(ArrayData);
add_conversion!(RecordBatch);
+add_conversion!(ArrowArrayStreamReader);