This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 9544887f feat(python/adbc_driver_manager): export handles through
python Arrow Capsule interface (#1346)
9544887f is described below
commit 9544887f37edc30bed43a93ee4be615df3f1b014
Author: Joris Van den Bossche <[email protected]>
AuthorDate: Wed Dec 13 20:21:28 2023 +0100
feat(python/adbc_driver_manager): export handles through python Arrow
Capsule interface (#1346)
Addresses https://github.com/apache/arrow-adbc/issues/70
This PR adds the dunder methods to the Handle classes of the low-level
interface (which already enables using the low-level interface without
pyarrow and with the capsule protocol).
And secondly, in the places that accept data (eg ingest/bind), it now
also accepts objects that implement the dunders in addition to hardcoded
support for pyarrow.
---------
Co-authored-by: David Li <[email protected]>
---
.../adbc_driver_manager/_lib.pxd | 9 +-
.../adbc_driver_manager/_lib.pyx | 110 ++++++++++++++++++---
.../adbc_driver_manager/dbapi.py | 51 ++++++----
python/adbc_driver_manager/tests/test_dbapi.py | 28 ++++++
python/adbc_driver_manager/tests/test_lowlevel.py | 72 ++++++++++++++
5 files changed, 237 insertions(+), 33 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
index 358a09aa..e9ea833c 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
@@ -22,10 +22,15 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t
cdef extern from "adbc.h" nogil:
# C ABI
+
+ ctypedef void (*CArrowSchemaRelease)(void*)
+ ctypedef void (*CArrowArrayRelease)(void*)
+
cdef struct CArrowSchema"ArrowSchema":
- pass
+ CArrowSchemaRelease release
+
cdef struct CArrowArray"ArrowArray":
- pass
+ CArrowArrayRelease release
ctypedef int (*CArrowArrayStreamGetLastError)(void*)
ctypedef int (*CArrowArrayStreamGetNext)(void*, CArrowArray*)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index ced8870e..91139100 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -24,10 +24,15 @@ import threading
import typing
from typing import List, Tuple
+cimport cpython
import cython
from cpython.bytes cimport PyBytes_FromStringAndSize
+from cpython.pycapsule cimport (
+ PyCapsule_GetPointer, PyCapsule_New, PyCapsule_CheckExact
+)
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
-from libc.string cimport memset
+from libc.stdlib cimport malloc, free
+from libc.string cimport memcpy, memset
from libcpp.vector cimport vector as c_vector
if typing.TYPE_CHECKING:
@@ -304,9 +309,29 @@ cdef class _AdbcHandle:
f"with open {self._child_type}")
+cdef void pycapsule_schema_deleter(object capsule) noexcept:
+ cdef CArrowSchema* allocated = <CArrowSchema*>PyCapsule_GetPointer(
+ capsule, "arrow_schema"
+ )
+ if allocated.release != NULL:
+ allocated.release(allocated)
+ free(allocated)
+
+
+cdef void pycapsule_stream_deleter(object capsule) noexcept:
+ cdef CArrowArrayStream* allocated = <CArrowArrayStream*>
PyCapsule_GetPointer(
+ capsule, "arrow_array_stream"
+ )
+ if allocated.release != NULL:
+ allocated.release(allocated)
+ free(allocated)
+
+
cdef class ArrowSchemaHandle:
"""
A wrapper for an allocated ArrowSchema.
+
+ This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowSchema schema
@@ -316,23 +341,42 @@ cdef class ArrowSchemaHandle:
"""The address of the ArrowSchema."""
return <uintptr_t> &self.schema
+ def __arrow_c_schema__(self) -> object:
+ """Consume this object to get a PyCapsule."""
+ # Reference:
+ #
https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html#create-a-pycapsule
+ cdef CArrowSchema* allocated = <CArrowSchema*>
malloc(sizeof(CArrowSchema))
+ allocated.release = NULL
+ capsule = PyCapsule_New(
+ <void*>allocated, "arrow_schema", &pycapsule_schema_deleter,
+ )
+ memcpy(allocated, &self.schema, sizeof(CArrowSchema))
+ self.schema.release = NULL
+ return capsule
+
cdef class ArrowArrayHandle:
"""
A wrapper for an allocated ArrowArray.
+
+ This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowArray array
@property
def address(self) -> int:
- """The address of the ArrowArray."""
+ """
+ The address of the ArrowArray.
+ """
return <uintptr_t> &self.array
cdef class ArrowArrayStreamHandle:
"""
A wrapper for an allocated ArrowArrayStream.
+
+ This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowArrayStream stream
@@ -342,6 +386,21 @@ cdef class ArrowArrayStreamHandle:
"""The address of the ArrowArrayStream."""
return <uintptr_t> &self.stream
+ def __arrow_c_stream__(self, requested_schema=None) -> object:
+ """Consume this object to get a PyCapsule."""
+ if requested_schema is not None:
+ raise NotImplementedError("requested_schema")
+
+ cdef CArrowArrayStream* allocated = \
+ <CArrowArrayStream*> malloc(sizeof(CArrowArrayStream))
+ allocated.release = NULL
+ capsule = PyCapsule_New(
+ <void*>allocated, "arrow_array_stream", &pycapsule_stream_deleter,
+ )
+ memcpy(allocated, &self.stream, sizeof(CArrowArrayStream))
+ self.stream.release = NULL
+ return capsule
+
class GetObjectsDepth(enum.IntEnum):
ALL = ADBC_OBJECT_DEPTH_ALL
@@ -1000,32 +1059,47 @@ cdef class AdbcStatement(_AdbcHandle):
connection._open_child()
- def bind(self, data, schema) -> None:
+ def bind(self, data, schema=None) -> None:
"""
Bind an ArrowArray to this statement.
Parameters
----------
- data : int or ArrowArrayHandle
- schema : int or ArrowSchemaHandle
+ data : PyCapsule or int or ArrowArrayHandle
+ schema : PyCapsule or int or ArrowSchemaHandle
"""
cdef CAdbcError c_error = empty_error()
cdef CArrowArray* c_array
cdef CArrowSchema* c_schema
- if isinstance(data, ArrowArrayHandle):
+ if hasattr(data, "__arrow_c_array__") and not isinstance(data,
ArrowArrayHandle):
+ if schema is not None:
+ raise ValueError(
+ "Can not provide a schema when passing Arrow-compatible "
+ "data that implements the Arrow PyCapsule Protocol"
+ )
+ schema, data = data.__arrow_c_array__()
+
+ if PyCapsule_CheckExact(data):
+ c_array = <CArrowArray*> PyCapsule_GetPointer(data, "arrow_array")
+ elif isinstance(data, ArrowArrayHandle):
c_array = &(<ArrowArrayHandle> data).array
elif isinstance(data, int):
c_array = <CArrowArray*> data
else:
- raise TypeError(f"data must be int or ArrowArrayHandle, not
{type(data)}")
-
- if isinstance(schema, ArrowSchemaHandle):
+ raise TypeError(
+ "data must be Arrow-compatible data (implementing the Arrow
PyCapsule "
+ f"Protocol), a PyCapsule, int or ArrowArrayHandle, not
{type(data)}"
+ )
+
+ if PyCapsule_CheckExact(schema):
+ c_schema = <CArrowSchema*> PyCapsule_GetPointer(schema,
"arrow_schema")
+ elif isinstance(schema, ArrowSchemaHandle):
c_schema = &(<ArrowSchemaHandle> schema).schema
elif isinstance(schema, int):
c_schema = <CArrowSchema*> schema
else:
- raise TypeError(f"schema must be int or ArrowSchemaHandle, "
+ raise TypeError("schema must be a PyCapsule, int or
ArrowSchemaHandle, "
f"not {type(schema)}")
with nogil:
@@ -1042,17 +1116,27 @@ cdef class AdbcStatement(_AdbcHandle):
Parameters
----------
- stream : int or ArrowArrayStreamHandle
+ stream : PyCapsule or int or ArrowArrayStreamHandle
"""
cdef CAdbcError c_error = empty_error()
cdef CArrowArrayStream* c_stream
- if isinstance(stream, ArrowArrayStreamHandle):
+ if (
+ hasattr(stream, "__arrow_c_stream__")
+ and not isinstance(stream, ArrowArrayStreamHandle)
+ ):
+ stream = stream.__arrow_c_stream__()
+
+ if PyCapsule_CheckExact(stream):
+ c_stream = <CArrowArrayStream*> PyCapsule_GetPointer(
+ stream, "arrow_array_stream"
+ )
+ elif isinstance(stream, ArrowArrayStreamHandle):
c_stream = &(<ArrowArrayStreamHandle> stream).stream
elif isinstance(stream, int):
c_stream = <CArrowArrayStream*> stream
else:
- raise TypeError(f"data must be int or ArrowArrayStreamHandle, "
+ raise TypeError(f"data must be a PyCapsule, int or
ArrowArrayStreamHandle, "
f"not {type(stream)}")
with nogil:
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 8edcdf4f..4c36ad5c 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -612,17 +612,21 @@ class Cursor(_Closeable):
self._closed = True
def _bind(self, parameters) -> None:
- if isinstance(parameters, pyarrow.RecordBatch):
+ if hasattr(parameters, "__arrow_c_array__"):
+ self._stmt.bind(parameters)
+ elif hasattr(parameters, "__arrow_c_stream__"):
+ self._stmt.bind_stream(parameters)
+ elif isinstance(parameters, pyarrow.RecordBatch):
arr_handle = _lib.ArrowArrayHandle()
sch_handle = _lib.ArrowSchemaHandle()
parameters._export_to_c(arr_handle.address, sch_handle.address)
self._stmt.bind(arr_handle, sch_handle)
- return
- if isinstance(parameters, pyarrow.Table):
- parameters = parameters.to_reader()
- stream_handle = _lib.ArrowArrayStreamHandle()
- parameters._export_to_c(stream_handle.address)
- self._stmt.bind_stream(stream_handle)
+ else:
+ if isinstance(parameters, pyarrow.Table):
+ parameters = parameters.to_reader()
+ stream_handle = _lib.ArrowArrayStreamHandle()
+ parameters._export_to_c(stream_handle.address)
+ self._stmt.bind_stream(stream_handle)
def _prepare_execute(self, operation, parameters=None) -> None:
self._results = None
@@ -639,9 +643,7 @@ class Cursor(_Closeable):
# Not all drivers support it
pass
- if isinstance(
- parameters, (pyarrow.RecordBatch, pyarrow.Table,
pyarrow.RecordBatchReader)
- ):
+ if _is_arrow_data(parameters):
self._bind(parameters)
elif parameters:
rb = pyarrow.record_batch(
@@ -668,7 +670,6 @@ class Cursor(_Closeable):
self._prepare_execute(operation, parameters)
handle, self._rowcount = self._stmt.execute_query()
self._results = _RowIterator(
- # pyarrow.RecordBatchReader._import_from_c(handle.address)
_reader.AdbcRecordBatchReader._import_from_c(handle.address)
)
@@ -683,7 +684,7 @@ class Cursor(_Closeable):
operation : bytes or str
The query to execute. Pass SQL queries as strings,
(serialized) Substrait plans as bytes.
- parameters
+ seq_of_parameters
Parameters to bind. Can be a list of Python sequences, or
an Arrow record batch, table, or record batch reader. If
None, then the query will be executed once, else it will
@@ -695,10 +696,7 @@ class Cursor(_Closeable):
self._stmt.set_sql_query(operation)
self._stmt.prepare()
- if isinstance(
- seq_of_parameters,
- (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader),
- ):
+ if _is_arrow_data(seq_of_parameters):
arrow_parameters = seq_of_parameters
elif seq_of_parameters:
arrow_parameters = pyarrow.RecordBatch.from_pydict(
@@ -806,7 +804,10 @@ class Cursor(_Closeable):
table_name
The table to insert into.
data
- The Arrow data to insert.
+ The Arrow data to insert. This can be a pyarrow RecordBatch, Table
+ or RecordBatchReader, or any Arrow-compatible data that implements
+ the Arrow PyCapsule Protocol (i.e. has an ``__arrow_c_array__``
+ or ``__arrow_c_stream__ ``method).
mode
How to deal with existing data:
@@ -878,7 +879,11 @@ class Cursor(_Closeable):
except NotSupportedError:
pass
- if isinstance(data, pyarrow.RecordBatch):
+ if hasattr(data, "__arrow_c_array__"):
+ self._stmt.bind(data)
+ elif hasattr(data, "__arrow_c_stream__"):
+ self._stmt.bind_stream(data)
+ elif isinstance(data, pyarrow.RecordBatch):
array = _lib.ArrowArrayHandle()
schema = _lib.ArrowSchemaHandle()
data._export_to_c(array.address, schema.address)
@@ -1151,3 +1156,13 @@ def _warn_unclosed(name):
category=ResourceWarning,
stacklevel=2,
)
+
+
+def _is_arrow_data(data):
+ return (
+ hasattr(data, "__arrow_c_array__")
+ or hasattr(data, "__arrow_c_stream__")
+ or isinstance(
+ data, (pyarrow.RecordBatch, pyarrow.Table,
pyarrow.RecordBatchReader)
+ )
+ )
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py
b/python/adbc_driver_manager/tests/test_dbapi.py
index 52b8e131..20990eff 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -134,6 +134,22 @@ def test_get_table_types(sqlite):
assert sqlite.adbc_get_table_types() == ["table", "view"]
+class ArrayWrapper:
+ def __init__(self, array):
+ self.array = array
+
+ def __arrow_c_array__(self, requested_schema=None):
+ return self.array.__arrow_c_array__(requested_schema=requested_schema)
+
+
+class StreamWrapper:
+ def __init__(self, stream):
+ self.stream = stream
+
+ def __arrow_c_stream__(self, requested_schema=None):
+ return
self.stream.__arrow_c_stream__(requested_schema=requested_schema)
+
+
@pytest.mark.parametrize(
"data",
[
@@ -142,6 +158,12 @@ def test_get_table_types(sqlite):
lambda: pyarrow.table(
[[1, 2], ["foo", ""]], names=["ints", "strs"]
).to_reader(),
+ lambda: ArrayWrapper(
+ pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"])
+ ),
+ lambda: StreamWrapper(
+ pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"])
+ ),
],
)
@pytest.mark.sqlite
@@ -237,6 +259,8 @@ def test_query_fetch_df(sqlite):
(1.0, 2),
pyarrow.record_batch([[1.0], [2]], names=["float", "int"]),
pyarrow.table([[1.0], [2]], names=["float", "int"]),
+ ArrayWrapper(pyarrow.record_batch([[1.0], [2]], names=["float",
"int"])),
+ StreamWrapper(pyarrow.table([[1.0], [2]], names=["float", "int"])),
],
)
def test_execute_parameters(sqlite, parameters):
@@ -253,6 +277,10 @@ def test_execute_parameters(sqlite, parameters):
pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float",
"str"]).to_batches()[0],
+ ArrayWrapper(
+ pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"])
+ ),
+ StreamWrapper(pyarrow.table([[1, 3], ["a", None]], names=["float",
"str"])),
((x, y) for x, y in ((1, "a"), (3, None))),
],
)
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py
b/python/adbc_driver_manager/tests/test_lowlevel.py
index 15d98e53..98c8721c 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -390,3 +390,75 @@ def test_child_tracking(sqlite):
RuntimeError, match="Cannot close AdbcDatabase with open
AdbcConnection"
):
db.close()
+
+
[email protected]
+def test_pycapsule(sqlite):
+ _, conn = sqlite
+ handle = conn.get_table_types()
+ with pyarrow.RecordBatchReader._import_from_c_capsule(
+ handle.__arrow_c_stream__()
+ ) as reader:
+ reader.read_all()
+
+ # set up some data
+ data = pyarrow.record_batch(
+ [
+ [1, 2, 3, 4],
+ ["a", "b", "c", "d"],
+ ],
+ names=["ints", "strs"],
+ )
+ table = pyarrow.Table.from_batches([data])
+
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE:
"foo"})
+ schema_capsule, array_capsule = data.__arrow_c_array__()
+ stmt.bind(array_capsule, schema_capsule)
+ stmt.execute_update()
+
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE:
"bar"})
+ stream_capsule = data.__arrow_c_stream__()
+ stmt.bind_stream(stream_capsule)
+ stmt.execute_update()
+
+ # importing a schema
+
+ handle = conn.get_table_schema(catalog=None, db_schema=None,
table_name="foo")
+ assert data.schema == pyarrow.schema(handle)
+ # ensure consumed schema was marked as such
+ with pytest.raises(ValueError, match="Cannot import released ArrowSchema"):
+ pyarrow.schema(handle)
+
+ # smoke test for the capsule calling release
+ capsule = conn.get_table_schema(
+ catalog=None, db_schema=None, table_name="foo"
+ ).__arrow_c_schema__()
+ del capsule
+
+ # importing a stream
+
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_sql_query("SELECT * FROM foo")
+ handle, _ = stmt.execute_query()
+
+ result = pyarrow.table(handle)
+ assert result == table
+
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_sql_query("SELECT * FROM bar")
+ handle, _ = stmt.execute_query()
+
+ result = pyarrow.table(handle)
+ assert result == table
+
+ # ensure consumed schema was marked as such
+ with pytest.raises(ValueError, match="Cannot import released
ArrowArrayStream"):
+ pyarrow.table(handle)
+
+ # smoke test for the capsule calling release
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_sql_query("SELECT * FROM foo")
+ capsule = stmt.execute_query()[0].__arrow_c_stream__()
+ del capsule