This is an automated email from the ASF dual-hosted git repository.
wjones127 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new aff86e704d Implement Arrow PyCapsule Interface (#5070)
aff86e704d is described below
commit aff86e704dabecbf99edd1e0ad62c216819dbc15
Author: Kyle Barron <[email protected]>
AuthorDate: Wed Nov 15 13:18:45 2023 -0500
Implement Arrow PyCapsule Interface (#5070)
* arrow ffi array copy
* remove copy_ffi_array
* docstring
* wip: pycapsule support
* return
* Update arrow/src/pyarrow.rs
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* remove sync impl
* Update arrow/src/pyarrow.rs
Co-authored-by: Will Jones <[email protected]>
* Remove copy()
* Need &mut FFI_ArrowArray for std::mem::replace
* Use std::ptr::replace
* update comments
* Minimize unsafe block
* revert pub release functions
* Add RecordBatch and Stream conversion
* fix returns
* Fix return type
* Fix name
* fix ci
* Add tests
* Add table test
* skip if pre pyarrow 14
* bump python version in CI to use pyarrow 14
* Add record batch test
* Update arrow/src/pyarrow.rs
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
* run on pyarrow 13 and 14
* Update .github/workflows/integration.yml
Co-authored-by: Will Jones <[email protected]>
---------
Co-authored-by: Raphael Taylor-Davies
<[email protected]>
Co-authored-by: Will Jones <[email protected]>
---
.github/workflows/integration.yml | 6 +-
arrow-pyarrow-integration-testing/README.md | 2 +
.../tests/test_sql.py | 138 ++++++++++++++++++++-
arrow-schema/src/ffi.rs | 2 +
arrow/src/pyarrow.rs | 134 +++++++++++++++++++-
5 files changed, 274 insertions(+), 8 deletions(-)
diff --git a/.github/workflows/integration.yml
b/.github/workflows/integration.yml
index 6e2b442040..f939a6a13b 100644
--- a/.github/workflows/integration.yml
+++ b/.github/workflows/integration.yml
@@ -106,6 +106,8 @@ jobs:
strategy:
matrix:
rust: [ stable ]
+ # PyArrow 13 was the last version prior to introduction to Arrow
PyCapsules
+ pyarrow: [ "13", "14" ]
steps:
- uses: actions/checkout@v4
with:
@@ -128,14 +130,14 @@ jobs:
key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{
matrix.rust }}-
- uses: actions/setup-python@v4
with:
- python-version: '3.7'
+ python-version: '3.8'
- name: Upgrade pip and setuptools
run: pip install --upgrade pip setuptools wheel virtualenv
- name: Create virtualenv and install dependencies
run: |
virtualenv venv
source venv/bin/activate
- pip install maturin toml pytest pytz pyarrow>=5.0
+ pip install maturin toml pytest pytz pyarrow==${{ matrix.pyarrow }}
- name: Run Rust tests
run: |
source venv/bin/activate
diff --git a/arrow-pyarrow-integration-testing/README.md
b/arrow-pyarrow-integration-testing/README.md
index e63953ad79..5ca2ea76b8 100644
--- a/arrow-pyarrow-integration-testing/README.md
+++ b/arrow-pyarrow-integration-testing/README.md
@@ -25,6 +25,7 @@ Note that this crate uses two languages and an external ABI:
* `Rust`
* `Python`
* C ABI privately exposed by `Pyarrow`.
+* PyCapsule ABI publicly exposed by `pyarrow`
## Basic idea
@@ -36,6 +37,7 @@ we can use pyarrow's interface to move pointers from and to
Rust.
## Relevant literature
* [Arrow's
CDataInterface](https://arrow.apache.org/docs/format/CDataInterface.html)
+* [Arrow PyCapsule
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html)
* [Rust's FFI](https://doc.rust-lang.org/nomicon/ffi.html)
* [Pyarrow private
binds](https://github.com/apache/arrow/blob/ae1d24efcc3f1ac2a876d8d9f544a34eb04ae874/python/pyarrow/array.pxi#L1226)
* [PyO3](https://docs.rs/pyo3/0.12.1/pyo3/index.html)
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index 1748fd3ffb..16d4e0f12f 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -27,6 +27,8 @@ import pytz
import arrow_pyarrow_integration_testing as rust
+PYARROW_PRE_14 = int(pa.__version__.split('.')[0]) < 14
+
@contextlib.contextmanager
def no_pyarrow_leak():
@@ -113,6 +115,34 @@ _supported_pyarrow_types = [
_unsupported_pyarrow_types = [
]
+# As of pyarrow 14, pyarrow implements the Arrow PyCapsule interface
+#
(https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
+# This defines that Arrow consumers should allow any object that has specific
"dunder"
+# methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able
to handle
+# _any_ class, without pyarrow-specific handling.
+class SchemaWrapper:
+ def __init__(self, schema):
+ self.schema = schema
+
+ def __arrow_c_schema__(self):
+ return self.schema.__arrow_c_schema__()
+
+
+class ArrayWrapper:
+ def __init__(self, array):
+ self.array = array
+
+ def __arrow_c_array__(self):
+ return self.array.__arrow_c_array__()
+
+
+class StreamWrapper:
+ def __init__(self, stream):
+ self.stream = stream
+
+ def __arrow_c_stream__(self):
+ return self.stream.__arrow_c_stream__()
+
@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
def test_type_roundtrip(pyarrow_type):
@@ -120,6 +150,14 @@ def test_type_roundtrip(pyarrow_type):
assert restored == pyarrow_type
assert restored is not pyarrow_type
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
[email protected]("pyarrow_type", _supported_pyarrow_types, ids=str)
+def test_type_roundtrip_pycapsule(pyarrow_type):
+ wrapped = SchemaWrapper(pyarrow_type)
+ restored = rust.round_trip_type(wrapped)
+ assert restored == pyarrow_type
+ assert restored is not pyarrow_type
+
@pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str)
def test_type_roundtrip_raises(pyarrow_type):
@@ -138,6 +176,20 @@ def test_field_roundtrip(pyarrow_type):
field = rust.round_trip_field(pyarrow_field)
assert field == pyarrow_field
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
[email protected]('pyarrow_type', _supported_pyarrow_types, ids=str)
+def test_field_roundtrip_pycapsule(pyarrow_type):
+ pyarrow_field = pa.field("test", pyarrow_type, nullable=True)
+ wrapped = SchemaWrapper(pyarrow_field)
+ field = rust.round_trip_field(wrapped)
+ assert field == wrapped.schema
+
+ if pyarrow_type != pa.null():
+ # A null type field may not be non-nullable
+ pyarrow_field = pa.field("test", pyarrow_type, nullable=False)
+ field = rust.round_trip_field(wrapped)
+ assert field == wrapped.schema
+
def test_field_metadata_roundtrip():
metadata = {"hello": "World! 😊", "x": "2"}
pyarrow_field = pa.field("test", pa.int32(), metadata=metadata)
@@ -163,6 +215,17 @@ def test_primitive_python():
del b
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
+def test_primitive_python_pycapsule():
+ """
+ Python -> Rust -> Python
+ """
+ a = pa.array([1, 2, 3])
+ wrapped = ArrayWrapper(a)
+ b = rust.double(wrapped)
+ assert b == pa.array([2, 4, 6])
+
+
def test_primitive_rust():
"""
Rust -> Python -> Rust
@@ -433,6 +496,33 @@ def test_record_batch_reader():
got_batches = list(b)
assert got_batches == batches
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
+def test_record_batch_reader_pycapsule():
+ """
+ Python -> Rust -> Python
+ """
+ schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1':
b'value1'})
+ batches = [
+ pa.record_batch([[[1], [2, 42]]], schema),
+ pa.record_batch([[None, [], [5, 6]]], schema),
+ ]
+ a = pa.RecordBatchReader.from_batches(schema, batches)
+ wrapped = StreamWrapper(a)
+ b = rust.round_trip_record_batch_reader(wrapped)
+
+ assert b.schema == schema
+ got_batches = list(b)
+ assert got_batches == batches
+
+ # Also try the boxed reader variant
+ a = pa.RecordBatchReader.from_batches(schema, batches)
+ wrapped = StreamWrapper(a)
+ b = rust.boxed_reader_roundtrip(wrapped)
+ assert b.schema == schema
+ got_batches = list(b)
+ assert got_batches == batches
+
+
def test_record_batch_reader_error():
schema = pa.schema([('ints', pa.list_(pa.int32()))])
@@ -453,24 +543,64 @@ def test_record_batch_reader_error():
with pytest.raises(ValueError, match="invalid utf-8"):
rust.round_trip_record_batch_reader(reader)
+
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
+def test_record_batch_pycapsule():
+ """
+ Python -> Rust -> Python
+ """
+ schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1':
b'value1'})
+ batch = pa.record_batch([[[1], [2, 42]]], schema)
+ wrapped = StreamWrapper(batch)
+ b = rust.round_trip_record_batch_reader(wrapped)
+ new_table = b.read_all()
+ new_batches = new_table.to_batches()
+
+ assert len(new_batches) == 1
+ new_batch = new_batches[0]
+
+ assert batch == new_batch
+ assert batch.schema == new_batch.schema
+
+
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
+def test_table_pycapsule():
+ """
+ Python -> Rust -> Python
+ """
+ schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1':
b'value1'})
+ batches = [
+ pa.record_batch([[[1], [2, 42]]], schema),
+ pa.record_batch([[None, [], [5, 6]]], schema),
+ ]
+ table = pa.Table.from_batches(batches)
+ wrapped = StreamWrapper(table)
+ b = rust.round_trip_record_batch_reader(wrapped)
+ new_table = b.read_all()
+
+ assert table.schema == new_table.schema
+ assert table == new_table
+ assert len(table.to_batches()) == len(new_table.to_batches())
+
+
def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]
with pytest.raises(TypeError, match="Expected instance of
pyarrow.lib.Array, got builtins.list"):
rust.round_trip_array(not_pyarrow)
-
+
with pytest.raises(TypeError, match="Expected instance of
pyarrow.lib.Schema, got builtins.list"):
rust.round_trip_schema(not_pyarrow)
-
+
with pytest.raises(TypeError, match="Expected instance of
pyarrow.lib.Field, got builtins.list"):
rust.round_trip_field(not_pyarrow)
-
+
with pytest.raises(TypeError, match="Expected instance of
pyarrow.lib.DataType, got builtins.list"):
rust.round_trip_type(not_pyarrow)
with pytest.raises(TypeError, match="Expected instance of
pyarrow.lib.RecordBatch, got builtins.list"):
rust.round_trip_record_batch(not_pyarrow)
-
+
with pytest.raises(TypeError, match="Expected instance of
pyarrow.lib.RecordBatchReader, got builtins.list"):
rust.round_trip_record_batch_reader(not_pyarrow)
diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs
index 7e33a78fec..640a7de798 100644
--- a/arrow-schema/src/ffi.rs
+++ b/arrow-schema/src/ffi.rs
@@ -351,6 +351,8 @@ impl Drop for FFI_ArrowSchema {
}
}
+unsafe impl Send for FFI_ArrowSchema {}
+
impl TryFrom<&FFI_ArrowSchema> for DataType {
type Error = ArrowError;
diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs
index 517c333add..4d262b0d10 100644
--- a/arrow/src/pyarrow.rs
+++ b/arrow/src/pyarrow.rs
@@ -59,12 +59,12 @@ use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
use std::sync::Arc;
-use arrow_array::{RecordBatchIterator, RecordBatchReader};
+use arrow_array::{RecordBatchIterator, RecordBatchReader, StructArray};
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
-use pyo3::types::{PyList, PyTuple};
+use pyo3::types::{PyCapsule, PyList, PyTuple};
use crate::array::{make_array, ArrayData};
use crate::datatypes::{DataType, Field, Schema};
@@ -118,8 +118,40 @@ fn validate_class(expected: &str, value: &PyAny) ->
PyResult<()> {
Ok(())
}
+fn validate_pycapsule(capsule: &PyCapsule, name: &str) -> PyResult<()> {
+ let capsule_name = capsule.name()?;
+ if capsule_name.is_none() {
+ return Err(PyValueError::new_err(
+ "Expected schema PyCapsule to have name set.",
+ ));
+ }
+
+ let capsule_name = capsule_name.unwrap().to_str()?;
+ if capsule_name != name {
+ return Err(PyValueError::new_err(format!(
+ "Expected name '{}' in PyCapsule, instead got '{}'",
+ name, capsule_name
+ )));
+ }
+
+ Ok(())
+}
+
impl FromPyArrow for DataType {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ // Newer versions of PyArrow as well as other libraries with Arrow
data implement this
+ // method, so prefer it over _export_to_c.
+ // See
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+ if value.hasattr("__arrow_c_schema__")? {
+ let capsule: &PyCapsule =
+
PyTryInto::try_into(value.getattr("__arrow_c_schema__")?.call0()?)?;
+ validate_pycapsule(capsule, "arrow_schema")?;
+
+ let schema_ptr = unsafe { capsule.reference::<FFI_ArrowSchema>() };
+ let dtype = DataType::try_from(schema_ptr).map_err(to_py_err)?;
+ return Ok(dtype);
+ }
+
validate_class("DataType", value)?;
let c_schema = FFI_ArrowSchema::empty();
@@ -143,6 +175,19 @@ impl ToPyArrow for DataType {
impl FromPyArrow for Field {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ // Newer versions of PyArrow as well as other libraries with Arrow
data implement this
+ // method, so prefer it over _export_to_c.
+ // See
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+ if value.hasattr("__arrow_c_schema__")? {
+ let capsule: &PyCapsule =
+
PyTryInto::try_into(value.getattr("__arrow_c_schema__")?.call0()?)?;
+ validate_pycapsule(capsule, "arrow_schema")?;
+
+ let schema_ptr = unsafe { capsule.reference::<FFI_ArrowSchema>() };
+ let field = Field::try_from(schema_ptr).map_err(to_py_err)?;
+ return Ok(field);
+ }
+
validate_class("Field", value)?;
let c_schema = FFI_ArrowSchema::empty();
@@ -166,6 +211,19 @@ impl ToPyArrow for Field {
impl FromPyArrow for Schema {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ // Newer versions of PyArrow as well as other libraries with Arrow
data implement this
+ // method, so prefer it over _export_to_c.
+ // See
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+ if value.hasattr("__arrow_c_schema__")? {
+ let capsule: &PyCapsule =
+
PyTryInto::try_into(value.getattr("__arrow_c_schema__")?.call0()?)?;
+ validate_pycapsule(capsule, "arrow_schema")?;
+
+ let schema_ptr = unsafe { capsule.reference::<FFI_ArrowSchema>() };
+ let schema = Schema::try_from(schema_ptr).map_err(to_py_err)?;
+ return Ok(schema);
+ }
+
validate_class("Schema", value)?;
let c_schema = FFI_ArrowSchema::empty();
@@ -189,6 +247,30 @@ impl ToPyArrow for Schema {
impl FromPyArrow for ArrayData {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ // Newer versions of PyArrow as well as other libraries with Arrow
data implement this
+ // method, so prefer it over _export_to_c.
+ // See
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+ if value.hasattr("__arrow_c_array__")? {
+ let tuple = value.getattr("__arrow_c_array__")?.call0()?;
+
+ if !tuple.is_instance_of::<PyTuple>() {
+ return Err(PyTypeError::new_err(
+ "Expected __arrow_c_array__ to return a tuple.",
+ ));
+ }
+
+ let schema_capsule: &PyCapsule =
PyTryInto::try_into(tuple.get_item(0)?)?;
+ let array_capsule: &PyCapsule =
PyTryInto::try_into(tuple.get_item(1)?)?;
+
+ validate_pycapsule(schema_capsule, "arrow_schema")?;
+ validate_pycapsule(array_capsule, "arrow_array")?;
+
+ let schema_ptr = unsafe {
schema_capsule.reference::<FFI_ArrowSchema>() };
+ let array_ptr = array_capsule.pointer() as *mut FFI_ArrowArray;
+ let array = unsafe { std::ptr::replace(array_ptr,
FFI_ArrowArray::empty()) };
+ return ffi::from_ffi(array, schema_ptr).map_err(to_py_err);
+ }
+
validate_class("Array", value)?;
// prepare a pointer to receive the Array struct
@@ -247,6 +329,37 @@ impl<T: ToPyArrow> ToPyArrow for Vec<T> {
impl FromPyArrow for RecordBatch {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ // Newer versions of PyArrow as well as other libraries with Arrow
data implement this
+ // method, so prefer it over _export_to_c.
+ // See
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+ if value.hasattr("__arrow_c_array__")? {
+ let tuple = value.getattr("__arrow_c_array__")?.call0()?;
+
+ if !tuple.is_instance_of::<PyTuple>() {
+ return Err(PyTypeError::new_err(
+ "Expected __arrow_c_array__ to return a tuple.",
+ ));
+ }
+
+ let schema_capsule: &PyCapsule =
PyTryInto::try_into(tuple.get_item(0)?)?;
+ let array_capsule: &PyCapsule =
PyTryInto::try_into(tuple.get_item(1)?)?;
+
+ validate_pycapsule(schema_capsule, "arrow_schema")?;
+ validate_pycapsule(array_capsule, "arrow_array")?;
+
+ let schema_ptr = unsafe {
schema_capsule.reference::<FFI_ArrowSchema>() };
+ let array_ptr = array_capsule.pointer() as *mut FFI_ArrowArray;
+ let ffi_array = unsafe { std::ptr::replace(array_ptr,
FFI_ArrowArray::empty()) };
+ let array_data = ffi::from_ffi(ffi_array,
schema_ptr).map_err(to_py_err)?;
+ if !matches!(array_data.data_type(), DataType::Struct(_)) {
+ return Err(PyTypeError::new_err(
+ "Expected Struct type from __arrow_c_array.",
+ ));
+ }
+ let array = StructArray::from(array_data);
+ return Ok(array.into());
+ }
+
validate_class("RecordBatch", value)?;
// TODO(kszucs): implement the FFI conversions in arrow-rs for
RecordBatches
let schema = value.getattr("schema")?;
@@ -276,6 +389,23 @@ impl ToPyArrow for RecordBatch {
/// Supports conversion from `pyarrow.RecordBatchReader` to
[ArrowArrayStreamReader].
impl FromPyArrow for ArrowArrayStreamReader {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ // Newer versions of PyArrow as well as other libraries with Arrow
data implement this
+ // method, so prefer it over _export_to_c.
+ // See
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+ if value.hasattr("__arrow_c_stream__")? {
+ let capsule: &PyCapsule =
+
PyTryInto::try_into(value.getattr("__arrow_c_stream__")?.call0()?)?;
+ validate_pycapsule(capsule, "arrow_array_stream")?;
+
+ let stream_ptr = capsule.pointer() as *mut FFI_ArrowArrayStream;
+ let stream = unsafe { std::ptr::replace(stream_ptr,
FFI_ArrowArrayStream::empty()) };
+
+ let stream_reader = ArrowArrayStreamReader::try_new(stream)
+ .map_err(|err| PyValueError::new_err(err.to_string()))?;
+
+ return Ok(stream_reader);
+ }
+
validate_class("RecordBatchReader", value)?;
// prepare a pointer to receive the stream struct