This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 3956876  feat(c/driver_manager): expose ADBC functionality in DBAPI 
layer (#143)
3956876 is described below

commit 3956876c4af2c94540664635f218d4c1c4b1e4e5
Author: David Li <[email protected]>
AuthorDate: Thu Nov 10 11:55:15 2022 -0500

    feat(c/driver_manager): expose ADBC functionality in DBAPI layer (#143)
---
 .../adbc_driver_manager/dbapi.py                   | 447 +++++++++++++++++----
 python/adbc_driver_manager/tests/test_dbapi.py     | 115 ++++++
 2 files changed, 485 insertions(+), 77 deletions(-)

diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py 
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 0748a79..1c67986 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -20,11 +20,11 @@ PEP 249 (DB-API 2.0) API wrapper for the ADBC Driver 
Manager.
 """
 
 import datetime
-import functools
+import threading
 import time
 import typing
 import warnings
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Literal, Optional, Tuple, Union
 
 try:
     import pyarrow
@@ -36,6 +36,8 @@ from . import _lib
 if typing.TYPE_CHECKING:
     from typing import Self
 
+    import pandas
+
 # ----------------------------------------------------------
 # Globals
 
@@ -58,6 +60,15 @@ InternalError = _lib.InternalError
 ProgrammingError = _lib.ProgrammingError
 NotSupportedError = _lib.NotSupportedError
 
+_KNOWN_INFO_VALUES = {
+    0: "vendor_name",
+    1: "vendor_version",
+    2: "vendor_arrow_version",
+    100: "vendor_name",
+    101: "vendor_version",
+    102: "vendor_arrow_version",
+}
+
 # ----------------------------------------------------------
 # Types
 
@@ -126,7 +137,7 @@ def connect(
     driver: str,
     entrypoint: str = None,
     db_kwargs: Optional[Dict[str, str]] = None,
-    conn_kwargs: Optional[Dict[str, str]] = None
+    conn_kwargs: Optional[Dict[str, str]] = None,
 ) -> "Connection":
     """
     Connect to a database via ADBC.
@@ -159,7 +170,7 @@ def connect(
     try:
         db = _lib.AdbcDatabase(**db_kwargs)
         conn = _lib.AdbcConnection(db, **conn_kwargs)
-        return Connection(db, conn)
+        return Connection(db, conn, conn_kwargs)
     except Exception:
         if conn:
             conn.close()
@@ -173,6 +184,8 @@ def connect(
 
 
 class _Closeable:
+    """Base class providing context manager interface."""
+
     def __enter__(self) -> "Self":
         return self
 
@@ -180,6 +193,32 @@ class _Closeable:
         self.close()
 
 
+class _SharedDatabase(_Closeable):
+    """A holder for a shared AdbcDatabase."""
+
+    def __init__(self, db: _lib.AdbcDatabase) -> None:
+        self._db = db
+        self._lock = threading.Lock()
+        self._refcount = 1
+
+    def _inc(self) -> None:
+        with self._lock:
+            self._refcount += 1
+
+    def _dec(self) -> int:
+        with self._lock:
+            self._refcount -= 1
+            return self._refcount
+
+    def clone(self) -> "Self":
+        self._inc()
+        return self
+
+    def close(self) -> None:
+        if self._dec() == 0:
+            self._db.close()
+
+
 class Connection(_Closeable):
     """
     A DB-API 2.0 (PEP 249) connection.
@@ -199,9 +238,18 @@ class Connection(_Closeable):
     ProgrammingError = _lib.ProgrammingError
     NotSupportedError = _lib.NotSupportedError
 
-    def __init__(self, db: _lib.AdbcDatabase, conn: _lib.AdbcConnection) -> 
None:
-        self._db = db
+    def __init__(
+        self,
+        db: Union[_lib.AdbcDatabase, _SharedDatabase],
+        conn: _lib.AdbcConnection,
+        conn_kwargs: Optional[Dict[str, str]] = None,
+    ) -> None:
+        if isinstance(db, _SharedDatabase):
+            self._db = db.clone()
+        else:
+            self._db = _SharedDatabase(db)
         self._conn = conn
+        self._conn_kwargs = conn_kwargs
 
         try:
             self._conn.set_autocommit(False)
@@ -215,7 +263,14 @@ class Connection(_Closeable):
             self._commit_supported = True
 
     def close(self) -> None:
-        """Close the connection."""
+        """
+        Close the connection.
+
+        Warnings
+        --------
+        Failure to close a connection may leak memory or database
+        connections.
+        """
         self._conn.close()
         self._db.close()
 
@@ -224,14 +279,164 @@ class Connection(_Closeable):
         if self._commit_supported:
             self._conn.commit()
 
+    def cursor(self) -> "Cursor":
+        """Create a new cursor for querying the database."""
+        return Cursor(self)
+
     def rollback(self) -> None:
         """Explicitly rollback."""
         if self._commit_supported:
             self._conn.rollback()
 
-    def cursor(self) -> "Cursor":
-        """Create a new cursor for querying the database."""
-        return Cursor(self)
+    # ------------------------------------------------------------
+    # API Extensions
+    # ------------------------------------------------------------
+
+    def adbc_clone(self) -> "Connection":
+        """
+        Create a new Connection sharing the same underlying database.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        conn = _lib.AdbcConnection(self._db._db, **(self._conn_kwargs or {}))
+        return Connection(self._db, conn)
+
+    def adbc_get_info(self) -> Dict[Union[str, int], Any]:
+        """
+        Get metadata about the database and driver.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        handle = self._conn.get_info()
+        reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
+        info = reader.read_all().to_pylist()
+        return dict(
+            {
+                _KNOWN_INFO_VALUES.get(row["info_name"], row["info_name"]): 
row[
+                    "info_value"
+                ]
+                for row in info
+            }
+        )
+
+    def adbc_get_objects(
+        self,
+        *,
+        depth: Literal["all", "catalogs", "db_schemas", "tables", "columns"] = 
"all",
+        catalog_filter: Optional[str] = None,
+        db_schema_filter: Optional[str] = None,
+        table_name_filter: Optional[str] = None,
+        table_types_filter: Optional[List[str]] = None,
+        column_name_filter: Optional[str] = None,
+    ) -> pyarrow.RecordBatchReader:
+        """
+        List catalogs, schemas, tables, etc. in the database.
+
+        Parameters
+        ----------
+        depth
+            What objects to return info on.
+        catalog_filter
+            An optional filter on the catalog names returned.
+        db_schema_filter
+            An optional filter on the database schema names returned.
+        table_name_filter
+            An optional filter on the table names returned.
+        table_types_filter
+            An optional list of types of tables returned.
+        column_name_filter
+            An optional filter on the column names returned.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        if depth in ("all", "columns"):
+            c_depth = _lib.GetObjectsDepth.ALL
+        elif depth == "catalogs":
+            c_depth = _lib.GetObjectsDepth.CATALOGS
+        elif depth == "db_schemas":
+            c_depth = _lib.GetObjectsDepth.DB_SCHEMAS
+        elif depth == "tables":
+            c_depth = _lib.GetObjectsDepth.TABLES
+        else:
+            raise ValueError(f"Invalid value for 'depth': {depth}")
+        handle = self._conn.get_objects(
+            c_depth,
+            catalog=catalog_filter,
+            db_schema=db_schema_filter,
+            table_name=table_name_filter,
+            table_types=table_types_filter,
+            column_name=column_name_filter,
+        )
+        return pyarrow.RecordBatchReader._import_from_c(handle.address)
+
+    def adbc_get_table_schema(
+        self,
+        table_name: str,
+        *,
+        catalog_filter: Optional[str] = None,
+        db_schema_filter: Optional[str] = None,
+    ) -> pyarrow.Schema:
+        """
+        Get the Arrow schema of a table by name.
+
+        Parameters
+        ----------
+        table_name
+            The table to get the schema of.
+        catalog_filter
+            An optional filter on the catalog name of the table.
+        db_schema_filter
+            An optional filter on the database schema name of the table.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        handle = self._conn.get_table_schema(
+            catalog_filter, db_schema_filter, table_name
+        )
+        return pyarrow.Schema._import_from_c(handle.address)
+
+    def adbc_get_table_types(self) -> List[str]:
+        """
+        List the types of tables that the server knows about.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        handle = self._conn.get_table_types()
+        reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
+        table = reader.read_all()
+        return table[0].to_pylist()
+
+    @property
+    def adbc_connection(self) -> _lib.AdbcConnection:
+        """
+        Get the underlying ADBC connection.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        return self._conn
+
+    @property
+    def adbc_database(self) -> _lib.AdbcDatabase:
+        """
+        Get the underlying ADBC database.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        return self._db._db
 
 
 class Cursor(_Closeable):
@@ -244,7 +449,7 @@ class Cursor(_Closeable):
     def __init__(self, conn: Connection) -> None:
         self._conn = conn
         self._stmt = _lib.AdbcStatement(conn._conn)
-        self._last_query: Optional[str] = None
+        self._last_query: Optional[Union[str, bytes]] = None
         self._results: Optional["_RowIterator"] = None
         self._arraysize = 1
         self._rowcount = -1
@@ -300,7 +505,11 @@ class Cursor(_Closeable):
         self._results = None
         if operation != self._last_query:
             self._last_query = operation
-            self._stmt.set_sql_query(operation)
+            if isinstance(operation, bytes):
+                # Serialized Substrait plan
+                self._stmt.set_substrait_plan(operation)
+            else:
+                self._stmt.set_sql_query(operation)
             try:
                 self._stmt.prepare()
             except NotSupportedError:
@@ -354,37 +563,6 @@ class Cursor(_Closeable):
         self._stmt.bind(arr_handle, sch_handle)
         self._rowcount = self._stmt.execute_update()
 
-    def execute_partitions(
-        self, operation, parameters=None
-    ) -> Tuple[List[bytes], pyarrow.Schema]:
-        """
-        Execute a query and get the partitions of a distributed result set.
-
-        This is an extension method, not present in DBAPI.
-
-        Return
-        ------
-        partitions : list of byte
-            A list of partition descriptors, which can be read with
-            read_partition.
-        schema : pyarrow.Schema
-            The schema of the result set.
-        """
-        self._prepare_execute(operation, parameters)
-        partitions, schema, self._rowcount = self._stmt.execute_partitions()
-        return partitions, pyarrow.Schema._import_from_c(schema.address())
-
-    def read_partition(self, partition: bytes) -> None:
-        """
-        Read a partition of a distributed result set.
-        """
-        self._results = None
-        handle = self.conn._conn.read_partition(partition)
-        self._rowcount = -1
-        self._results = _RowIterator(
-            pyarrow.RecordBatchReader._import_from_c(handle.address)
-        )
-
     def fetchone(self) -> tuple:
         """Fetch one row of the result."""
         if self._results is None:
@@ -414,11 +592,150 @@ class Cursor(_Closeable):
             )
         return self._results.fetchall()
 
+    def next(self):
+        """Fetch the next row, or raise StopIteration."""
+        row = self.fetchone()
+        if row is None:
+            raise StopIteration
+        return row
+
+    def nextset(self):
+        raise NotSupportedError("Cursor.nextset")
+
+    def setinputsizes(self, sizes):
+        # Not used
+        pass
+
+    def setoutputsize(self, size, column=None):
+        # Not used
+        pass
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        return self.next()
+
+    # ------------------------------------------------------------
+    # API Extensions
+    # ------------------------------------------------------------
+
+    def adbc_ingest(
+        self,
+        table_name: str,
+        data: Union[pyarrow.RecordBatch, pyarrow.Table, 
pyarrow.RecordBatchReader],
+        mode: Literal["append", "create"] = "create",
+    ) -> int:
+        """
+        Ingest Arrow data into a database table.
+
+        Depending on the driver, this can avoid per-row overhead that
+        would result from a typical prepare-bind-insert loop.
+
+        Parameters
+        ----------
+        table_name
+            The table to insert into.
+        data
+            The Arrow data to insert.
+        mode
+            Whether to append data to an existing table, or create a new table.
+
+        Returns
+        -------
+        int
+            The number of rows inserted, or -1 if the driver cannot
+            provide this information.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        if mode == "append":
+            c_mode = _lib.INGEST_OPTION_MODE_APPEND
+        elif mode == "create":
+            c_mode = _lib.INGEST_OPTION_MODE_CREATE
+        else:
+            raise ValueError(f"Invalid value for 'mode': {mode}")
+        self._stmt.set_options(
+            **{
+                _lib.INGEST_OPTION_TARGET_TABLE: table_name,
+                _lib.INGEST_OPTION_MODE: c_mode,
+            }
+        )
+
+        if isinstance(data, pyarrow.RecordBatch):
+            array = _lib.ArrowArrayHandle()
+            schema = _lib.ArrowSchemaHandle()
+            data._export_to_c(array.address, schema.address)
+            self._stmt.bind(array, schema)
+        else:
+            if isinstance(data, pyarrow.Table):
+                data = data.to_reader()
+            handle = _lib.ArrowArrayStreamHandle()
+            data._export_to_c(handle.address)
+            self._stmt.bind_stream(handle)
+
+        self._last_query = None
+        return self._stmt.execute_update()
+
+    def adbc_execute_partitions(
+        self, operation, parameters=None
+    ) -> Tuple[List[bytes], pyarrow.Schema]:
+        """
+        Execute a query and get the partitions of a distributed result set.
+
+        Parameters
+        ----------
+        partitions : list of byte
+            A list of partition descriptors, which can be read with
+            read_partition.
+        schema : pyarrow.Schema
+            The schema of the result set.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        self._prepare_execute(operation, parameters)
+        partitions, schema, self._rowcount = self._stmt.execute_partitions()
+        return partitions, pyarrow.Schema._import_from_c(schema.address)
+
+    def adbc_read_partition(self, partition: bytes) -> None:
+        """
+        Read a partition of a distributed result set.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        self._results = None
+        handle = self.conn._conn.read_partition(partition)
+        self._rowcount = -1
+        self._results = _RowIterator(
+            pyarrow.RecordBatchReader._import_from_c(handle.address)
+        )
+
+    @property
+    def adbc_statement(self) -> _lib.AdbcStatement:
+        """
+        Get the underlying ADBC statement.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        return self._stmt
+
     def fetchallarrow(self) -> pyarrow.Table:
         """
         Fetch all rows of the result as a PyArrow Table.
 
         This implements a similar API as turbodbc.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
         """
         return self.fetch_arrow_table()
 
@@ -427,6 +744,10 @@ class Cursor(_Closeable):
         Fetch all rows of the result as a PyArrow Table.
 
         This implements a similar API as DuckDB.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
         """
         if self._results is None:
             raise ProgrammingError(
@@ -435,11 +756,15 @@ class Cursor(_Closeable):
             )
         return self._results.fetch_arrow_table()
 
-    def fetch_df(self):
+    def fetch_df(self) -> "pandas.DataFrame":
         """
         Fetch all rows of the result as a Pandas DataFrame.
 
         This implements a similar API as DuckDB.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
         """
         if self._results is None:
             raise ProgrammingError(
@@ -448,30 +773,6 @@ class Cursor(_Closeable):
             )
         return self._results.fetch_df()
 
-    def next(self):
-        """Fetch the next row, or raise StopIteration."""
-        row = self.fetchone()
-        if row is None:
-            raise StopIteration
-        return row
-
-    def nextset(self):
-        raise NotSupportedError("Cursor.nextset")
-
-    def setinputsizes(self, sizes):
-        # Not used
-        pass
-
-    def setoutputsize(self, size, column=None):
-        # Not used
-        pass
-
-    def __iter__(self):
-        return self
-
-    def __next__(self):
-        return self.next()
-
 
 class _RowIterator(_Closeable):
     """Track state needed to iterate over the result set."""
@@ -507,10 +808,7 @@ class _RowIterator(_Closeable):
         if self._finished:
             return None
 
-        row = tuple(
-            _convert_value(arr, row=self._next_row)
-            for arr in self._current_batch.columns
-        )
+        row = tuple(arr[self._next_row].as_py() for arr in 
self._current_batch.columns)
         self._next_row += 1
         self.rownumber += 1
         return row
@@ -538,8 +836,3 @@ class _RowIterator(_Closeable):
 
     def fetch_df(self):
         return self._reader.read_pandas()
-
-
[email protected]
-def _convert_value(arr: pyarrow.Array, *, row: int) -> Any:
-    return arr[row].as_py()
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py 
b/python/adbc_driver_manager/tests/test_dbapi.py
index f8e6d0f..d3d34af 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -62,6 +62,114 @@ def test_attrs(sqlite):
         assert cur.rowcount == -1
 
 
[email protected]
+def test_info(sqlite):
+    assert sqlite.adbc_get_info() == {
+        "vendor_arrow_version": "Arrow/C++ 9.0.0",
+        "vendor_name": "ADBC C SQLite3",
+        "vendor_version": "0.0.1",
+    }
+
+
[email protected]
+def test_get_underlying(sqlite):
+    assert sqlite.adbc_database
+    assert sqlite.adbc_connection
+    with sqlite.cursor() as cur:
+        assert cur.adbc_statement
+
+
[email protected]
+def test_clone(sqlite):
+    with sqlite.adbc_clone() as sqlite2:
+        with sqlite2.cursor() as cur:
+            cur.execute("CREATE TABLE temporary (ints)")
+            cur.execute("INSERT INTO temporary VALUES (1)")
+        sqlite2.commit()
+
+    with sqlite.cursor() as cur:
+        cur.execute("SELECT * FROM temporary")
+        assert cur.fetchone() == (1,)
+
+
[email protected]
+def test_get_objects(sqlite):
+    with sqlite.cursor() as cur:
+        cur.execute("CREATE TABLE temporary (ints)")
+        cur.execute("INSERT INTO temporary VALUES (1)")
+    metadata = (
+        
sqlite.adbc_get_objects(table_name_filter="temporary").read_all().to_pylist()
+    )
+    assert len(metadata) == 1
+    assert metadata[0]["catalog_name"] == "main"
+    schemas = metadata[0]["catalog_db_schemas"]
+    assert len(schemas) == 1
+    assert schemas[0]["db_schema_name"] is None
+    tables = schemas[0]["db_schema_tables"]
+    assert len(tables) == 1
+    assert tables[0]["table_name"] == "temporary"
+    assert tables[0]["table_type"] == "table"
+    assert tables[0]["table_columns"][0]["column_name"] == "ints"
+    assert tables[0]["table_columns"][0]["ordinal_position"] == 1
+    assert tables[0]["table_constraints"] == []
+
+
[email protected]
+def test_get_table_schema(sqlite):
+    with sqlite.cursor() as cur:
+        cur.execute("CREATE TABLE temporary (ints)")
+        cur.execute("INSERT INTO temporary VALUES (1)")
+    assert sqlite.adbc_get_table_schema("temporary") == pyarrow.schema(
+        [("ints", pyarrow.int64())]
+    )
+
+
[email protected]
+def test_get_table_types(sqlite):
+    assert sqlite.adbc_get_table_types() == ["table", "view"]
+
+
[email protected](
+    "data",
+    [
+        lambda: pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", 
"strs"]),
+        lambda: pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"]),
+        lambda: pyarrow.table(
+            [[1, 2], ["foo", ""]], names=["ints", "strs"]
+        ).to_reader(),
+    ],
+)
[email protected]
+def test_ingest(data, sqlite):
+    with sqlite.cursor() as cur:
+        cur.adbc_ingest("bulk_ingest", data())
+
+        with pytest.raises(dbapi.ProgrammingError):
+            cur.adbc_ingest("bulk_ingest", data())
+
+        cur.adbc_ingest("bulk_ingest", data(), mode="append")
+
+        with pytest.raises(dbapi.Error):
+            cur.adbc_ingest("nonexistent", data(), mode="append")
+
+        with pytest.raises(ValueError):
+            cur.adbc_ingest("bulk_ingest", data(), mode="invalid")
+
+    with sqlite.cursor() as cur:
+        cur.execute("SELECT * FROM bulk_ingest")
+        assert cur.fetchone() == (1, "foo")
+        assert cur.fetchone() == (2, "")
+        assert cur.fetchone() == (1, "foo")
+        assert cur.fetchone() == (2, "")
+
+
[email protected]
+def test_partitions(sqlite):
+    with pytest.raises(dbapi.NotSupportedError):
+        with sqlite.cursor() as cur:
+            cur.adbc_execute_partitions("SELECT 1")
+
+
 @pytest.mark.sqlite
 def test_query_fetch_py(sqlite):
     with sqlite.cursor() as cur:
@@ -124,6 +232,13 @@ def test_query_parameters(sqlite):
         assert cur.fetchall() == [(2.0, 2)]
 
 
[email protected]
+def test_query_substrait(sqlite):
+    with sqlite.cursor() as cur:
+        with pytest.raises(dbapi.NotSupportedError):
+            cur.execute(b"Substrait plan")
+
+
 @pytest.mark.sqlite
 def test_executemany(sqlite):
     with sqlite.cursor() as cur:

Reply via email to