This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new cedaf8a6a Add PyArrow integration test for C Stream Interface (#1848)
cedaf8a6a is described below

commit cedaf8a6ab55826c34f3b1bc9a21dbaf3e0328bc
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Jun 13 15:29:46 2022 -0700

    Add PyArrow integration test for C Stream Interface (#1848)
    
    * Add PyArrow integration test for ArrowArrayStream
    
    * Trigger Build
---
 arrow-pyarrow-integration-testing/src/lib.rs       |  9 +++++
 .../tests/test_sql.py                              | 16 +++++++++
 arrow/src/ffi_stream.rs                            | 24 +++++--------
 arrow/src/pyarrow.rs                               | 42 +++++++++++++++++++++-
 4 files changed, 74 insertions(+), 17 deletions(-)

diff --git a/arrow-pyarrow-integration-testing/src/lib.rs 
b/arrow-pyarrow-integration-testing/src/lib.rs
index 26c09d64d..086b21834 100644
--- a/arrow-pyarrow-integration-testing/src/lib.rs
+++ b/arrow-pyarrow-integration-testing/src/lib.rs
@@ -27,6 +27,7 @@ use arrow::array::{ArrayData, ArrayRef, Int64Array};
 use arrow::compute::kernels;
 use arrow::datatypes::{DataType, Field, Schema};
 use arrow::error::ArrowError;
+use arrow::ffi_stream::ArrowArrayStreamReader;
 use arrow::pyarrow::PyArrowConvert;
 use arrow::record_batch::RecordBatch;
 
@@ -111,6 +112,13 @@ fn round_trip_record_batch(obj: RecordBatch) -> 
PyResult<RecordBatch> {
     Ok(obj)
 }
 
+#[pyfunction]
+fn round_trip_record_batch_reader(
+    obj: ArrowArrayStreamReader,
+) -> PyResult<ArrowArrayStreamReader> {
+    Ok(obj)
+}
+
 #[pymodule]
 fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> 
PyResult<()> {
     m.add_wrapped(wrap_pyfunction!(double))?;
@@ -122,5 +130,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: 
&PyModule) -> PyResult<()>
     m.add_wrapped(wrap_pyfunction!(round_trip_schema))?;
     m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
     m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
+    m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
     Ok(())
 }
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py 
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index 324956c9c..a17ba6d06 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -303,3 +303,19 @@ def test_dictionary_python():
     assert a == b
     del a
     del b
+
+def test_record_batch_reader():
+    """
+    Python -> Rust -> Python
+    """
+    schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': 
b'value1'})
+    batches = [
+        pa.record_batch([[[1], [2, 42]]], schema),
+        pa.record_batch([[None, [], [5, 6]]], schema),
+    ]
+    a = pa.RecordBatchReader.from_batches(schema, batches)
+    b = rust.round_trip_record_batch_reader(a)
+
+    assert b.schema == schema
+    got_batches = list(b)
+    assert got_batches == batches
diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs
index ab4caea36..bfc62b888 100644
--- a/arrow/src/ffi_stream.rs
+++ b/arrow/src/ffi_stream.rs
@@ -198,13 +198,6 @@ impl ExportedArrayStream {
     }
 
     pub fn get_schema(&mut self, out: *mut FFI_ArrowSchema) -> i32 {
-        unsafe {
-            match (*out).release {
-                None => (),
-                Some(release) => release(out),
-            };
-        };
-
         let mut private_data = self.get_private_data();
         let reader = &private_data.batch_reader;
 
@@ -224,18 +217,17 @@ impl ExportedArrayStream {
     }
 
     pub fn get_next(&mut self, out: *mut FFI_ArrowArray) -> i32 {
-        unsafe {
-            match (*out).release {
-                None => (),
-                Some(release) => release(out),
-            };
-        };
-
         let mut private_data = self.get_private_data();
         let reader = &mut private_data.batch_reader;
 
         let ret_code = match reader.next() {
-            None => 0,
+            None => {
+                // Marks ArrowArray released to indicate reaching the end of 
stream.
+                unsafe {
+                    (*out).release = None;
+                }
+                0
+            }
             Some(next_batch) => {
                 if let Ok(batch) = next_batch {
                     let struct_array = StructArray::from(batch);
@@ -275,7 +267,7 @@ fn get_error_code(err: &ArrowError) -> i32 {
 /// Struct used to fetch `RecordBatch` from the C Stream Interface.
 /// Its main responsibility is to expose `RecordBatchReader` functionality
 /// that requires [FFI_ArrowArrayStream].
-#[derive(Debug)]
+#[derive(Debug, Clone)]
 pub struct ArrowArrayStreamReader {
     stream: Arc<FFI_ArrowArrayStream>,
     schema: SchemaRef,
diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs
index 62e6316b6..3ae5b3b99 100644
--- a/arrow/src/pyarrow.rs
+++ b/arrow/src/pyarrow.rs
@@ -24,13 +24,16 @@ use std::sync::Arc;
 use pyo3::ffi::Py_uintptr_t;
 use pyo3::import_exception;
 use pyo3::prelude::*;
-use pyo3::types::PyList;
+use pyo3::types::{PyList, PyTuple};
 
 use crate::array::{Array, ArrayData, ArrayRef};
 use crate::datatypes::{DataType, Field, Schema};
 use crate::error::ArrowError;
 use crate::ffi;
 use crate::ffi::FFI_ArrowSchema;
+use crate::ffi_stream::{
+    export_reader_into_raw, ArrowArrayStreamReader, FFI_ArrowArrayStream,
+};
 use crate::record_batch::RecordBatch;
 
 import_exception!(pyarrow, ArrowException);
@@ -198,6 +201,42 @@ impl PyArrowConvert for RecordBatch {
     }
 }
 
+impl PyArrowConvert for ArrowArrayStreamReader {
+    fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        // prepare a pointer to receive the stream struct
+        let stream = Box::new(FFI_ArrowArrayStream::empty());
+        let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream;
+
+        // make the conversion through PyArrow's private API
+        // this changes the pointer's memory and is thus unsafe.
+        // In particular, `_export_to_c` can go out of bounds
+        let args = PyTuple::new(value.py(), &[stream_ptr as Py_uintptr_t]);
+        value.call_method1("_export_to_c", args)?;
+
+        let stream_reader =
+            unsafe { ArrowArrayStreamReader::from_raw(stream_ptr).unwrap() };
+
+        unsafe {
+            Box::from_raw(stream_ptr);
+        }
+
+        Ok(stream_reader)
+    }
+
+    fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
+        let stream = Box::new(FFI_ArrowArrayStream::empty());
+        let stream_ptr = Box::into_raw(stream) as *mut FFI_ArrowArrayStream;
+
+        unsafe { export_reader_into_raw(Box::new(self.clone()), stream_ptr) };
+
+        let module = py.import("pyarrow")?;
+        let class = module.getattr("RecordBatchReader")?;
+        let args = PyTuple::new(py, &[stream_ptr as Py_uintptr_t]);
+        let reader = class.call_method1("_import_from_c", args)?;
+        Ok(PyObject::from(reader))
+    }
+}
+
 macro_rules! add_conversion {
     ($typ:ty) => {
         impl<'source> FromPyObject<'source> for $typ {
@@ -219,3 +258,4 @@ add_conversion!(Field);
 add_conversion!(Schema);
 add_conversion!(ArrayData);
 add_conversion!(RecordBatch);
+add_conversion!(ArrowArrayStreamReader);

Reply via email to