This is an automated email from the ASF dual-hosted git repository.
wjones127 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 50f161eafb fix ownership of c stream error (#4660)
50f161eafb is described below
commit 50f161eafb4062ffa13e3399c49d8f98e8dbfb6d
Author: Will Jones <[email protected]>
AuthorDate: Mon Aug 7 15:16:23 2023 -0700
fix ownership of c stream error (#4660)
* fix ownership of c stream error
* add pyarrow integration test
---
arrow-pyarrow-integration-testing/src/lib.rs | 13 ++++
.../tests/test_sql.py | 12 ++++
arrow/src/ffi_stream.rs | 75 ++++++++++++++++------
3 files changed, 81 insertions(+), 19 deletions(-)
diff --git a/arrow-pyarrow-integration-testing/src/lib.rs
b/arrow-pyarrow-integration-testing/src/lib.rs
index 89395bd2ed..adcec769f2 100644
--- a/arrow-pyarrow-integration-testing/src/lib.rs
+++ b/arrow-pyarrow-integration-testing/src/lib.rs
@@ -21,6 +21,7 @@
use std::sync::Arc;
use arrow::array::new_empty_array;
+use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
@@ -140,6 +141,17 @@ fn round_trip_record_batch_reader(
Ok(obj)
}
+#[pyfunction]
+fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) ->
PyResult<()> {
+ // This makes sure we can correctly consume a RBR and return the error,
+ // ensuring the error can live beyond the lifetime of the RBR.
+ let batches = obj.0.collect::<Result<Vec<RecordBatch>, ArrowError>>();
+ match batches {
+ Ok(_) => Ok(()),
+ Err(err) => Err(PyValueError::new_err(err.to_string())),
+ }
+}
+
#[pymodule]
fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) ->
PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double))?;
@@ -153,5 +165,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m:
&PyModule) -> PyResult<()>
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
+ m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
Ok(())
}
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index a7c6b34a44..92782b9ed4 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -409,6 +409,18 @@ def test_record_batch_reader():
got_batches = list(b)
assert got_batches == batches
+def test_record_batch_reader_error():
+ schema = pa.schema([('ints', pa.list_(pa.int32()))])
+
+ def iter_batches():
+ yield pa.record_batch([[[1], [2, 42]]], schema)
+ raise ValueError("test error")
+
+ reader = pa.RecordBatchReader.from_batches(schema, iter_batches())
+
+ with pytest.raises(ValueError, match="test error"):
+ rust.reader_return_errors(reader)
+
def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]
diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs
index 7d6689a890..a9d2e8ab6b 100644
--- a/arrow/src/ffi_stream.rs
+++ b/arrow/src/ffi_stream.rs
@@ -54,6 +54,7 @@
//! }
//! ```
+use std::ffi::CStr;
use std::ptr::addr_of;
use std::{
convert::TryFrom,
@@ -120,7 +121,7 @@ unsafe extern "C" fn release_stream(stream: *mut
FFI_ArrowArrayStream) {
struct StreamPrivateData {
batch_reader: Box<dyn RecordBatchReader + Send>,
- last_error: String,
+ last_error: Option<CString>,
}
// The callback used to get array schema
@@ -142,8 +143,12 @@ unsafe extern "C" fn get_next(
// The callback used to get the error from last operation on the
`FFI_ArrowArrayStream`
unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) ->
*const c_char {
let mut ffi_stream = ExportedArrayStream { stream };
- let last_error = ffi_stream.get_last_error();
- CString::new(last_error.as_str()).unwrap().into_raw()
+ // The consumer should not take ownership of this string, we should return
+ // a const pointer to it.
+ match ffi_stream.get_last_error() {
+ Some(err_string) => err_string.as_ptr(),
+ None => std::ptr::null(),
+ }
}
impl Drop for FFI_ArrowArrayStream {
@@ -160,7 +165,7 @@ impl FFI_ArrowArrayStream {
pub fn new(batch_reader: Box<dyn RecordBatchReader + Send>) -> Self {
let private_data = Box::new(StreamPrivateData {
batch_reader,
- last_error: String::new(),
+ last_error: None,
});
Self {
@@ -206,7 +211,10 @@ impl ExportedArrayStream {
0
}
Err(ref err) => {
- private_data.last_error = err.to_string();
+ private_data.last_error = Some(
+ CString::new(err.to_string())
+ .expect("Error string has a null byte in it."),
+ );
get_error_code(err)
}
}
@@ -231,15 +239,18 @@ impl ExportedArrayStream {
0
} else {
let err = &next_batch.unwrap_err();
- private_data.last_error = err.to_string();
+ private_data.last_error = Some(
+ CString::new(err.to_string())
+ .expect("Error string has a null byte in it."),
+ );
get_error_code(err)
}
}
}
}
- pub fn get_last_error(&mut self) -> &String {
- &self.get_private_data().last_error
+ pub fn get_last_error(&mut self) -> Option<&CString> {
+ self.get_private_data().last_error.as_ref()
}
}
@@ -312,19 +323,15 @@ impl ArrowArrayStreamReader {
/// Get the last error from `ArrowArrayStreamReader`
fn get_stream_last_error(&mut self) -> Option<String> {
- self.stream.get_last_error?;
-
- let error_str = unsafe {
- let c_str =
- self.stream.get_last_error.unwrap()(&mut self.stream) as *mut
c_char;
- CString::from_raw(c_str).into_string()
- };
+ let get_last_error = self.stream.get_last_error?;
- if let Err(err) = error_str {
- Some(err.to_string())
- } else {
- Some(error_str.unwrap())
+ let error_str = unsafe { get_last_error(&mut self.stream) };
+ if error_str.is_null() {
+ return None;
}
+
+ let error_str = unsafe { CStr::from_ptr(error_str) };
+ Some(error_str.to_string_lossy().to_string())
}
}
@@ -381,6 +388,8 @@ pub unsafe fn export_reader_into_raw(
#[cfg(test)]
mod tests {
+ use arrow_schema::DataType;
+
use super::*;
use crate::array::Int32Array;
@@ -503,4 +512,32 @@ mod tests {
_test_round_trip_import(vec![array.clone(), array.clone(), array])
}
+
+ #[test]
+ fn test_error_import() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, true)]));
+
+ let iter =
+
Box::new(vec![Err(ArrowError::MemoryError("".to_string()))].into_iter());
+
+ let reader = TestRecordBatchReader::new(schema.clone(), iter);
+
+ // Import through `FFI_ArrowArrayStream` as `ArrowArrayStreamReader`
+ let stream = FFI_ArrowArrayStream::new(reader);
+ let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap();
+
+ let imported_schema = stream_reader.schema();
+ assert_eq!(imported_schema, schema);
+
+ let mut produced_batches = vec![];
+ for batch in stream_reader {
+ produced_batches.push(batch);
+ }
+
+ // The results should outlive the lifetime of the stream itself.
+ assert_eq!(produced_batches.len(), 1);
+ assert!(produced_batches[0].is_err());
+
+ Ok(())
+ }
}