This is an automated email from the ASF dual-hosted git repository.

wjones127 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 50f161eafb fix ownership of c stream error (#4660)
50f161eafb is described below

commit 50f161eafb4062ffa13e3399c49d8f98e8dbfb6d
Author: Will Jones <[email protected]>
AuthorDate: Mon Aug 7 15:16:23 2023 -0700

    fix ownership of c stream error (#4660)
    
    * fix ownership of c stream error
    
    * add pyarrow integration test
---
 arrow-pyarrow-integration-testing/src/lib.rs       | 13 ++++
 .../tests/test_sql.py                              | 12 ++++
 arrow/src/ffi_stream.rs                            | 75 ++++++++++++++++------
 3 files changed, 81 insertions(+), 19 deletions(-)

diff --git a/arrow-pyarrow-integration-testing/src/lib.rs 
b/arrow-pyarrow-integration-testing/src/lib.rs
index 89395bd2ed..adcec769f2 100644
--- a/arrow-pyarrow-integration-testing/src/lib.rs
+++ b/arrow-pyarrow-integration-testing/src/lib.rs
@@ -21,6 +21,7 @@
 use std::sync::Arc;
 
 use arrow::array::new_empty_array;
+use pyo3::exceptions::PyValueError;
 use pyo3::prelude::*;
 use pyo3::wrap_pyfunction;
 
@@ -140,6 +141,17 @@ fn round_trip_record_batch_reader(
     Ok(obj)
 }
 
+#[pyfunction]
+fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> 
PyResult<()> {
+    // This makes sure we can correctly consume a RBR and return the error,
+    // ensuring the error can live beyond the lifetime of the RBR.
+    let batches = obj.0.collect::<Result<Vec<RecordBatch>, ArrowError>>();
+    match batches {
+        Ok(_) => Ok(()),
+        Err(err) => Err(PyValueError::new_err(err.to_string())),
+    }
+}
+
 #[pymodule]
 fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> 
PyResult<()> {
     m.add_wrapped(wrap_pyfunction!(double))?;
@@ -153,5 +165,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: 
&PyModule) -> PyResult<()>
     m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
     m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
     m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
+    m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
     Ok(())
 }
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py 
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index a7c6b34a44..92782b9ed4 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -409,6 +409,18 @@ def test_record_batch_reader():
     got_batches = list(b)
     assert got_batches == batches
 
+def test_record_batch_reader_error():
+    schema = pa.schema([('ints', pa.list_(pa.int32()))])
+
+    def iter_batches():
+        yield pa.record_batch([[[1], [2, 42]]], schema)
+        raise ValueError("test error")
+
+    reader = pa.RecordBatchReader.from_batches(schema, iter_batches())
+
+    with pytest.raises(ValueError, match="test error"):
+        rust.reader_return_errors(reader)
+
 def test_reject_other_classes():
     # Arbitrary type that is not a PyArrow type
     not_pyarrow = ["hello"]
diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs
index 7d6689a890..a9d2e8ab6b 100644
--- a/arrow/src/ffi_stream.rs
+++ b/arrow/src/ffi_stream.rs
@@ -54,6 +54,7 @@
 //! }
 //! ```
 
+use std::ffi::CStr;
 use std::ptr::addr_of;
 use std::{
     convert::TryFrom,
@@ -120,7 +121,7 @@ unsafe extern "C" fn release_stream(stream: *mut 
FFI_ArrowArrayStream) {
 
 struct StreamPrivateData {
     batch_reader: Box<dyn RecordBatchReader + Send>,
-    last_error: String,
+    last_error: Option<CString>,
 }
 
 // The callback used to get array schema
@@ -142,8 +143,12 @@ unsafe extern "C" fn get_next(
 // The callback used to get the error from last operation on the 
`FFI_ArrowArrayStream`
 unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) -> 
*const c_char {
     let mut ffi_stream = ExportedArrayStream { stream };
-    let last_error = ffi_stream.get_last_error();
-    CString::new(last_error.as_str()).unwrap().into_raw()
+    // The consumer should not take ownership of this string, we should return
+    // a const pointer to it.
+    match ffi_stream.get_last_error() {
+        Some(err_string) => err_string.as_ptr(),
+        None => std::ptr::null(),
+    }
 }
 
 impl Drop for FFI_ArrowArrayStream {
@@ -160,7 +165,7 @@ impl FFI_ArrowArrayStream {
     pub fn new(batch_reader: Box<dyn RecordBatchReader + Send>) -> Self {
         let private_data = Box::new(StreamPrivateData {
             batch_reader,
-            last_error: String::new(),
+            last_error: None,
         });
 
         Self {
@@ -206,7 +211,10 @@ impl ExportedArrayStream {
                 0
             }
             Err(ref err) => {
-                private_data.last_error = err.to_string();
+                private_data.last_error = Some(
+                    CString::new(err.to_string())
+                        .expect("Error string has a null byte in it."),
+                );
                 get_error_code(err)
             }
         }
@@ -231,15 +239,18 @@ impl ExportedArrayStream {
                     0
                 } else {
                     let err = &next_batch.unwrap_err();
-                    private_data.last_error = err.to_string();
+                    private_data.last_error = Some(
+                        CString::new(err.to_string())
+                            .expect("Error string has a null byte in it."),
+                    );
                     get_error_code(err)
                 }
             }
         }
     }
 
-    pub fn get_last_error(&mut self) -> &String {
-        &self.get_private_data().last_error
+    pub fn get_last_error(&mut self) -> Option<&CString> {
+        self.get_private_data().last_error.as_ref()
     }
 }
 
@@ -312,19 +323,15 @@ impl ArrowArrayStreamReader {
 
     /// Get the last error from `ArrowArrayStreamReader`
     fn get_stream_last_error(&mut self) -> Option<String> {
-        self.stream.get_last_error?;
-
-        let error_str = unsafe {
-            let c_str =
-                self.stream.get_last_error.unwrap()(&mut self.stream) as *mut 
c_char;
-            CString::from_raw(c_str).into_string()
-        };
+        let get_last_error = self.stream.get_last_error?;
 
-        if let Err(err) = error_str {
-            Some(err.to_string())
-        } else {
-            Some(error_str.unwrap())
+        let error_str = unsafe { get_last_error(&mut self.stream) };
+        if error_str.is_null() {
+            return None;
         }
+
+        let error_str = unsafe { CStr::from_ptr(error_str) };
+        Some(error_str.to_string_lossy().to_string())
     }
 }
 
@@ -381,6 +388,8 @@ pub unsafe fn export_reader_into_raw(
 
 #[cfg(test)]
 mod tests {
+    use arrow_schema::DataType;
+
     use super::*;
 
     use crate::array::Int32Array;
@@ -503,4 +512,32 @@ mod tests {
 
         _test_round_trip_import(vec![array.clone(), array.clone(), array])
     }
+
+    #[test]
+    fn test_error_import() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![Field::new("a", 
DataType::Int32, true)]));
+
+        let iter =
+            
Box::new(vec![Err(ArrowError::MemoryError("".to_string()))].into_iter());
+
+        let reader = TestRecordBatchReader::new(schema.clone(), iter);
+
+        // Import through `FFI_ArrowArrayStream` as `ArrowArrayStreamReader`
+        let stream = FFI_ArrowArrayStream::new(reader);
+        let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap();
+
+        let imported_schema = stream_reader.schema();
+        assert_eq!(imported_schema, schema);
+
+        let mut produced_batches = vec![];
+        for batch in stream_reader {
+            produced_batches.push(batch);
+        }
+
+        // The results should outlive the lifetime of the stream itself.
+        assert_eq!(produced_batches.len(), 1);
+        assert!(produced_batches[0].is_err());
+
+        Ok(())
+    }
 }

Reply via email to