This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 687da3753 feat(python/adbc_driver_manager): enable DB-API without 
PyArrow (#2609)
687da3753 is described below

commit 687da37536b370821be6ac95bd5c5de53207b2e7
Author: David Li <[email protected]>
AuthorDate: Fri Mar 21 04:35:21 2025 -0400

    feat(python/adbc_driver_manager): enable DB-API without PyArrow (#2609)
    
    Enable limited use of the DB-API interface without PyArrow for people
    using other Arrow-based libraries, like polars (which is used here to
    test the new functionality).
    
    This doesn't enable everything (e.g. get_objects), but it does enable
    (parameterized) queries and ingestion, which I would guess are the most
    important things.
    
    Fixes #2413.
---
 ci/conda_env_docs.txt                              |   1 +
 ci/scripts/python_sdist_test.sh                    |   2 +-
 ci/scripts/python_util.sh                          |  29 +-
 ci/scripts/python_wheel_unix_test.sh               |   7 +-
 ci/scripts/python_wheel_windows_test.bat           |   2 +-
 docs/source/conf.py                                |  17 ++
 .../adbc_driver_manager/_lib.pyi                   |   1 +
 .../adbc_driver_manager/_lib.pyx                   |  14 +-
 .../adbc_driver_manager/dbapi.py                   | 292 ++++++++++++++-------
 python/adbc_driver_manager/pyproject.toml          |   1 +
 python/adbc_driver_manager/tests/test_dbapi.py     |  42 +++
 .../tests/test_dbapi_nopyarrow.py                  | 156 +++++++++++
 12 files changed, 461 insertions(+), 103 deletions(-)

diff --git a/ci/conda_env_docs.txt b/ci/conda_env_docs.txt
index 26d20f8c6..9badc3f31 100644
--- a/ci/conda_env_docs.txt
+++ b/ci/conda_env_docs.txt
@@ -21,6 +21,7 @@ make
 # Needed to install mermaid
 nodejs
 numpydoc
+polars
 pytest
 sphinx>=8.1
 sphinx-autobuild
diff --git a/ci/scripts/python_sdist_test.sh b/ci/scripts/python_sdist_test.sh
index 46ce43835..38e71e692 100755
--- a/ci/scripts/python_sdist_test.sh
+++ b/ci/scripts/python_sdist_test.sh
@@ -47,7 +47,7 @@ echo "=== Installing sdists ==="
 for component in ${COMPONENTS}; do
     pip install --no-deps --force-reinstall 
${source_dir}/python/${component}/dist/*.tar.gz
 done
-pip install importlib-resources pytest pyarrow pandas protobuf
+pip install importlib-resources pytest pyarrow pandas polars protobuf
 
 echo "=== (${PYTHON_VERSION}) Testing sdists ==="
 test_packages
diff --git a/ci/scripts/python_util.sh b/ci/scripts/python_util.sh
index abeabd78b..3bf1dee2e 100644
--- a/ci/scripts/python_util.sh
+++ b/ci/scripts/python_util.sh
@@ -161,10 +161,33 @@ import $component.dbapi
 
         # --import-mode required, else tries to import from the source dir 
instead of installed package
         if [[ "${component}" = "adbc_driver_manager" ]]; then
-            export PYTEST_ADDOPTS="-k 'not duckdb and not sqlite'"
-        elif [[ "${component}" = "adbc_driver_postgresql" ]]; then
-            export PYTEST_ADDOPTS="-k 'not polars'"
+            export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} -k 'not duckdb and not 
sqlite'"
         fi
         python -m pytest -vvx --import-mode append 
${source_dir}/python/$component/tests
     done
 }
+
+function test_packages_pyarrowless {
+    local -r driver_path=$(python -c "import os; import adbc_driver_sqlite; 
print(os.path.dirname(adbc_driver_sqlite._driver_path()))")
+    export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${driver_path}"
+    export DYLD_LIBRARY_PATH="${DYLD_LIBRARY_PATH}:${driver_path}"
+    # For macOS (because we name the file ".so" on every platform regardless 
of the actual type)
+    ln -s "${driver_path}/libadbc_driver_sqlite.so" 
"${driver_path}/libadbc_driver_sqlite.dylib"
+    for component in ${COMPONENTS}; do
+        echo "=== Testing $component (no PyArrow) ==="
+
+        python -c "
+import $component
+import $component.dbapi
+"
+
+        local test_files=$(find ${source_dir}/python/$component/tests -type f |
+                               grep -e 'nopyarrow\.py$')
+        if [[ -z "${test_files}" ]]; then
+            continue
+        fi
+
+        # --import-mode required, else tries to import from the source dir 
instead of installed package
+        python -m pytest -vvx --import-mode append "${test_files[@]}"
+    done
+}
diff --git a/ci/scripts/python_wheel_unix_test.sh 
b/ci/scripts/python_wheel_unix_test.sh
index 15eea984d..3cc72fd9b 100755
--- a/ci/scripts/python_wheel_unix_test.sh
+++ b/ci/scripts/python_wheel_unix_test.sh
@@ -49,8 +49,13 @@ for component in ${COMPONENTS}; do
         echo "NOTE: assuming wheels are already installed"
     fi
 done
-pip install importlib-resources pytest pyarrow pandas protobuf
+pip install importlib-resources pytest pyarrow pandas polars protobuf
 
 
 echo "=== (${PYTHON_VERSION}) Testing wheels ==="
 test_packages
+
+echo "=== (${PYTHON_VERSION}) Testing wheels (no PyArrow) ==="
+pip uninstall -y pyarrow
+export PYTEST_ADDOPTS="${PYTEST_ADDOPTS} -k pyarrowless"
+test_packages_pyarrowless
diff --git a/ci/scripts/python_wheel_windows_test.bat 
b/ci/scripts/python_wheel_windows_test.bat
index 963067b7b..852991bfc 100644
--- a/ci/scripts/python_wheel_windows_test.bat
+++ b/ci/scripts/python_wheel_windows_test.bat
@@ -27,7 +27,7 @@ FOR %%c IN (adbc_driver_bigquery adbc_driver_manager 
adbc_driver_flightsql adbc_
     )
 )
 
-pip install importlib-resources pytest pyarrow pandas protobuf
+pip install importlib-resources pytest pyarrow pandas polars protobuf
 
 echo "=== (%PYTHON_VERSION%) Testing wheels ==="
 
diff --git a/docs/source/conf.py b/docs/source/conf.py
index bcbefaf96..10771e168 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -54,6 +54,23 @@ extensions = [
 ]
 templates_path = ["_templates"]
 
+
+def on_missing_reference(app, env, node, contnode):
+    if str(contnode) == "polars.DataFrame":
+        # Polars does something odd with Sphinx such that polars.DataFrame
+        # isn't xrefable; suppress the warning.
+        return contnode
+    elif str(contnode) == "CapsuleType":
+        # CapsuleType is only in 3.13+
+        return contnode
+    else:
+        return None
+
+
+def setup(app):
+    app.connect("missing-reference", on_missing_reference)
+
+
 # -- Options for autodoc ----------------------------------------------------
 
 try:
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
index 0a19f92ed..6c7d6c4ab 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
@@ -214,3 +214,4 @@ def _blocking_call(
     kwargs: dict,
     cancel: typing.Callable[[], None],
 ) -> _T: ...
+def is_pycapsule(obj: Any, name: bytes) -> bool: ...
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 21afe9d3c..d2ac2401b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -31,7 +31,7 @@ from typing import List, Optional, Tuple
 import cython
 from cpython.bytes cimport PyBytes_FromStringAndSize
 from cpython.pycapsule cimport (
-    PyCapsule_GetPointer, PyCapsule_New, PyCapsule_CheckExact
+    PyCapsule_GetPointer, PyCapsule_IsValid, PyCapsule_New
 )
 from libc.stdint cimport int64_t, uint8_t, uint32_t, uintptr_t
 from libc.stdlib cimport malloc, free
@@ -337,6 +337,12 @@ cdef class _AdbcHandle:
                     f"with open {self._child_type}")
 
 
+def is_pycapsule(obj, bytes name) -> bool:
+    """Check if an object is a PyCapsule of a specific type."""
+    # Taken from nanoarrow
+    return PyCapsule_IsValid(obj, name) == 1
+
+
 cdef void pycapsule_schema_deleter(object capsule) noexcept:
     cdef CArrowSchema* allocated = <CArrowSchema*>PyCapsule_GetPointer(
         capsule, "arrow_schema"
@@ -1125,7 +1131,7 @@ cdef class AdbcStatement(_AdbcHandle):
                 )
             schema, data = data.__arrow_c_array__()
 
-        if PyCapsule_CheckExact(data):
+        if is_pycapsule(data, b"arrow_array"):
             c_array = <CArrowArray*> PyCapsule_GetPointer(data, "arrow_array")
         elif isinstance(data, ArrowArrayHandle):
             c_array = &(<ArrowArrayHandle> data).array
@@ -1137,7 +1143,7 @@ cdef class AdbcStatement(_AdbcHandle):
                 f"Protocol), a PyCapsule, int or ArrowArrayHandle, not 
{type(data)}"
             )
 
-        if PyCapsule_CheckExact(schema):
+        if is_pycapsule(schema, b"arrow_schema"):
             c_schema = <CArrowSchema*> PyCapsule_GetPointer(schema, 
"arrow_schema")
         elif isinstance(schema, ArrowSchemaHandle):
             c_schema = &(<ArrowSchemaHandle> schema).schema
@@ -1172,7 +1178,7 @@ cdef class AdbcStatement(_AdbcHandle):
         ):
             stream = stream.__arrow_c_stream__()
 
-        if PyCapsule_CheckExact(stream):
+        if is_pycapsule(stream, b"arrow_array_stream"):
             c_stream = <CArrowArrayStream*> PyCapsule_GetPointer(
                 stream, "arrow_array_stream"
             )
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py 
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 9dfc4e551..679cc1bff 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -17,6 +17,15 @@
 
 """PEP 249 (DB-API 2.0) API wrapper for the ADBC Driver Manager.
 
+PyArrow Requirement
+===================
+
+This module requires PyArrow for full functionality.  If PyArrow is not
+installed, all functionality that actually reads/writes data will be missing.
+You can still execute queries and get the result as a PyCapsule, but many
+other methods will raise.  Also, the DB-API type definitions (``BINARY``,
+``DATETIME``, etc) will be present, but invalid.
+
 Resource Management
 ===================
 
@@ -40,26 +49,23 @@ from typing import Any, Dict, List, Literal, Optional, 
Tuple, Union
 
 try:
     import pyarrow
-except ImportError as e:
-    raise ImportError("PyArrow is required for the DBAPI-compatible 
interface") from e
-
-try:
     import pyarrow.dataset
 except ImportError:
-    _pya_dataset = ()
-    _pya_scanner = ()
+    _has_pyarrow = False
 else:
-    _pya_dataset = (pyarrow.dataset.Dataset,)
-    _pya_scanner = (pyarrow.dataset.Scanner,)
+    _has_pyarrow = True
+    from . import _reader
 
 import adbc_driver_manager
 
-from . import _lib, _reader
+from . import _lib
 from ._lib import _blocking_call
 
 if typing.TYPE_CHECKING:
     import pandas
-    from typing_extensions import Self
+    import polars
+    import pyarrow
+    from typing_extensions import CapsuleType, Self
 
 # ----------------------------------------------------------
 # Globals
@@ -131,37 +137,44 @@ class _TypeSet(frozenset):
         return False
 
 
-#: The type of binary columns.
-BINARY = _TypeSet({pyarrow.binary().id, pyarrow.large_binary().id})
-#: The type of datetime columns.
-DATETIME = _TypeSet(
-    [
-        pyarrow.date32().id,
-        pyarrow.date64().id,
-        pyarrow.time32("s").id,
-        pyarrow.time64("ns").id,
-        pyarrow.timestamp("s").id,
-    ]
-)
-#: The type of numeric columns.
-NUMBER = _TypeSet(
-    [
-        pyarrow.int8().id,
-        pyarrow.int16().id,
-        pyarrow.int32().id,
-        pyarrow.int64().id,
-        pyarrow.uint8().id,
-        pyarrow.uint16().id,
-        pyarrow.uint32().id,
-        pyarrow.uint64().id,
-        pyarrow.float32().id,
-        pyarrow.float64().id,
-    ]
-)
-#: The type of "row ID" columns.
-ROWID = _TypeSet([pyarrow.int64().id])
-#: The type of string columns.
-STRING = _TypeSet([pyarrow.string().id, pyarrow.large_string().id])
+if _has_pyarrow:
+    #: The type of binary columns.
+    BINARY = _TypeSet({pyarrow.binary().id, pyarrow.large_binary().id})
+    #: The type of datetime columns.
+    DATETIME = _TypeSet(
+        [
+            pyarrow.date32().id,
+            pyarrow.date64().id,
+            pyarrow.time32("s").id,
+            pyarrow.time64("ns").id,
+            pyarrow.timestamp("s").id,
+        ]
+    )
+    #: The type of numeric columns.
+    NUMBER = _TypeSet(
+        [
+            pyarrow.int8().id,
+            pyarrow.int16().id,
+            pyarrow.int32().id,
+            pyarrow.int64().id,
+            pyarrow.uint8().id,
+            pyarrow.uint16().id,
+            pyarrow.uint32().id,
+            pyarrow.uint64().id,
+            pyarrow.float32().id,
+            pyarrow.float64().id,
+        ]
+    )
+    #: The type of "row ID" columns.
+    ROWID = _TypeSet([pyarrow.int64().id])
+    #: The type of string columns.
+    STRING = _TypeSet([pyarrow.string().id, pyarrow.large_string().id])
+else:
+    BINARY = _TypeSet()
+    DATETIME = _TypeSet()
+    NUMBER = _TypeSet()
+    ROWID = _TypeSet()
+    STRING = _TypeSet()
 
 # ----------------------------------------------------------
 # Functions
@@ -396,6 +409,8 @@ class Connection(_Closeable):
         -----
         This is an extension and not part of the DBAPI standard.
         """
+        _requires_pyarrow()
+
         handle = _blocking_call(self._conn.get_info, (), {}, self._conn.cancel)
         reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
         table = _blocking_call(reader.read_all, (), {}, self._conn.cancel)
@@ -418,7 +433,7 @@ class Connection(_Closeable):
         table_name_filter: Optional[str] = None,
         table_types_filter: Optional[List[str]] = None,
         column_name_filter: Optional[str] = None,
-    ) -> pyarrow.RecordBatchReader:
+    ) -> "pyarrow.RecordBatchReader":
         """
         List catalogs, schemas, tables, etc. in the database.
 
@@ -441,6 +456,8 @@ class Connection(_Closeable):
         -----
         This is an extension and not part of the DBAPI standard.
         """
+        _requires_pyarrow()
+
         if depth in ("all", "columns"):
             c_depth = _lib.GetObjectsDepth.ALL
         elif depth == "catalogs":
@@ -471,7 +488,7 @@ class Connection(_Closeable):
         *,
         catalog_filter: Optional[str] = None,
         db_schema_filter: Optional[str] = None,
-    ) -> pyarrow.Schema:
+    ) -> "pyarrow.Schema":
         """
         Get the Arrow schema of a table by name.
 
@@ -488,6 +505,8 @@ class Connection(_Closeable):
         -----
         This is an extension and not part of the DBAPI standard.
         """
+        _requires_pyarrow()
+
         handle = _blocking_call(
             self._conn.get_table_schema,
             (
@@ -508,6 +527,8 @@ class Connection(_Closeable):
         -----
         This is an extension and not part of the DBAPI standard.
         """
+        _requires_pyarrow()
+
         handle = _blocking_call(
             self._conn.get_table_types,
             (),
@@ -660,17 +681,10 @@ class Cursor(_Closeable):
             self._stmt.bind(parameters)
         elif hasattr(parameters, "__arrow_c_stream__"):
             self._stmt.bind_stream(parameters)
-        elif isinstance(parameters, pyarrow.RecordBatch):
-            arr_handle = _lib.ArrowArrayHandle()
-            sch_handle = _lib.ArrowSchemaHandle()
-            parameters._export_to_c(arr_handle.address, sch_handle.address)
-            self._stmt.bind(arr_handle, sch_handle)
+        elif _lib.is_pycapsule(parameters, b"arrow_array_stream"):
+            self._stmt.bind_stream(parameters)
         else:
-            if isinstance(parameters, pyarrow.Table):
-                parameters = parameters.to_reader()
-            stream_handle = _lib.ArrowArrayStreamHandle()
-            parameters._export_to_c(stream_handle.address)
-            self._stmt.bind_stream(stream_handle)
+            raise TypeError(f"Cannot bind {type(parameters)}")
 
     def _prepare_execute(self, operation, parameters=None) -> None:
         self._results = None
@@ -690,6 +704,7 @@ class Cursor(_Closeable):
         if _is_arrow_data(parameters):
             self._bind(parameters)
         elif parameters:
+            _requires_pyarrow()
             rb = pyarrow.record_batch(
                 [[param_value] for param_value in parameters],
                 names=[str(i) for i in range(len(parameters))],
@@ -716,9 +731,7 @@ class Cursor(_Closeable):
         handle, self._rowcount = _blocking_call(
             self._stmt.execute_query, (), {}, self._stmt.cancel
         )
-        self._results = _RowIterator(
-            self._stmt, 
_reader.AdbcRecordBatchReader._import_from_c(handle.address)
-        )
+        self._results = _RowIterator(self._stmt, handle)
 
     def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> 
None:
         """
@@ -746,6 +759,7 @@ class Cursor(_Closeable):
         if _is_arrow_data(seq_of_parameters):
             arrow_parameters = seq_of_parameters
         elif seq_of_parameters:
+            _requires_pyarrow()
             arrow_parameters = pyarrow.RecordBatch.from_pydict(
                 {
                     str(col_idx): pyarrow.array(x)
@@ -753,6 +767,7 @@ class Cursor(_Closeable):
                 },
             )
         else:
+            _requires_pyarrow()
             arrow_parameters = pyarrow.record_batch([])
 
         self._bind(arrow_parameters)
@@ -836,7 +851,12 @@ class Cursor(_Closeable):
     def adbc_ingest(
         self,
         table_name: str,
-        data: Union[pyarrow.RecordBatch, pyarrow.Table, 
pyarrow.RecordBatchReader],
+        data: Union[
+            "pyarrow.RecordBatch",
+            "pyarrow.Table",
+            "pyarrow.RecordBatchReader",
+            "CapsuleType",
+        ],
         mode: Literal["append", "create", "replace", "create_append"] = 
"create",
         *,
         catalog_name: Optional[str] = None,
@@ -932,24 +952,24 @@ class Cursor(_Closeable):
             self._stmt.bind(data)
         elif hasattr(data, "__arrow_c_stream__"):
             self._stmt.bind_stream(data)
-        elif isinstance(data, pyarrow.RecordBatch):
-            array = _lib.ArrowArrayHandle()
-            schema = _lib.ArrowSchemaHandle()
-            data._export_to_c(array.address, schema.address)
-            self._stmt.bind(array, schema)
+        elif _lib.is_pycapsule(data, b"arrow_array_stream"):
+            self._stmt.bind_stream(data)
         else:
-            if isinstance(data, pyarrow.Table):
-                data = data.to_reader()
-            elif isinstance(data, pyarrow.dataset.Dataset):
-                data = data.scanner().to_reader()
+            _requires_pyarrow()
+            if isinstance(data, pyarrow.dataset.Dataset):
+                data = typing.cast(pyarrow.dataset.Dataset, 
data).scanner().to_reader()
             elif isinstance(data, pyarrow.dataset.Scanner):
-                data = data.to_reader()
+                data = typing.cast(pyarrow.dataset.Scanner, data).to_reader()
             elif not hasattr(data, "_export_to_c"):
-                data = pyarrow.Table.from_batches(data)
-                data = data.to_reader()
-            handle = _lib.ArrowArrayStreamHandle()
-            data._export_to_c(handle.address)
-            self._stmt.bind_stream(handle)
+                data = pyarrow.Table.from_batches(data).to_reader()
+            if hasattr(data, "_export_to_c"):
+                handle = _lib.ArrowArrayStreamHandle()
+                # pyright doesn't seem to handle flow-sensitive typing here
+                data._export_to_c(handle.address)  # type: ignore
+                self._stmt.bind_stream(handle)
+            else:
+                # Should be impossible from above but let's be explicit
+                raise TypeError(f"Cannot bind {type(data)}")
 
         self._last_query = None
         return _blocking_call(self._stmt.execute_update, (), {}, 
self._stmt.cancel)
@@ -958,7 +978,7 @@ class Cursor(_Closeable):
         self,
         operation,
         parameters=None,
-    ) -> Tuple[List[bytes], pyarrow.Schema]:
+    ) -> Tuple[List[bytes], "pyarrow.Schema"]:
         """
         Execute a query and get the partitions of a distributed result set.
 
@@ -975,6 +995,7 @@ class Cursor(_Closeable):
         -----
         This is an extension and not part of the DBAPI standard.
         """
+        _requires_pyarrow()
         self._prepare_execute(operation, parameters)
         partitions, schema_handle, self._rowcount = _blocking_call(
             self._stmt.execute_partitions, (), {}, self._stmt.cancel
@@ -985,7 +1006,7 @@ class Cursor(_Closeable):
             schema = None
         return partitions, schema
 
-    def adbc_execute_schema(self, operation, parameters=None) -> 
pyarrow.Schema:
+    def adbc_execute_schema(self, operation, parameters=None) -> 
"pyarrow.Schema":
         """
         Get the schema of the result set of a query without executing it.
 
@@ -998,11 +1019,12 @@ class Cursor(_Closeable):
         -----
         This is an extension and not part of the DBAPI standard.
         """
+        _requires_pyarrow()
         self._prepare_execute(operation, parameters)
         schema = _blocking_call(self._stmt.execute_schema, (), {}, 
self._stmt.cancel)
         return pyarrow.Schema._import_from_c(schema.address)
 
-    def adbc_prepare(self, operation: Union[bytes, str]) -> 
Optional[pyarrow.Schema]:
+    def adbc_prepare(self, operation: Union[bytes, str]) -> 
Optional["pyarrow.Schema"]:
         """
         Prepare a query without executing it.
 
@@ -1020,6 +1042,7 @@ class Cursor(_Closeable):
         -----
         This is an extension and not part of the DBAPI standard.
         """
+        _requires_pyarrow()
         self._prepare_execute(operation)
 
         try:
@@ -1038,14 +1061,13 @@ class Cursor(_Closeable):
         -----
         This is an extension and not part of the DBAPI standard.
         """
+        _requires_pyarrow()
         self._results = None
         handle = _blocking_call(
             self._conn._conn.read_partition, (partition,), {}, 
self._stmt.cancel
         )
         self._rowcount = -1
-        self._results = _RowIterator(
-            self._stmt, 
pyarrow.RecordBatchReader._import_from_c(handle.address)
-        )
+        self._results = _RowIterator(self._stmt, handle)
 
     @property
     def adbc_statement(self) -> _lib.AdbcStatement:
@@ -1076,7 +1098,7 @@ class Cursor(_Closeable):
         self._stmt.set_sql_query(operation)
         _blocking_call(self._stmt.execute_update, (), {}, self._stmt.cancel)
 
-    def fetchallarrow(self) -> pyarrow.Table:
+    def fetchallarrow(self) -> "pyarrow.Table":
         """
         Fetch all rows of the result as a PyArrow Table.
 
@@ -1088,7 +1110,7 @@ class Cursor(_Closeable):
         """
         return self.fetch_arrow_table()
 
-    def fetch_arrow_table(self) -> pyarrow.Table:
+    def fetch_arrow_table(self) -> "pyarrow.Table":
         """
         Fetch all rows of the result as a PyArrow Table.
 
@@ -1122,7 +1144,22 @@ class Cursor(_Closeable):
             )
         return self._results.fetch_df()
 
-    def fetch_record_batch(self) -> pyarrow.RecordBatchReader:
+    def fetch_polars(self) -> "polars.DataFrame":
+        """
+        Fetch all rows of the result as a Polars DataFrame.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        if self._results is None:
+            raise ProgrammingError(
+                "Cannot fetch_polars() before execute()",
+                status_code=_lib.AdbcStatusCode.INVALID_STATE,
+            )
+        return self._results.fetch_polars()
+
+    def fetch_record_batch(self) -> "pyarrow.RecordBatchReader":
         """
         Fetch the result as a PyArrow RecordBatchReader.
 
@@ -1133,6 +1170,7 @@ class Cursor(_Closeable):
         -----
         This is an extension and not part of the DBAPI standard.
         """
+        _requires_pyarrow()
         if self._results is None:
             raise ProgrammingError(
                 "Cannot fetch_record_batch() before execute()",
@@ -1141,7 +1179,27 @@ class Cursor(_Closeable):
         # XXX(https://github.com/apache/arrow-adbc/issues/1523): return the
         # "real" PyArrow reader since PyArrow may try to poke the internal C++
         # reader pointer
-        return self._results._reader._reader
+        return self._results.reader._reader
+
+    def fetch_arrow(self) -> _lib.ArrowArrayStreamHandle:
+        """
+        Fetch the result as an object implementing the Arrow PyCapsule 
interface.
+
+        This can only be called once.  It must be called before any other
+        method that inspect the data (e.g. description, fetchone,
+        fetch_arrow_table, etc.).  Once this is called, other methods that
+        inspect the data may not be called.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        if self._results is None:
+            raise ProgrammingError(
+                "Cannot fetch_arrow() before execute()",
+                status_code=_lib.AdbcStatusCode.INVALID_STATE,
+            )
+        return self._results.fetch_arrow()
 
 
 # ----------------------------------------------------------
@@ -1151,24 +1209,41 @@ class Cursor(_Closeable):
 class _RowIterator(_Closeable):
     """Track state needed to iterate over the result set."""
 
-    def __init__(self, stmt, reader: pyarrow.RecordBatchReader) -> None:
+    def __init__(self, stmt, handle: _lib.ArrowArrayStreamHandle) -> None:
         self._stmt = stmt
-        self._reader = reader
+        self._handle: Optional[_lib.ArrowArrayStreamHandle] = handle
+        self._reader: Optional["_reader.AdbcRecordBatchReader"] = None
         self._current_batch = None
         self._next_row = 0
         self._finished = False
         self.rownumber = 0
 
     def close(self) -> None:
-        if hasattr(self._reader, "close"):
+        if self._reader is not None and hasattr(self._reader, "close"):
             # Only in recent PyArrow
             self._reader.close()
+        self._reader = None
+
+    @property
+    def reader(self) -> "_reader.AdbcRecordBatchReader":
+        if self._reader is None:
+            _requires_pyarrow()
+            if self._handle is None:
+                raise ProgrammingError(
+                    "Result set has been closed or consumed",
+                    status_code=_lib.AdbcStatusCode.INVALID_STATE,
+                )
+            else:
+                handle, self._handle = self._handle, None
+                klass = _reader.AdbcRecordBatchReader  # type: ignore
+                self._reader = klass._import_from_c(handle.address)
+        return self._reader
 
     @property
     def description(self) -> List[tuple]:
         return [
             (field.name, field.type, None, None, None, None, None)
-            for field in self._reader.schema
+            for field in self.reader.schema
         ]
 
     def fetchone(self) -> Optional[tuple]:
@@ -1176,7 +1251,7 @@ class _RowIterator(_Closeable):
             try:
                 while True:
                     self._current_batch = _blocking_call(
-                        self._reader.read_next_batch, (), {}, self._stmt.cancel
+                        self.reader.read_next_batch, (), {}, self._stmt.cancel
                     )
                     if self._current_batch.num_rows > 0:
                         break
@@ -1211,11 +1286,33 @@ class _RowIterator(_Closeable):
             rows.append(row)
         return rows
 
-    def fetch_arrow_table(self) -> pyarrow.Table:
-        return _blocking_call(self._reader.read_all, (), {}, self._stmt.cancel)
+    def fetch_arrow_table(self) -> "pyarrow.Table":
+        return _blocking_call(self.reader.read_all, (), {}, self._stmt.cancel)
 
     def fetch_df(self) -> "pandas.DataFrame":
-        return _blocking_call(self._reader.read_pandas, (), {}, 
self._stmt.cancel)
+        return _blocking_call(self.reader.read_pandas, (), {}, 
self._stmt.cancel)
+
+    def fetch_polars(self) -> "polars.DataFrame":
+        import polars
+
+        return _blocking_call(
+            lambda: typing.cast(
+                polars.DataFrame,
+                polars.from_arrow(self.fetch_arrow()),
+            ),
+            (),
+            {},
+            self._stmt.cancel,
+        )
+
+    def fetch_arrow(self) -> _lib.ArrowArrayStreamHandle:
+        if self._handle is None:
+            raise ProgrammingError(
+                "Result set has been closed or consumed",
+                status_code=_lib.AdbcStatusCode.INVALID_STATE,
+            )
+        handle, self._handle = self._handle, None
+        return handle
 
 
 _PYTEST_ENV_VAR = "PYTEST_CURRENT_TEST"
@@ -1234,10 +1331,19 @@ def _warn_unclosed(name):
 
 
 def _is_arrow_data(data):
+    # No need to check for PyArrow types explicitly since they support the
+    # dunder methods
     return (
         hasattr(data, "__arrow_c_array__")
         or hasattr(data, "__arrow_c_stream__")
-        or isinstance(
-            data, (pyarrow.RecordBatch, pyarrow.Table, 
pyarrow.RecordBatchReader)
-        )
+        or _lib.is_pycapsule(data, b"arrow_array")
+        or _lib.is_pycapsule(data, b"arrow_array_stream")
     )
+
+
+def _requires_pyarrow():
+    if not _has_pyarrow:
+        raise ProgrammingError(
+            "This API requires PyArrow to be installed",
+            status_code=_lib.AdbcStatusCode.INVALID_STATE,
+        )
diff --git a/python/adbc_driver_manager/pyproject.toml 
b/python/adbc_driver_manager/pyproject.toml
index 744024c97..a99be7c9b 100644
--- a/python/adbc_driver_manager/pyproject.toml
+++ b/python/adbc_driver_manager/pyproject.toml
@@ -41,6 +41,7 @@ build-backend = "setuptools.build_meta"
 markers = [
     "duckdb: tests that require DuckDB",
     "panicdummy: tests that require the testing-only panicdummy driver",
+    "pyarrowless: tests of functionality when PyArrow is NOT installed",
     "sqlite: tests that require the SQLite driver",
 ]
 xfail_strict = true
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py 
b/python/adbc_driver_manager/tests/test_dbapi.py
index 2db92388f..8325a2d19 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -16,6 +16,8 @@
 # under the License.
 
 import pandas
+import polars
+import polars.testing
 import pyarrow
 import pyarrow.dataset
 import pytest
@@ -165,6 +167,9 @@ class StreamWrapper:
         lambda: StreamWrapper(
             pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"])
         ),
+        lambda: pyarrow.table(
+            [[1, 2], ["foo", ""]], names=["ints", "strs"]
+        ).__arrow_c_stream__(),
     ],
 )
 @pytest.mark.sqlite
@@ -226,6 +231,27 @@ def test_query_fetch_py(sqlite):
 
 @pytest.mark.sqlite
 def test_query_fetch_arrow(sqlite):
+    with sqlite.cursor() as cur:
+        with pytest.raises(sqlite.ProgrammingError):
+            cur.fetch_arrow()
+
+        cur.execute("SELECT 1, 'foo' AS foo, 2.0")
+        capsule = cur.fetch_arrow().__arrow_c_stream__()
+        reader = pyarrow.RecordBatchReader._import_from_c_capsule(capsule)
+        assert reader.read_all() == pyarrow.table(
+            {
+                "1": [1],
+                "foo": ["foo"],
+                "2.0": [2.0],
+            }
+        )
+
+        with pytest.raises(sqlite.ProgrammingError):
+            cur.fetch_arrow()
+
+
[email protected]
+def test_query_fetch_arrow_table(sqlite):
     with sqlite.cursor() as cur:
         cur.execute("SELECT 1, 'foo' AS foo, 2.0")
         assert cur.fetch_arrow_table() == pyarrow.table(
@@ -253,6 +279,22 @@ def test_query_fetch_df(sqlite):
         )
 
 
[email protected]
+def test_query_fetch_polars(sqlite):
+    with sqlite.cursor() as cur:
+        cur.execute("SELECT 1, 'foo' AS foo, 2.0")
+        polars.testing.assert_frame_equal(
+            cur.fetch_polars(),
+            polars.DataFrame(
+                {
+                    "1": [1],
+                    "foo": ["foo"],
+                    "2.0": [2.0],
+                }
+            ),
+        )
+
+
 @pytest.mark.sqlite
 @pytest.mark.parametrize(
     "parameters",
diff --git a/python/adbc_driver_manager/tests/test_dbapi_nopyarrow.py 
b/python/adbc_driver_manager/tests/test_dbapi_nopyarrow.py
new file mode 100644
index 000000000..f65763fd4
--- /dev/null
+++ b/python/adbc_driver_manager/tests/test_dbapi_nopyarrow.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import typing
+
+import polars
+import polars.testing
+import pytest
+
+from adbc_driver_manager import dbapi
+
+pytestmark = pytest.mark.pyarrowless
+
+
[email protected](scope="module", autouse=True)
+def no_pyarrow() -> None:
+    try:
+        import pyarrow  # noqa:F401
+    except ImportError:
+        return
+    else:
+        pytest.skip("Skipping because pyarrow is installed")
+
+
[email protected]
+def sqlite() -> typing.Generator[dbapi.Connection, None, None]:
+    with dbapi.connect(driver="adbc_driver_sqlite") as conn:
+        yield conn
+
+
[email protected](
+    "data",
+    [
+        pytest.param(polars.DataFrame({"theresult": [1]}), 
id="polars.DataFrame"),
+        pytest.param(polars.Series([{"theresult": 1}]), id="polars.Series"),
+        pytest.param(
+            polars.DataFrame({"theresult": [1]}).__arrow_c_stream__(),
+            id="PyCapsule_Stream",
+        ),
+    ],
+)
+def test_ingest(sqlite: dbapi.Connection, data: typing.Any) -> None:
+    with sqlite.cursor() as cursor:
+        cursor.adbc_ingest("mytable", data)
+        cursor.execute("SELECT * FROM mytable")
+        df = cursor.fetch_polars()
+        polars.testing.assert_frame_equal(
+            df,
+            polars.DataFrame(
+                {
+                    "theresult": [1],
+                }
+            ),
+        )
+
+
+def test_query(sqlite: dbapi.Connection) -> None:
+    with sqlite.cursor() as cursor:
+        cursor.execute("SELECT 1 AS theresult")
+        capsule = cursor.fetch_arrow()
+        df = typing.cast(polars.DataFrame, polars.from_arrow(capsule))
+        polars.testing.assert_frame_equal(
+            df,
+            polars.DataFrame(
+                {
+                    "theresult": [1],
+                }
+            ),
+        )
+
+        cursor.execute("SELECT 1 AS theresult")
+        df = cursor.fetch_polars()
+        polars.testing.assert_frame_equal(
+            df,
+            polars.DataFrame(
+                {
+                    "theresult": [1],
+                }
+            ),
+        )
+
+
[email protected](
+    "parameters",
+    [
+        pytest.param(polars.DataFrame({"$0": [1]}), id="polars.DataFrame"),
+        pytest.param(polars.Series([{"$0": 1}]), id="polars.Series"),
+        pytest.param(
+            polars.DataFrame({"$0": [1]}).__arrow_c_stream__(), 
id="PyCapsule_Stream"
+        ),
+    ],
+)
+def test_query_bind(sqlite: dbapi.Connection, parameters: typing.Any) -> None:
+    with sqlite.cursor() as cursor:
+        cursor.execute("SELECT 1 + ? AS theresult", parameters=parameters)
+
+        df = cursor.fetch_polars()
+        polars.testing.assert_frame_equal(
+            df,
+            polars.DataFrame(
+                {
+                    "theresult": [2],
+                }
+            ),
+        )
+
+
+def test_query_not_permitted(sqlite: dbapi.Connection) -> None:
+    with sqlite.cursor() as cursor:
+        cursor.execute("SELECT 1 AS theresult")
+
+        with pytest.raises(dbapi.ProgrammingError, match="requires PyArrow"):
+            cursor.fetchone()
+
+        with pytest.raises(dbapi.ProgrammingError, match="requires PyArrow"):
+            cursor.fetchall()
+
+        with pytest.raises(dbapi.ProgrammingError, match="requires PyArrow"):
+            cursor.fetchallarrow()
+
+        with pytest.raises(dbapi.ProgrammingError, match="requires PyArrow"):
+            cursor.fetch_arrow_table()
+
+        with pytest.raises(dbapi.ProgrammingError, match="requires PyArrow"):
+            cursor.fetch_df()
+
+        capsule = cursor.fetch_arrow()
+        # Import the result to free memory
+        polars.from_arrow(capsule)
+
+
+def test_query_double_capsule(sqlite: dbapi.Connection) -> None:
+    with sqlite.cursor() as cursor:
+        cursor.execute("SELECT 1 AS theresult")
+
+        capsule = cursor.fetch_arrow()
+
+        with pytest.raises(dbapi.ProgrammingError, match="has been closed"):
+            cursor.fetch_arrow()
+
+        # Import the result to free memory
+        polars.from_arrow(capsule)


Reply via email to