This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 3956876 feat(c/driver_manager): expose ADBC functionality in DBAPI
layer (#143)
3956876 is described below
commit 3956876c4af2c94540664635f218d4c1c4b1e4e5
Author: David Li <[email protected]>
AuthorDate: Thu Nov 10 11:55:15 2022 -0500
feat(c/driver_manager): expose ADBC functionality in DBAPI layer (#143)
---
.../adbc_driver_manager/dbapi.py | 447 +++++++++++++++++----
python/adbc_driver_manager/tests/test_dbapi.py | 115 ++++++
2 files changed, 485 insertions(+), 77 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 0748a79..1c67986 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -20,11 +20,11 @@ PEP 249 (DB-API 2.0) API wrapper for the ADBC Driver
Manager.
"""
import datetime
-import functools
+import threading
import time
import typing
import warnings
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Literal, Optional, Tuple, Union
try:
import pyarrow
@@ -36,6 +36,8 @@ from . import _lib
if typing.TYPE_CHECKING:
from typing import Self
+ import pandas
+
# ----------------------------------------------------------
# Globals
@@ -58,6 +60,15 @@ InternalError = _lib.InternalError
ProgrammingError = _lib.ProgrammingError
NotSupportedError = _lib.NotSupportedError
+_KNOWN_INFO_VALUES = {
+ 0: "vendor_name",
+ 1: "vendor_version",
+ 2: "vendor_arrow_version",
+ 100: "vendor_name",
+ 101: "vendor_version",
+ 102: "vendor_arrow_version",
+}
+
# ----------------------------------------------------------
# Types
@@ -126,7 +137,7 @@ def connect(
driver: str,
entrypoint: str = None,
db_kwargs: Optional[Dict[str, str]] = None,
- conn_kwargs: Optional[Dict[str, str]] = None
+ conn_kwargs: Optional[Dict[str, str]] = None,
) -> "Connection":
"""
Connect to a database via ADBC.
@@ -159,7 +170,7 @@ def connect(
try:
db = _lib.AdbcDatabase(**db_kwargs)
conn = _lib.AdbcConnection(db, **conn_kwargs)
- return Connection(db, conn)
+ return Connection(db, conn, conn_kwargs)
except Exception:
if conn:
conn.close()
@@ -173,6 +184,8 @@ def connect(
class _Closeable:
+ """Base class providing context manager interface."""
+
def __enter__(self) -> "Self":
return self
@@ -180,6 +193,32 @@ class _Closeable:
self.close()
+class _SharedDatabase(_Closeable):
+ """A holder for a shared AdbcDatabase."""
+
+ def __init__(self, db: _lib.AdbcDatabase) -> None:
+ self._db = db
+ self._lock = threading.Lock()
+ self._refcount = 1
+
+ def _inc(self) -> None:
+ with self._lock:
+ self._refcount += 1
+
+ def _dec(self) -> int:
+ with self._lock:
+ self._refcount -= 1
+ return self._refcount
+
+ def clone(self) -> "Self":
+ self._inc()
+ return self
+
+ def close(self) -> None:
+ if self._dec() == 0:
+ self._db.close()
+
+
class Connection(_Closeable):
"""
A DB-API 2.0 (PEP 249) connection.
@@ -199,9 +238,18 @@ class Connection(_Closeable):
ProgrammingError = _lib.ProgrammingError
NotSupportedError = _lib.NotSupportedError
- def __init__(self, db: _lib.AdbcDatabase, conn: _lib.AdbcConnection) ->
None:
- self._db = db
+ def __init__(
+ self,
+ db: Union[_lib.AdbcDatabase, _SharedDatabase],
+ conn: _lib.AdbcConnection,
+ conn_kwargs: Optional[Dict[str, str]] = None,
+ ) -> None:
+ if isinstance(db, _SharedDatabase):
+ self._db = db.clone()
+ else:
+ self._db = _SharedDatabase(db)
self._conn = conn
+ self._conn_kwargs = conn_kwargs
try:
self._conn.set_autocommit(False)
@@ -215,7 +263,14 @@ class Connection(_Closeable):
self._commit_supported = True
def close(self) -> None:
- """Close the connection."""
+ """
+ Close the connection.
+
+ Warnings
+ --------
+ Failure to close a connection may leak memory or database
+ connections.
+ """
self._conn.close()
self._db.close()
@@ -224,14 +279,164 @@ class Connection(_Closeable):
if self._commit_supported:
self._conn.commit()
+ def cursor(self) -> "Cursor":
+ """Create a new cursor for querying the database."""
+ return Cursor(self)
+
def rollback(self) -> None:
"""Explicitly rollback."""
if self._commit_supported:
self._conn.rollback()
- def cursor(self) -> "Cursor":
- """Create a new cursor for querying the database."""
- return Cursor(self)
+ # ------------------------------------------------------------
+ # API Extensions
+ # ------------------------------------------------------------
+
+ def adbc_clone(self) -> "Connection":
+ """
+ Create a new Connection sharing the same underlying database.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ conn = _lib.AdbcConnection(self._db._db, **(self._conn_kwargs or {}))
+ return Connection(self._db, conn)
+
+ def adbc_get_info(self) -> Dict[Union[str, int], Any]:
+ """
+ Get metadata about the database and driver.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ handle = self._conn.get_info()
+ reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
+ info = reader.read_all().to_pylist()
+ return dict(
+ {
+ _KNOWN_INFO_VALUES.get(row["info_name"], row["info_name"]):
row[
+ "info_value"
+ ]
+ for row in info
+ }
+ )
+
+ def adbc_get_objects(
+ self,
+ *,
+ depth: Literal["all", "catalogs", "db_schemas", "tables", "columns"] =
"all",
+ catalog_filter: Optional[str] = None,
+ db_schema_filter: Optional[str] = None,
+ table_name_filter: Optional[str] = None,
+ table_types_filter: Optional[List[str]] = None,
+ column_name_filter: Optional[str] = None,
+ ) -> pyarrow.RecordBatchReader:
+ """
+ List catalogs, schemas, tables, etc. in the database.
+
+ Parameters
+ ----------
+ depth
+ What objects to return info on.
+ catalog_filter
+ An optional filter on the catalog names returned.
+ db_schema_filter
+ An optional filter on the database schema names returned.
+ table_name_filter
+ An optional filter on the table names returned.
+ table_types_filter
+ An optional list of types of tables returned.
+ column_name_filter
+ An optional filter on the column names returned.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ if depth in ("all", "columns"):
+ c_depth = _lib.GetObjectsDepth.ALL
+ elif depth == "catalogs":
+ c_depth = _lib.GetObjectsDepth.CATALOGS
+ elif depth == "db_schemas":
+ c_depth = _lib.GetObjectsDepth.DB_SCHEMAS
+ elif depth == "tables":
+ c_depth = _lib.GetObjectsDepth.TABLES
+ else:
+ raise ValueError(f"Invalid value for 'depth': {depth}")
+ handle = self._conn.get_objects(
+ c_depth,
+ catalog=catalog_filter,
+ db_schema=db_schema_filter,
+ table_name=table_name_filter,
+ table_types=table_types_filter,
+ column_name=column_name_filter,
+ )
+ return pyarrow.RecordBatchReader._import_from_c(handle.address)
+
+ def adbc_get_table_schema(
+ self,
+ table_name: str,
+ *,
+ catalog_filter: Optional[str] = None,
+ db_schema_filter: Optional[str] = None,
+ ) -> pyarrow.Schema:
+ """
+ Get the Arrow schema of a table by name.
+
+ Parameters
+ ----------
+ table_name
+ The table to get the schema of.
+ catalog_filter
+ An optional filter on the catalog name of the table.
+ db_schema_filter
+ An optional filter on the database schema name of the table.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ handle = self._conn.get_table_schema(
+ catalog_filter, db_schema_filter, table_name
+ )
+ return pyarrow.Schema._import_from_c(handle.address)
+
+ def adbc_get_table_types(self) -> List[str]:
+ """
+ List the types of tables that the server knows about.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ handle = self._conn.get_table_types()
+ reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
+ table = reader.read_all()
+ return table[0].to_pylist()
+
+ @property
+ def adbc_connection(self) -> _lib.AdbcConnection:
+ """
+ Get the underlying ADBC connection.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ return self._conn
+
+ @property
+ def adbc_database(self) -> _lib.AdbcDatabase:
+ """
+ Get the underlying ADBC database.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ return self._db._db
class Cursor(_Closeable):
@@ -244,7 +449,7 @@ class Cursor(_Closeable):
def __init__(self, conn: Connection) -> None:
self._conn = conn
self._stmt = _lib.AdbcStatement(conn._conn)
- self._last_query: Optional[str] = None
+ self._last_query: Optional[Union[str, bytes]] = None
self._results: Optional["_RowIterator"] = None
self._arraysize = 1
self._rowcount = -1
@@ -300,7 +505,11 @@ class Cursor(_Closeable):
self._results = None
if operation != self._last_query:
self._last_query = operation
- self._stmt.set_sql_query(operation)
+ if isinstance(operation, bytes):
+ # Serialized Substrait plan
+ self._stmt.set_substrait_plan(operation)
+ else:
+ self._stmt.set_sql_query(operation)
try:
self._stmt.prepare()
except NotSupportedError:
@@ -354,37 +563,6 @@ class Cursor(_Closeable):
self._stmt.bind(arr_handle, sch_handle)
self._rowcount = self._stmt.execute_update()
- def execute_partitions(
- self, operation, parameters=None
- ) -> Tuple[List[bytes], pyarrow.Schema]:
- """
- Execute a query and get the partitions of a distributed result set.
-
- This is an extension method, not present in DBAPI.
-
- Return
- ------
- partitions : list of byte
- A list of partition descriptors, which can be read with
- read_partition.
- schema : pyarrow.Schema
- The schema of the result set.
- """
- self._prepare_execute(operation, parameters)
- partitions, schema, self._rowcount = self._stmt.execute_partitions()
- return partitions, pyarrow.Schema._import_from_c(schema.address())
-
- def read_partition(self, partition: bytes) -> None:
- """
- Read a partition of a distributed result set.
- """
- self._results = None
- handle = self.conn._conn.read_partition(partition)
- self._rowcount = -1
- self._results = _RowIterator(
- pyarrow.RecordBatchReader._import_from_c(handle.address)
- )
-
def fetchone(self) -> tuple:
"""Fetch one row of the result."""
if self._results is None:
@@ -414,11 +592,150 @@ class Cursor(_Closeable):
)
return self._results.fetchall()
+ def next(self):
+ """Fetch the next row, or raise StopIteration."""
+ row = self.fetchone()
+ if row is None:
+ raise StopIteration
+ return row
+
+ def nextset(self):
+ raise NotSupportedError("Cursor.nextset")
+
+ def setinputsizes(self, sizes):
+ # Not used
+ pass
+
+ def setoutputsize(self, size, column=None):
+ # Not used
+ pass
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return self.next()
+
+ # ------------------------------------------------------------
+ # API Extensions
+ # ------------------------------------------------------------
+
+ def adbc_ingest(
+ self,
+ table_name: str,
+ data: Union[pyarrow.RecordBatch, pyarrow.Table,
pyarrow.RecordBatchReader],
+ mode: Literal["append", "create"] = "create",
+ ) -> int:
+ """
+ Ingest Arrow data into a database table.
+
+ Depending on the driver, this can avoid per-row overhead that
+ would result from a typical prepare-bind-insert loop.
+
+ Parameters
+ ----------
+ table_name
+ The table to insert into.
+ data
+ The Arrow data to insert.
+ mode
+ Whether to append data to an existing table, or create a new table.
+
+ Returns
+ -------
+ int
+ The number of rows inserted, or -1 if the driver cannot
+ provide this information.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ if mode == "append":
+ c_mode = _lib.INGEST_OPTION_MODE_APPEND
+ elif mode == "create":
+ c_mode = _lib.INGEST_OPTION_MODE_CREATE
+ else:
+ raise ValueError(f"Invalid value for 'mode': {mode}")
+ self._stmt.set_options(
+ **{
+ _lib.INGEST_OPTION_TARGET_TABLE: table_name,
+ _lib.INGEST_OPTION_MODE: c_mode,
+ }
+ )
+
+ if isinstance(data, pyarrow.RecordBatch):
+ array = _lib.ArrowArrayHandle()
+ schema = _lib.ArrowSchemaHandle()
+ data._export_to_c(array.address, schema.address)
+ self._stmt.bind(array, schema)
+ else:
+ if isinstance(data, pyarrow.Table):
+ data = data.to_reader()
+ handle = _lib.ArrowArrayStreamHandle()
+ data._export_to_c(handle.address)
+ self._stmt.bind_stream(handle)
+
+ self._last_query = None
+ return self._stmt.execute_update()
+
+ def adbc_execute_partitions(
+ self, operation, parameters=None
+ ) -> Tuple[List[bytes], pyarrow.Schema]:
+ """
+ Execute a query and get the partitions of a distributed result set.
+
+ Parameters
+ ----------
+ partitions : list of byte
+ A list of partition descriptors, which can be read with
+ read_partition.
+ schema : pyarrow.Schema
+ The schema of the result set.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ self._prepare_execute(operation, parameters)
+ partitions, schema, self._rowcount = self._stmt.execute_partitions()
+ return partitions, pyarrow.Schema._import_from_c(schema.address)
+
+ def adbc_read_partition(self, partition: bytes) -> None:
+ """
+ Read a partition of a distributed result set.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ self._results = None
+ handle = self.conn._conn.read_partition(partition)
+ self._rowcount = -1
+ self._results = _RowIterator(
+ pyarrow.RecordBatchReader._import_from_c(handle.address)
+ )
+
+ @property
+ def adbc_statement(self) -> _lib.AdbcStatement:
+ """
+ Get the underlying ADBC statement.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ return self._stmt
+
def fetchallarrow(self) -> pyarrow.Table:
"""
Fetch all rows of the result as a PyArrow Table.
This implements a similar API as turbodbc.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
"""
return self.fetch_arrow_table()
@@ -427,6 +744,10 @@ class Cursor(_Closeable):
Fetch all rows of the result as a PyArrow Table.
This implements a similar API as DuckDB.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
"""
if self._results is None:
raise ProgrammingError(
@@ -435,11 +756,15 @@ class Cursor(_Closeable):
)
return self._results.fetch_arrow_table()
- def fetch_df(self):
+ def fetch_df(self) -> "pandas.DataFrame":
"""
Fetch all rows of the result as a Pandas DataFrame.
This implements a similar API as DuckDB.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
"""
if self._results is None:
raise ProgrammingError(
@@ -448,30 +773,6 @@ class Cursor(_Closeable):
)
return self._results.fetch_df()
- def next(self):
- """Fetch the next row, or raise StopIteration."""
- row = self.fetchone()
- if row is None:
- raise StopIteration
- return row
-
- def nextset(self):
- raise NotSupportedError("Cursor.nextset")
-
- def setinputsizes(self, sizes):
- # Not used
- pass
-
- def setoutputsize(self, size, column=None):
- # Not used
- pass
-
- def __iter__(self):
- return self
-
- def __next__(self):
- return self.next()
-
class _RowIterator(_Closeable):
"""Track state needed to iterate over the result set."""
@@ -507,10 +808,7 @@ class _RowIterator(_Closeable):
if self._finished:
return None
- row = tuple(
- _convert_value(arr, row=self._next_row)
- for arr in self._current_batch.columns
- )
+ row = tuple(arr[self._next_row].as_py() for arr in
self._current_batch.columns)
self._next_row += 1
self.rownumber += 1
return row
@@ -538,8 +836,3 @@ class _RowIterator(_Closeable):
def fetch_df(self):
return self._reader.read_pandas()
-
-
[email protected]
-def _convert_value(arr: pyarrow.Array, *, row: int) -> Any:
- return arr[row].as_py()
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py
b/python/adbc_driver_manager/tests/test_dbapi.py
index f8e6d0f..d3d34af 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -62,6 +62,114 @@ def test_attrs(sqlite):
assert cur.rowcount == -1
[email protected]
+def test_info(sqlite):
+ assert sqlite.adbc_get_info() == {
+ "vendor_arrow_version": "Arrow/C++ 9.0.0",
+ "vendor_name": "ADBC C SQLite3",
+ "vendor_version": "0.0.1",
+ }
+
+
[email protected]
+def test_get_underlying(sqlite):
+ assert sqlite.adbc_database
+ assert sqlite.adbc_connection
+ with sqlite.cursor() as cur:
+ assert cur.adbc_statement
+
+
[email protected]
+def test_clone(sqlite):
+ with sqlite.adbc_clone() as sqlite2:
+ with sqlite2.cursor() as cur:
+ cur.execute("CREATE TABLE temporary (ints)")
+ cur.execute("INSERT INTO temporary VALUES (1)")
+ sqlite2.commit()
+
+ with sqlite.cursor() as cur:
+ cur.execute("SELECT * FROM temporary")
+ assert cur.fetchone() == (1,)
+
+
[email protected]
+def test_get_objects(sqlite):
+ with sqlite.cursor() as cur:
+ cur.execute("CREATE TABLE temporary (ints)")
+ cur.execute("INSERT INTO temporary VALUES (1)")
+ metadata = (
+
sqlite.adbc_get_objects(table_name_filter="temporary").read_all().to_pylist()
+ )
+ assert len(metadata) == 1
+ assert metadata[0]["catalog_name"] == "main"
+ schemas = metadata[0]["catalog_db_schemas"]
+ assert len(schemas) == 1
+ assert schemas[0]["db_schema_name"] is None
+ tables = schemas[0]["db_schema_tables"]
+ assert len(tables) == 1
+ assert tables[0]["table_name"] == "temporary"
+ assert tables[0]["table_type"] == "table"
+ assert tables[0]["table_columns"][0]["column_name"] == "ints"
+ assert tables[0]["table_columns"][0]["ordinal_position"] == 1
+ assert tables[0]["table_constraints"] == []
+
+
[email protected]
+def test_get_table_schema(sqlite):
+ with sqlite.cursor() as cur:
+ cur.execute("CREATE TABLE temporary (ints)")
+ cur.execute("INSERT INTO temporary VALUES (1)")
+ assert sqlite.adbc_get_table_schema("temporary") == pyarrow.schema(
+ [("ints", pyarrow.int64())]
+ )
+
+
[email protected]
+def test_get_table_types(sqlite):
+ assert sqlite.adbc_get_table_types() == ["table", "view"]
+
+
[email protected](
+ "data",
+ [
+ lambda: pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints",
"strs"]),
+ lambda: pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"]),
+ lambda: pyarrow.table(
+ [[1, 2], ["foo", ""]], names=["ints", "strs"]
+ ).to_reader(),
+ ],
+)
[email protected]
+def test_ingest(data, sqlite):
+ with sqlite.cursor() as cur:
+ cur.adbc_ingest("bulk_ingest", data())
+
+ with pytest.raises(dbapi.ProgrammingError):
+ cur.adbc_ingest("bulk_ingest", data())
+
+ cur.adbc_ingest("bulk_ingest", data(), mode="append")
+
+ with pytest.raises(dbapi.Error):
+ cur.adbc_ingest("nonexistent", data(), mode="append")
+
+ with pytest.raises(ValueError):
+ cur.adbc_ingest("bulk_ingest", data(), mode="invalid")
+
+ with sqlite.cursor() as cur:
+ cur.execute("SELECT * FROM bulk_ingest")
+ assert cur.fetchone() == (1, "foo")
+ assert cur.fetchone() == (2, "")
+ assert cur.fetchone() == (1, "foo")
+ assert cur.fetchone() == (2, "")
+
+
[email protected]
+def test_partitions(sqlite):
+ with pytest.raises(dbapi.NotSupportedError):
+ with sqlite.cursor() as cur:
+ cur.adbc_execute_partitions("SELECT 1")
+
+
@pytest.mark.sqlite
def test_query_fetch_py(sqlite):
with sqlite.cursor() as cur:
@@ -124,6 +232,13 @@ def test_query_parameters(sqlite):
assert cur.fetchall() == [(2.0, 2)]
[email protected]
+def test_query_substrait(sqlite):
+ with sqlite.cursor() as cur:
+ with pytest.raises(dbapi.NotSupportedError):
+ cur.execute(b"Substrait plan")
+
+
@pytest.mark.sqlite
def test_executemany(sqlite):
with sqlite.cursor() as cur: