This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 9544887f feat(python/adbc_driver_manager): export handles through 
python Arrow Capsule interface (#1346)
9544887f is described below

commit 9544887f37edc30bed43a93ee4be615df3f1b014
Author: Joris Van den Bossche <[email protected]>
AuthorDate: Wed Dec 13 20:21:28 2023 +0100

    feat(python/adbc_driver_manager): export handles through python Arrow 
Capsule interface (#1346)
    
    Addresses https://github.com/apache/arrow-adbc/issues/70
    
    This PR adds the dunder methods to the Handle classes of the low-level
    interface (which already enables using the low-level interface without
    pyarrow and with the capsule protocol).
    
    And secondly, in the places that accept data (eg ingest/bind), it now
    also accepts objects that implement the dunders in addition to hardcoded
    support for pyarrow.
    
    ---------
    
    Co-authored-by: David Li <[email protected]>
---
 .../adbc_driver_manager/_lib.pxd                   |   9 +-
 .../adbc_driver_manager/_lib.pyx                   | 110 ++++++++++++++++++---
 .../adbc_driver_manager/dbapi.py                   |  51 ++++++----
 python/adbc_driver_manager/tests/test_dbapi.py     |  28 ++++++
 python/adbc_driver_manager/tests/test_lowlevel.py  |  72 ++++++++++++++
 5 files changed, 237 insertions(+), 33 deletions(-)

diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
index 358a09aa..e9ea833c 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
@@ -22,10 +22,15 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t
 
 cdef extern from "adbc.h" nogil:
     # C ABI
+
+    ctypedef void (*CArrowSchemaRelease)(void*)
+    ctypedef void (*CArrowArrayRelease)(void*)
+
     cdef struct CArrowSchema"ArrowSchema":
-        pass
+        CArrowSchemaRelease release
+
     cdef struct CArrowArray"ArrowArray":
-        pass
+        CArrowArrayRelease release
 
     ctypedef int (*CArrowArrayStreamGetLastError)(void*)
     ctypedef int (*CArrowArrayStreamGetNext)(void*, CArrowArray*)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index ced8870e..91139100 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -24,10 +24,15 @@ import threading
 import typing
 from typing import List, Tuple
 
+cimport cpython
 import cython
 from cpython.bytes cimport PyBytes_FromStringAndSize
+from cpython.pycapsule cimport (
+    PyCapsule_GetPointer, PyCapsule_New, PyCapsule_CheckExact
+)
 from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
-from libc.string cimport memset
+from libc.stdlib cimport malloc, free
+from libc.string cimport memcpy, memset
 from libcpp.vector cimport vector as c_vector
 
 if typing.TYPE_CHECKING:
@@ -304,9 +309,29 @@ cdef class _AdbcHandle:
                     f"with open {self._child_type}")
 
 
+cdef void pycapsule_schema_deleter(object capsule) noexcept:
+    cdef CArrowSchema* allocated = <CArrowSchema*>PyCapsule_GetPointer(
+        capsule, "arrow_schema"
+    )
+    if allocated.release != NULL:
+        allocated.release(allocated)
+    free(allocated)
+
+
+cdef void pycapsule_stream_deleter(object capsule) noexcept:
+    cdef CArrowArrayStream* allocated = <CArrowArrayStream*> 
PyCapsule_GetPointer(
+        capsule, "arrow_array_stream"
+    )
+    if allocated.release != NULL:
+        allocated.release(allocated)
+    free(allocated)
+
+
 cdef class ArrowSchemaHandle:
     """
     A wrapper for an allocated ArrowSchema.
+
+    This object implements the Arrow PyCapsule interface.
     """
     cdef:
         CArrowSchema schema
@@ -316,23 +341,42 @@ cdef class ArrowSchemaHandle:
         """The address of the ArrowSchema."""
         return <uintptr_t> &self.schema
 
+    def __arrow_c_schema__(self) -> object:
+        """Consume this object to get a PyCapsule."""
+        # Reference:
+        # 
https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html#create-a-pycapsule
+        cdef CArrowSchema* allocated = <CArrowSchema*> 
malloc(sizeof(CArrowSchema))
+        allocated.release = NULL
+        capsule = PyCapsule_New(
+            <void*>allocated, "arrow_schema", &pycapsule_schema_deleter,
+        )
+        memcpy(allocated, &self.schema, sizeof(CArrowSchema))
+        self.schema.release = NULL
+        return capsule
+
 
 cdef class ArrowArrayHandle:
     """
     A wrapper for an allocated ArrowArray.
+
+    This object implements the Arrow PyCapsule interface.
     """
     cdef:
         CArrowArray array
 
     @property
     def address(self) -> int:
-        """The address of the ArrowArray."""
+        """
+        The address of the ArrowArray.
+        """
         return <uintptr_t> &self.array
 
 
 cdef class ArrowArrayStreamHandle:
     """
     A wrapper for an allocated ArrowArrayStream.
+
+    This object implements the Arrow PyCapsule interface.
     """
     cdef:
         CArrowArrayStream stream
@@ -342,6 +386,21 @@ cdef class ArrowArrayStreamHandle:
         """The address of the ArrowArrayStream."""
         return <uintptr_t> &self.stream
 
+    def __arrow_c_stream__(self, requested_schema=None) -> object:
+        """Consume this object to get a PyCapsule."""
+        if requested_schema is not None:
+            raise NotImplementedError("requested_schema")
+
+        cdef CArrowArrayStream* allocated = \
+            <CArrowArrayStream*> malloc(sizeof(CArrowArrayStream))
+        allocated.release = NULL
+        capsule = PyCapsule_New(
+            <void*>allocated, "arrow_array_stream", &pycapsule_stream_deleter,
+        )
+        memcpy(allocated, &self.stream, sizeof(CArrowArrayStream))
+        self.stream.release = NULL
+        return capsule
+
 
 class GetObjectsDepth(enum.IntEnum):
     ALL = ADBC_OBJECT_DEPTH_ALL
@@ -1000,32 +1059,47 @@ cdef class AdbcStatement(_AdbcHandle):
 
         connection._open_child()
 
-    def bind(self, data, schema) -> None:
+    def bind(self, data, schema=None) -> None:
         """
         Bind an ArrowArray to this statement.
 
         Parameters
         ----------
-        data : int or ArrowArrayHandle
-        schema : int or ArrowSchemaHandle
+        data : PyCapsule or int or ArrowArrayHandle
+        schema : PyCapsule or int or ArrowSchemaHandle
         """
         cdef CAdbcError c_error = empty_error()
         cdef CArrowArray* c_array
         cdef CArrowSchema* c_schema
 
-        if isinstance(data, ArrowArrayHandle):
+        if hasattr(data, "__arrow_c_array__") and not isinstance(data, 
ArrowArrayHandle):
+            if schema is not None:
+                raise ValueError(
+                    "Can not provide a schema when passing Arrow-compatible "
+                    "data that implements the Arrow PyCapsule Protocol"
+                )
+            schema, data = data.__arrow_c_array__()
+
+        if PyCapsule_CheckExact(data):
+            c_array = <CArrowArray*> PyCapsule_GetPointer(data, "arrow_array")
+        elif isinstance(data, ArrowArrayHandle):
             c_array = &(<ArrowArrayHandle> data).array
         elif isinstance(data, int):
             c_array = <CArrowArray*> data
         else:
-            raise TypeError(f"data must be int or ArrowArrayHandle, not 
{type(data)}")
-
-        if isinstance(schema, ArrowSchemaHandle):
+            raise TypeError(
+                "data must be Arrow-compatible data (implementing the Arrow 
PyCapsule "
+                f"Protocol), a PyCapsule, int or ArrowArrayHandle, not 
{type(data)}"
+            )
+
+        if PyCapsule_CheckExact(schema):
+            c_schema = <CArrowSchema*> PyCapsule_GetPointer(schema, 
"arrow_schema")
+        elif isinstance(schema, ArrowSchemaHandle):
             c_schema = &(<ArrowSchemaHandle> schema).schema
         elif isinstance(schema, int):
             c_schema = <CArrowSchema*> schema
         else:
-            raise TypeError(f"schema must be int or ArrowSchemaHandle, "
+            raise TypeError("schema must be a PyCapsule, int or 
ArrowSchemaHandle, "
                             f"not {type(schema)}")
 
         with nogil:
@@ -1042,17 +1116,27 @@ cdef class AdbcStatement(_AdbcHandle):
 
         Parameters
         ----------
-        stream : int or ArrowArrayStreamHandle
+        stream : PyCapsule or int or ArrowArrayStreamHandle
         """
         cdef CAdbcError c_error = empty_error()
         cdef CArrowArrayStream* c_stream
 
-        if isinstance(stream, ArrowArrayStreamHandle):
+        if (
+            hasattr(stream, "__arrow_c_stream__")
+            and not isinstance(stream, ArrowArrayStreamHandle)
+        ):
+            stream = stream.__arrow_c_stream__()
+
+        if PyCapsule_CheckExact(stream):
+            c_stream = <CArrowArrayStream*> PyCapsule_GetPointer(
+                stream, "arrow_array_stream"
+            )
+        elif isinstance(stream, ArrowArrayStreamHandle):
             c_stream = &(<ArrowArrayStreamHandle> stream).stream
         elif isinstance(stream, int):
             c_stream = <CArrowArrayStream*> stream
         else:
-            raise TypeError(f"data must be int or ArrowArrayStreamHandle, "
+            raise TypeError(f"data must be a PyCapsule, int or 
ArrowArrayStreamHandle, "
                             f"not {type(stream)}")
 
         with nogil:
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py 
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 8edcdf4f..4c36ad5c 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -612,17 +612,21 @@ class Cursor(_Closeable):
         self._closed = True
 
     def _bind(self, parameters) -> None:
-        if isinstance(parameters, pyarrow.RecordBatch):
+        if hasattr(parameters, "__arrow_c_array__"):
+            self._stmt.bind(parameters)
+        elif hasattr(parameters, "__arrow_c_stream__"):
+            self._stmt.bind_stream(parameters)
+        elif isinstance(parameters, pyarrow.RecordBatch):
             arr_handle = _lib.ArrowArrayHandle()
             sch_handle = _lib.ArrowSchemaHandle()
             parameters._export_to_c(arr_handle.address, sch_handle.address)
             self._stmt.bind(arr_handle, sch_handle)
-            return
-        if isinstance(parameters, pyarrow.Table):
-            parameters = parameters.to_reader()
-        stream_handle = _lib.ArrowArrayStreamHandle()
-        parameters._export_to_c(stream_handle.address)
-        self._stmt.bind_stream(stream_handle)
+        else:
+            if isinstance(parameters, pyarrow.Table):
+                parameters = parameters.to_reader()
+            stream_handle = _lib.ArrowArrayStreamHandle()
+            parameters._export_to_c(stream_handle.address)
+            self._stmt.bind_stream(stream_handle)
 
     def _prepare_execute(self, operation, parameters=None) -> None:
         self._results = None
@@ -639,9 +643,7 @@ class Cursor(_Closeable):
                 # Not all drivers support it
                 pass
 
-        if isinstance(
-            parameters, (pyarrow.RecordBatch, pyarrow.Table, 
pyarrow.RecordBatchReader)
-        ):
+        if _is_arrow_data(parameters):
             self._bind(parameters)
         elif parameters:
             rb = pyarrow.record_batch(
@@ -668,7 +670,6 @@ class Cursor(_Closeable):
         self._prepare_execute(operation, parameters)
         handle, self._rowcount = self._stmt.execute_query()
         self._results = _RowIterator(
-            # pyarrow.RecordBatchReader._import_from_c(handle.address)
             _reader.AdbcRecordBatchReader._import_from_c(handle.address)
         )
 
@@ -683,7 +684,7 @@ class Cursor(_Closeable):
         operation : bytes or str
             The query to execute.  Pass SQL queries as strings,
             (serialized) Substrait plans as bytes.
-        parameters
+        seq_of_parameters
             Parameters to bind.  Can be a list of Python sequences, or
             an Arrow record batch, table, or record batch reader.  If
             None, then the query will be executed once, else it will
@@ -695,10 +696,7 @@ class Cursor(_Closeable):
             self._stmt.set_sql_query(operation)
             self._stmt.prepare()
 
-        if isinstance(
-            seq_of_parameters,
-            (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader),
-        ):
+        if _is_arrow_data(seq_of_parameters):
             arrow_parameters = seq_of_parameters
         elif seq_of_parameters:
             arrow_parameters = pyarrow.RecordBatch.from_pydict(
@@ -806,7 +804,10 @@ class Cursor(_Closeable):
         table_name
             The table to insert into.
         data
-            The Arrow data to insert.
+            The Arrow data to insert. This can be a pyarrow RecordBatch, Table
+            or RecordBatchReader, or any Arrow-compatible data that implements
+            the Arrow PyCapsule Protocol (i.e. has an ``__arrow_c_array__``
+            or ``__arrow_c_stream__ ``method).
         mode
             How to deal with existing data:
 
@@ -878,7 +879,11 @@ class Cursor(_Closeable):
             except NotSupportedError:
                 pass
 
-        if isinstance(data, pyarrow.RecordBatch):
+        if hasattr(data, "__arrow_c_array__"):
+            self._stmt.bind(data)
+        elif hasattr(data, "__arrow_c_stream__"):
+            self._stmt.bind_stream(data)
+        elif isinstance(data, pyarrow.RecordBatch):
             array = _lib.ArrowArrayHandle()
             schema = _lib.ArrowSchemaHandle()
             data._export_to_c(array.address, schema.address)
@@ -1151,3 +1156,13 @@ def _warn_unclosed(name):
             category=ResourceWarning,
             stacklevel=2,
         )
+
+
+def _is_arrow_data(data):
+    return (
+        hasattr(data, "__arrow_c_array__")
+        or hasattr(data, "__arrow_c_stream__")
+        or isinstance(
+            data, (pyarrow.RecordBatch, pyarrow.Table, 
pyarrow.RecordBatchReader)
+        )
+    )
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py 
b/python/adbc_driver_manager/tests/test_dbapi.py
index 52b8e131..20990eff 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -134,6 +134,22 @@ def test_get_table_types(sqlite):
     assert sqlite.adbc_get_table_types() == ["table", "view"]
 
 
+class ArrayWrapper:
+    def __init__(self, array):
+        self.array = array
+
+    def __arrow_c_array__(self, requested_schema=None):
+        return self.array.__arrow_c_array__(requested_schema=requested_schema)
+
+
+class StreamWrapper:
+    def __init__(self, stream):
+        self.stream = stream
+
+    def __arrow_c_stream__(self, requested_schema=None):
+        return 
self.stream.__arrow_c_stream__(requested_schema=requested_schema)
+
+
 @pytest.mark.parametrize(
     "data",
     [
@@ -142,6 +158,12 @@ def test_get_table_types(sqlite):
         lambda: pyarrow.table(
             [[1, 2], ["foo", ""]], names=["ints", "strs"]
         ).to_reader(),
+        lambda: ArrayWrapper(
+            pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"])
+        ),
+        lambda: StreamWrapper(
+            pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"])
+        ),
     ],
 )
 @pytest.mark.sqlite
@@ -237,6 +259,8 @@ def test_query_fetch_df(sqlite):
         (1.0, 2),
         pyarrow.record_batch([[1.0], [2]], names=["float", "int"]),
         pyarrow.table([[1.0], [2]], names=["float", "int"]),
+        ArrayWrapper(pyarrow.record_batch([[1.0], [2]], names=["float", 
"int"])),
+        StreamWrapper(pyarrow.table([[1.0], [2]], names=["float", "int"])),
     ],
 )
 def test_execute_parameters(sqlite, parameters):
@@ -253,6 +277,10 @@ def test_execute_parameters(sqlite, parameters):
         pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]),
         pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]),
         pyarrow.table([[1, 3], ["a", None]], names=["float", 
"str"]).to_batches()[0],
+        ArrayWrapper(
+            pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"])
+        ),
+        StreamWrapper(pyarrow.table([[1, 3], ["a", None]], names=["float", 
"str"])),
         ((x, y) for x, y in ((1, "a"), (3, None))),
     ],
 )
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py 
b/python/adbc_driver_manager/tests/test_lowlevel.py
index 15d98e53..98c8721c 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -390,3 +390,75 @@ def test_child_tracking(sqlite):
                 RuntimeError, match="Cannot close AdbcDatabase with open 
AdbcConnection"
             ):
                 db.close()
+
+
[email protected]
+def test_pycapsule(sqlite):
+    _, conn = sqlite
+    handle = conn.get_table_types()
+    with pyarrow.RecordBatchReader._import_from_c_capsule(
+        handle.__arrow_c_stream__()
+    ) as reader:
+        reader.read_all()
+
+    # set up some data
+    data = pyarrow.record_batch(
+        [
+            [1, 2, 3, 4],
+            ["a", "b", "c", "d"],
+        ],
+        names=["ints", "strs"],
+    )
+    table = pyarrow.Table.from_batches([data])
+
+    with adbc_driver_manager.AdbcStatement(conn) as stmt:
+        stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: 
"foo"})
+        schema_capsule, array_capsule = data.__arrow_c_array__()
+        stmt.bind(array_capsule, schema_capsule)
+        stmt.execute_update()
+
+    with adbc_driver_manager.AdbcStatement(conn) as stmt:
+        stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: 
"bar"})
+        stream_capsule = data.__arrow_c_stream__()
+        stmt.bind_stream(stream_capsule)
+        stmt.execute_update()
+
+    # importing a schema
+
+    handle = conn.get_table_schema(catalog=None, db_schema=None, 
table_name="foo")
+    assert data.schema == pyarrow.schema(handle)
+    # ensure consumed schema was marked as such
+    with pytest.raises(ValueError, match="Cannot import released ArrowSchema"):
+        pyarrow.schema(handle)
+
+    # smoke test for the capsule calling release
+    capsule = conn.get_table_schema(
+        catalog=None, db_schema=None, table_name="foo"
+    ).__arrow_c_schema__()
+    del capsule
+
+    # importing a stream
+
+    with adbc_driver_manager.AdbcStatement(conn) as stmt:
+        stmt.set_sql_query("SELECT * FROM foo")
+        handle, _ = stmt.execute_query()
+
+    result = pyarrow.table(handle)
+    assert result == table
+
+    with adbc_driver_manager.AdbcStatement(conn) as stmt:
+        stmt.set_sql_query("SELECT * FROM bar")
+        handle, _ = stmt.execute_query()
+
+    result = pyarrow.table(handle)
+    assert result == table
+
+    # ensure consumed schema was marked as such
+    with pytest.raises(ValueError, match="Cannot import released 
ArrowArrayStream"):
+        pyarrow.table(handle)
+
+    # smoke test for the capsule calling release
+    with adbc_driver_manager.AdbcStatement(conn) as stmt:
+        stmt.set_sql_query("SELECT * FROM foo")
+        capsule = stmt.execute_query()[0].__arrow_c_stream__()
+    del capsule

Reply via email to