This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch spec-1.1.0
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git

commit 9c41a16c530b93cdbc4929f43479713582c3c5e6
Author: David Li <[email protected]>
AuthorDate: Thu Jul 27 11:02:37 2023 -0400

    feat(python): expose ADBC 1.1.0 features (#937)
---
 c/driver/postgresql/statement.cc                   |  26 +-
 docs/source/python/api/adbc_driver_manager.rst     |   2 +
 .../adbc_driver_manager/__init__.py                |  10 +
 .../adbc_driver_manager/_lib.pyi                   |  23 +-
 .../adbc_driver_manager/_lib.pyx                   | 449 +++++++++++++++++++--
 .../adbc_driver_manager/dbapi.py                   |  87 +++-
 python/adbc_driver_postgresql/tests/test_dbapi.py  |  71 ++++
 7 files changed, 626 insertions(+), 42 deletions(-)

diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 6521fce9..93290202 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -660,15 +660,19 @@ int TupleReader::GetNext(struct ArrowArray* out) {
 
   // Check the server-side response
   result_ = PQgetResult(conn_);
-  const int pq_status = PQresultStatus(result_);
+  const ExecStatusType pq_status = PQresultStatus(result_);
   if (pq_status != PGRES_COMMAND_OK) {
-    StringBuilderAppend(&error_builder_, "[libpq] Query failed [%d]: %s", 
pq_status,
-                        PQresultErrorMessage(result_));
+    const char* sqlstate = PQresultErrorField(result_, PG_DIAG_SQLSTATE);
+    StringBuilderAppend(&error_builder_, "[libpq] Query failed [%s]: %s",
+                        PQresStatus(pq_status), PQresultErrorMessage(result_));
 
     if (tmp.release != nullptr) {
       tmp.release(&tmp);
     }
 
+    if (sqlstate != nullptr && std::strcmp(sqlstate, "57014") == 0) {
+      return ECANCELED;
+    }
     return EIO;
   }
 
@@ -1078,7 +1082,7 @@ AdbcStatusCode PostgresStatement::GetOption(const char* 
key, char* value, size_t
   } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 
0) {
     result = std::to_string(reader_.batch_size_hint_bytes_);
   } else {
-    SetError(error, "[libq] Unknown statement option '%s'", key);
+    SetError(error, "[libpq] Unknown statement option '%s'", key);
     return ADBC_STATUS_NOT_FOUND;
   }
 
@@ -1092,13 +1096,13 @@ AdbcStatusCode PostgresStatement::GetOption(const char* 
key, char* value, size_t
 AdbcStatusCode PostgresStatement::GetOptionBytes(const char* key, uint8_t* 
value,
                                                  size_t* length,
                                                  struct AdbcError* error) {
-  SetError(error, "[libq] Unknown statement option '%s'", key);
+  SetError(error, "[libpq] Unknown statement option '%s'", key);
   return ADBC_STATUS_NOT_FOUND;
 }
 
 AdbcStatusCode PostgresStatement::GetOptionDouble(const char* key, double* 
value,
                                                   struct AdbcError* error) {
-  SetError(error, "[libq] Unknown statement option '%s'", key);
+  SetError(error, "[libpq] Unknown statement option '%s'", key);
   return ADBC_STATUS_NOT_FOUND;
 }
 
@@ -1109,7 +1113,7 @@ AdbcStatusCode PostgresStatement::GetOptionInt(const 
char* key, int64_t* value,
     *value = reader_.batch_size_hint_bytes_;
     return ADBC_STATUS_OK;
   }
-  SetError(error, "[libq] Unknown statement option '%s'", key);
+  SetError(error, "[libpq] Unknown statement option '%s'", key);
   return ADBC_STATUS_NOT_FOUND;
 }
 
@@ -1173,7 +1177,7 @@ AdbcStatusCode PostgresStatement::SetOption(const char* 
key, const char* value,
 
     this->reader_.batch_size_hint_bytes_ = int_value;
   } else {
-    SetError(error, "[libq] Unknown statement option '%s'", key);
+    SetError(error, "[libpq] Unknown statement option '%s'", key);
     return ADBC_STATUS_NOT_IMPLEMENTED;
   }
   return ADBC_STATUS_OK;
@@ -1181,13 +1185,13 @@ AdbcStatusCode PostgresStatement::SetOption(const char* 
key, const char* value,
 
 AdbcStatusCode PostgresStatement::SetOptionBytes(const char* key, const 
uint8_t* value,
                                                  size_t length, struct 
AdbcError* error) {
-  SetError(error, "%s%s", "[libpq] Unknown option ", key);
+  SetError(error, "%s%s", "[libpq] Unknown statement option ", key);
   return ADBC_STATUS_NOT_IMPLEMENTED;
 }
 
 AdbcStatusCode PostgresStatement::SetOptionDouble(const char* key, double 
value,
                                                   struct AdbcError* error) {
-  SetError(error, "%s%s", "[libpq] Unknown option ", key);
+  SetError(error, "%s%s", "[libpq] Unknown statement option ", key);
   return ADBC_STATUS_NOT_IMPLEMENTED;
 }
 
@@ -1202,7 +1206,7 @@ AdbcStatusCode PostgresStatement::SetOptionInt(const 
char* key, int64_t value,
     this->reader_.batch_size_hint_bytes_ = value;
     return ADBC_STATUS_OK;
   }
-  SetError(error, "%s%s", "[libpq] Unknown option ", key);
+  SetError(error, "[libpq] Unknown statement option '%s'", key);
   return ADBC_STATUS_NOT_IMPLEMENTED;
 }
 
diff --git a/docs/source/python/api/adbc_driver_manager.rst 
b/docs/source/python/api/adbc_driver_manager.rst
index c0d22b62..7023af6a 100644
--- a/docs/source/python/api/adbc_driver_manager.rst
+++ b/docs/source/python/api/adbc_driver_manager.rst
@@ -31,9 +31,11 @@ Constants & Enums
 
 .. autoclass:: adbc_driver_manager.AdbcStatusCode
    :members:
+   :undoc-members:
 
 .. autoclass:: adbc_driver_manager.GetObjectsDepth
    :members:
+   :undoc-members:
 
 .. autoclass:: adbc_driver_manager.ConnectionOptions
    :members:
diff --git a/python/adbc_driver_manager/adbc_driver_manager/__init__.py 
b/python/adbc_driver_manager/adbc_driver_manager/__init__.py
index e2eaee57..25b821eb 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/__init__.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/__init__.py
@@ -90,6 +90,8 @@ class DatabaseOptions(enum.Enum):
 
     #: Set the password to use for username-password authentication.
     PASSWORD = "password"
+    #: The URI to connect to.
+    URI = "uri"
     #: Set the username to use for username-password authentication.
     USERNAME = "username"
 
@@ -100,6 +102,10 @@ class ConnectionOptions(enum.Enum):
     Not all drivers support all options.
     """
 
+    #: Get/set the current catalog.
+    CURRENT_CATALOG = "adbc.connection.catalog"
+    #: Get/set the current schema.
+    CURRENT_DB_SCHEMA = "adbc.connection.db_schema"
     #: Set the transaction isolation level.
     ISOLATION_LEVEL = "adbc.connection.transaction.isolation_level"
 
@@ -110,7 +116,11 @@ class StatementOptions(enum.Enum):
     Not all drivers support all options.
     """
 
+    #: Enable incremental execution on ExecutePartitions.
+    INCREMENTAL = "adbc.statement.exec.incremental"
     #: For bulk ingestion, whether to create or append to the table.
     INGEST_MODE = INGEST_OPTION_MODE
     #: For bulk ingestion, the table to ingest into.
     INGEST_TARGET_TABLE = INGEST_OPTION_TARGET_TABLE
+    #: Get progress of a query.
+    PROGRESS = "adbc.statement.exec.progress"
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
index 8f107369..7723df17 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
@@ -26,15 +26,22 @@ import typing
 INGEST_OPTION_MODE: str
 INGEST_OPTION_MODE_APPEND: str
 INGEST_OPTION_MODE_CREATE: str
+INGEST_OPTION_MODE_CREATE_APPEND: str
+INGEST_OPTION_MODE_REPLACE: str
 INGEST_OPTION_TARGET_TABLE: str
 
 class AdbcConnection(_AdbcHandle):
     def __init__(self, database: "AdbcDatabase", **kwargs: str) -> None: ...
+    def cancel(self) -> None: ...
     def close(self) -> None: ...
     def commit(self) -> None: ...
     def get_info(
         self, info_codes: Optional[List[Union[int, "AdbcInfoCode"]]] = None
     ) -> "ArrowArrayStreamHandle": ...
+    def get_option(self, key: str) -> str: ...
+    def get_option_bytes(self, key: str) -> bytes: ...
+    def get_option_float(self, key: str) -> float: ...
+    def get_option_int(self, key: str) -> int: ...
     def get_objects(
         self,
         depth: "GetObjectsDepth",
@@ -54,12 +61,16 @@ class AdbcConnection(_AdbcHandle):
     def read_partition(self, partition: bytes) -> "ArrowArrayStreamHandle": ...
     def rollback(self) -> None: ...
     def set_autocommit(self, enabled: bool) -> None: ...
-    def set_options(self, **kwargs: str) -> None: ...
+    def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ...
 
 class AdbcDatabase(_AdbcHandle):
     def __init__(self, **kwargs: str) -> None: ...
     def close(self) -> None: ...
-    def set_options(self, **kwargs: str) -> None: ...
+    def get_option(self, key: str) -> str: ...
+    def get_option_bytes(self, key: str) -> bytes: ...
+    def get_option_float(self, key: str) -> float: ...
+    def get_option_int(self, key: str) -> int: ...
+    def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ...
 
 class AdbcInfoCode(enum.IntEnum):
     DRIVER_ARROW_VERSION = ...
@@ -73,13 +84,19 @@ class AdbcStatement(_AdbcHandle):
     def __init__(self, *args, **kwargs) -> None: ...
     def bind(self, *args, **kwargs) -> Any: ...
     def bind_stream(self, *args, **kwargs) -> Any: ...
+    def cancel(self) -> None: ...
     def close(self) -> None: ...
     def execute_partitions(self, *args, **kwargs) -> Any: ...
     def execute_query(self, *args, **kwargs) -> Any: ...
+    def execute_schema(self) -> "ArrowSchemaHandle": ...
     def execute_update(self, *args, **kwargs) -> Any: ...
+    def get_option(self, key: str) -> str: ...
+    def get_option_bytes(self, key: str) -> bytes: ...
+    def get_option_float(self, key: str) -> float: ...
+    def get_option_int(self, key: str) -> int: ...
     def get_parameter_schema(self, *args, **kwargs) -> Any: ...
     def prepare(self, *args, **kwargs) -> Any: ...
-    def set_options(self, *args, **kwargs) -> Any: ...
+    def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ...
     def set_sql_query(self, *args, **kwargs) -> Any: ...
     def set_substrait_plan(self, *args, **kwargs) -> Any: ...
     def __reduce__(self) -> Any: ...
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 406d5778..a5ccc23b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -69,6 +69,8 @@ cdef extern from "adbc.h" nogil:
     cdef const char* ADBC_INGEST_OPTION_MODE
     cdef const char* ADBC_INGEST_OPTION_MODE_APPEND
     cdef const char* ADBC_INGEST_OPTION_MODE_CREATE
+    cdef const char* ADBC_INGEST_OPTION_MODE_REPLACE
+    cdef const char* ADBC_INGEST_OPTION_MODE_CREATE_APPEND
 
     cdef int ADBC_OBJECT_DEPTH_ALL
     cdef int ADBC_OBJECT_DEPTH_CATALOGS
@@ -112,11 +114,22 @@ cdef extern from "adbc.h" nogil:
         CAdbcPartitionsRelease release
 
     CAdbcStatusCode AdbcDatabaseNew(CAdbcDatabase* database, CAdbcError* error)
+    CAdbcStatusCode AdbcDatabaseGetOption(
+        CAdbcDatabase*, const char*, char*, size_t*, CAdbcError*);
+    CAdbcStatusCode AdbcDatabaseGetOptionBytes(
+        CAdbcDatabase*, const char*, uint8_t*, size_t*, CAdbcError*);
+    CAdbcStatusCode AdbcDatabaseGetOptionDouble(
+        CAdbcDatabase*, const char*, double*, CAdbcError*);
+    CAdbcStatusCode AdbcDatabaseGetOptionInt(
+        CAdbcDatabase*, const char*, int64_t*, CAdbcError*);
     CAdbcStatusCode AdbcDatabaseSetOption(
-        CAdbcDatabase* database,
-        const char* key,
-        const char* value,
-        CAdbcError* error)
+        CAdbcDatabase*, const char*, const char*, CAdbcError*)
+    CAdbcStatusCode AdbcDatabaseSetOptionBytes(
+        CAdbcDatabase*, const char*, const uint8_t*, size_t, CAdbcError*)
+    CAdbcStatusCode AdbcDatabaseSetOptionDouble(
+        CAdbcDatabase*, const char*, double, CAdbcError*)
+    CAdbcStatusCode AdbcDatabaseSetOptionInt(
+        CAdbcDatabase*, const char*, int64_t, CAdbcError*)
     CAdbcStatusCode AdbcDatabaseInit(CAdbcDatabase* database, CAdbcError* 
error)
     CAdbcStatusCode AdbcDatabaseRelease(CAdbcDatabase* database, CAdbcError* 
error)
 
@@ -126,6 +139,7 @@ cdef extern from "adbc.h" nogil:
         CAdbcDriverInitFunc init_func,
         CAdbcError* error)
 
+    CAdbcStatusCode AdbcConnectionCancel(CAdbcConnection*, CAdbcError*)
     CAdbcStatusCode AdbcConnectionCommit(
         CAdbcConnection* connection,
         CAdbcError* error)
@@ -154,6 +168,19 @@ cdef extern from "adbc.h" nogil:
         const char* column_name,
         CArrowArrayStream* stream,
         CAdbcError* error)
+    CAdbcStatusCode AdbcConnectionGetOption(
+        CAdbcConnection*, const char*, char*, size_t*, CAdbcError*);
+    CAdbcStatusCode AdbcConnectionGetOptionBytes(
+        CAdbcConnection*, const char*, uint8_t*, size_t*, CAdbcError*);
+    CAdbcStatusCode AdbcConnectionGetOptionDouble(
+        CAdbcConnection*, const char*, double*, CAdbcError*);
+    CAdbcStatusCode AdbcConnectionGetOptionInt(
+        CAdbcConnection*, const char*, int64_t*, CAdbcError*);
+    CAdbcStatusCode AdbcConnectionGetStatistics(
+        CAdbcConnection*, const char*, const char*, const char*,
+        char, CArrowArrayStream*, CAdbcError*);
+    CAdbcStatusCode AdbcConnectionGetStatisticNames(
+        CAdbcConnection*, CArrowArrayStream*, CAdbcError*);
     CAdbcStatusCode AdbcConnectionGetTableSchema(
         CAdbcConnection* connection,
         const char* catalog,
@@ -172,20 +199,24 @@ cdef extern from "adbc.h" nogil:
     CAdbcStatusCode AdbcConnectionNew(
         CAdbcConnection* connection,
         CAdbcError* error)
-    CAdbcStatusCode AdbcConnectionSetOption(
-        CAdbcConnection* connection,
-        const char* key,
-        const char* value,
-        CAdbcError* error)
     CAdbcStatusCode AdbcConnectionRelease(
         CAdbcConnection* connection,
         CAdbcError* error)
-
+    CAdbcStatusCode AdbcConnectionSetOption(
+        CAdbcConnection*, const char*, const char*, CAdbcError*)
+    CAdbcStatusCode AdbcConnectionSetOptionBytes(
+        CAdbcConnection*, const char*, const uint8_t*, size_t, CAdbcError*)
+    CAdbcStatusCode AdbcConnectionSetOptionDouble(
+        CAdbcConnection*, const char*, double, CAdbcError*)
+    CAdbcStatusCode AdbcConnectionSetOptionInt(
+        CAdbcConnection*, const char*, int64_t, CAdbcError*)
     CAdbcStatusCode AdbcStatementBind(
         CAdbcStatement* statement,
         CArrowArray*,
         CArrowSchema*,
         CAdbcError* error)
+
+    CAdbcStatusCode AdbcStatementCancel(CAdbcStatement*, CAdbcError*)
     CAdbcStatusCode AdbcStatementBindStream(
         CAdbcStatement* statement,
         CArrowArrayStream*,
@@ -199,6 +230,16 @@ cdef extern from "adbc.h" nogil:
         CAdbcStatement* statement,
         CArrowArrayStream* out, int64_t* rows_affected,
         CAdbcError* error)
+    CAdbcStatusCode AdbcStatementExecuteSchema(
+        CAdbcStatement*, CArrowSchema*, CAdbcError*)
+    CAdbcStatusCode AdbcStatementGetOption(
+        CAdbcStatement*, const char*, char*, size_t*, CAdbcError*);
+    CAdbcStatusCode AdbcStatementGetOptionBytes(
+        CAdbcStatement*, const char*, uint8_t*, size_t*, CAdbcError*);
+    CAdbcStatusCode AdbcStatementGetOptionDouble(
+        CAdbcStatement*, const char*, double*, CAdbcError*);
+    CAdbcStatusCode AdbcStatementGetOptionInt(
+        CAdbcStatement*, const char*, int64_t*, CAdbcError*);
     CAdbcStatusCode AdbcStatementGetParameterSchema(
         CAdbcStatement* statement,
         CArrowSchema* schema,
@@ -211,10 +252,13 @@ cdef extern from "adbc.h" nogil:
         CAdbcStatement* statement,
         CAdbcError* error)
     CAdbcStatusCode AdbcStatementSetOption(
-        CAdbcStatement* statement,
-        const char* key,
-        const char* value,
-        CAdbcError* error)
+        CAdbcStatement*, const char*, const char*, CAdbcError*)
+    CAdbcStatusCode AdbcStatementSetOptionBytes(
+        CAdbcStatement*, const char*, const uint8_t*, size_t, CAdbcError*)
+    CAdbcStatusCode AdbcStatementSetOptionDouble(
+        CAdbcStatement*, const char*, double, CAdbcError*)
+    CAdbcStatusCode AdbcStatementSetOptionInt(
+        CAdbcStatement*, const char*, int64_t, CAdbcError*)
     CAdbcStatusCode AdbcStatementSetSqlQuery(
         CAdbcStatement* statement,
         const char* query,
@@ -348,6 +392,8 @@ NotSupportedError.__module__ = "adbc_driver_manager"
 INGEST_OPTION_MODE = ADBC_INGEST_OPTION_MODE.decode("utf-8")
 INGEST_OPTION_MODE_APPEND = ADBC_INGEST_OPTION_MODE_APPEND.decode("utf-8")
 INGEST_OPTION_MODE_CREATE = ADBC_INGEST_OPTION_MODE_CREATE.decode("utf-8")
+INGEST_OPTION_MODE_REPLACE = ADBC_INGEST_OPTION_MODE_REPLACE.decode("utf-8")
+INGEST_OPTION_MODE_CREATE_APPEND = 
ADBC_INGEST_OPTION_MODE_CREATE_APPEND.decode("utf-8")
 INGEST_OPTION_TARGET_TABLE = ADBC_INGEST_OPTION_TARGET_TABLE.decode("utf-8")
 
 
@@ -521,6 +567,11 @@ class GetObjectsDepth(enum.IntEnum):
     COLUMNS = ADBC_OBJECT_DEPTH_COLUMNS
 
 
+# Assume a driver won't return more than 128 MiB of option data at
+# once.
+_MAX_OPTION_SIZE = 2**27
+
+
 cdef class AdbcDatabase(_AdbcHandle):
     """
     An instance of a database.
@@ -581,15 +632,102 @@ cdef class AdbcDatabase(_AdbcHandle):
                 status = AdbcDatabaseRelease(&self.database, &c_error)
             check_error(status, &c_error)
 
+    def get_option(self, key: str) -> str:
+        """Get the value of a string option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef uint8_t* c_value = NULL
+        cdef size_t c_len = 0
+
+        buf = bytearray(1024)
+        while True:
+            c_value = buf
+            c_len = len(buf)
+            check_error(
+                AdbcDatabaseGetOption(
+                    &self.database, c_key, buf, &c_len, &c_error),
+                &c_error)
+            if c_len <= len(buf):
+                # Entire value read
+                break
+            else:
+                # Buffer too small
+                new_len = len(buf) * 2
+                if new_len > _MAX_OPTION_SIZE:
+                    raise RuntimeError(
+                        f"Could not read option {key}: "
+                        f"would need more than {len(buf)} bytes")
+                buf = bytearray(new_len)
+
+        # Remove trailing null terminator
+        if c_len > 0:
+            c_len -= 1
+        return buf[:c_len].decode("utf-8")
+
+    def get_option_bytes(self, key: str) -> bytes:
+        """Get the value of a binary option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef uint8_t* c_value = NULL
+        cdef size_t c_len = 0
+
+        buf = bytearray(1024)
+        while True:
+            c_value = buf
+            c_len = len(buf)
+            check_error(
+                AdbcDatabaseGetOptionBytes(
+                    &self.database, c_key, buf, &c_len, &c_error),
+                &c_error)
+            if c_len <= len(buf):
+                # Entire value read
+                break
+            else:
+                # Buffer too small
+                new_len = len(buf) * 2
+                if new_len > _MAX_OPTION_SIZE:
+                    raise RuntimeError(
+                        f"Could not read option {key}: "
+                        f"would need more than {len(buf)} bytes")
+                buf = bytearray(new_len)
+
+        return bytes(buf[:c_len])
+
+    def get_option_float(self, key: str) -> float:
+        """Get the value of a floating-point option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef double c_value = 0.0
+        check_error(
+            AdbcDatabaseGetOptionDouble(
+                &self.database, c_key, &c_value, &c_error),
+            &c_error)
+        return c_value
+
+    def get_option_int(self, key: str) -> int:
+        """Get the value of an integer option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef int64_t c_value = 0
+        check_error(
+            AdbcDatabaseGetOptionInt(
+                &self.database, c_key, &c_value, &c_error),
+            &c_error)
+        return c_value
+
     def set_options(self, **kwargs) -> None:
-        """Set arbitrary key-value options.
+        """
+        Set arbitrary key-value options.
 
         Note, not all drivers support setting options after creation.
 
         See Also
         --------
         adbc_driver_manager.DatabaseOptions : Standard option names.
-
         """
         cdef CAdbcError c_error = empty_error()
         cdef char* c_key = NULL
@@ -600,12 +738,28 @@ cdef class AdbcDatabase(_AdbcHandle):
 
             if value is None:
                 c_value = NULL
-            else:
+                status = AdbcDatabaseSetOption(
+                    &self.database, c_key, c_value, &c_error)
+            elif isinstance(value, str):
                 value = value.encode("utf-8")
                 c_value = value
+                status = AdbcDatabaseSetOption(
+                    &self.database, c_key, c_value, &c_error)
+            elif isinstance(value, bytes):
+                c_value = value
+                status = AdbcDatabaseSetOptionBytes(
+                    &self.database, c_key, <const uint8_t*> c_value, 
len(value), &c_error)
+            elif isinstance(value, float):
+                status = AdbcDatabaseSetOptionDouble(
+                    &self.database, c_key, value, &c_error)
+            elif isinstance(value, int):
+                status = AdbcDatabaseSetOptionInt(
+                    &self.database, c_key, value, &c_error)
+            else:
+                raise ValueError(
+                    f"Unsupported type {type(value)} for value {value!r} "
+                    f"of option {key}")
 
-            status = AdbcDatabaseSetOption(
-                &self.database, c_key, c_value, &c_error)
             check_error(status, &c_error)
 
 
@@ -659,6 +813,14 @@ cdef class AdbcConnection(_AdbcHandle):
 
         database._open_child()
 
+    def cancel(self) -> None:
+        """Attempt to cancel any ongoing operations on the connection."""
+        cdef CAdbcError c_error = empty_error()
+        cdef CAdbcStatusCode status
+        with nogil:
+            status = AdbcConnectionCancel(&self.connection, &c_error)
+        check_error(status, &c_error)
+
     def commit(self) -> None:
         """Commit the current transaction."""
         cdef CAdbcError c_error = empty_error()
@@ -747,6 +909,93 @@ cdef class AdbcConnection(_AdbcHandle):
 
         return stream
 
+    def get_option(self, key: str) -> str:
+        """Get the value of a string option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef uint8_t* c_value = NULL
+        cdef size_t c_len = 0
+
+        buf = bytearray(1024)
+        while True:
+            c_value = buf
+            c_len = len(buf)
+            check_error(
+                AdbcConnectionGetOption(
+                    &self.connection, c_key, buf, &c_len, &c_error),
+                &c_error)
+            if c_len <= len(buf):
+                # Entire value read
+                break
+            else:
+                # Buffer too small
+                new_len = len(buf) * 2
+                if new_len > _MAX_OPTION_SIZE:
+                    raise RuntimeError(
+                        f"Could not read option {key}: "
+                        f"would need more than {len(buf)} bytes")
+                buf = bytearray(new_len)
+
+        # Remove trailing null terminator
+        if c_len > 0:
+            c_len -= 1
+        return buf[:c_len].decode("utf-8")
+
+    def get_option_bytes(self, key: str) -> bytes:
+        """Get the value of a binary option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef uint8_t* c_value = NULL
+        cdef size_t c_len = 0
+
+        buf = bytearray(1024)
+        while True:
+            c_value = buf
+            c_len = len(buf)
+            check_error(
+                AdbcConnectionGetOptionBytes(
+                    &self.connection, c_key, buf, &c_len, &c_error),
+                &c_error)
+            if c_len <= len(buf):
+                # Entire value read
+                break
+            else:
+                # Buffer too small
+                new_len = len(buf) * 2
+                if new_len > _MAX_OPTION_SIZE:
+                    raise RuntimeError(
+                        f"Could not read option {key}: "
+                        f"would need more than {len(buf)} bytes")
+                buf = bytearray(new_len)
+
+        return bytes(buf[:c_len])
+
+    def get_option_float(self, key: str) -> float:
+        """Get the value of a floating-point option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef double c_value = 0.0
+        check_error(
+            AdbcConnectionGetOptionDouble(
+                &self.connection, c_key, &c_value, &c_error),
+            &c_error)
+        return c_value
+
+    def get_option_int(self, key: str) -> int:
+        """Get the value of an integer option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef int64_t c_value = 0
+        check_error(
+            AdbcConnectionGetOptionInt(
+                &self.connection, c_key, &c_value, &c_error),
+            &c_error)
+        return c_value
+
     def get_table_schema(self, catalog, db_schema, table_name) -> 
ArrowSchemaHandle:
         """
         Get the Arrow schema of a table.
@@ -854,12 +1103,28 @@ cdef class AdbcConnection(_AdbcHandle):
 
             if value is None:
                 c_value = NULL
-            else:
+                status = AdbcConnectionSetOption(
+                    &self.connection, c_key, c_value, &c_error)
+            elif isinstance(value, str):
                 value = value.encode("utf-8")
                 c_value = value
+                status = AdbcConnectionSetOption(
+                    &self.connection, c_key, c_value, &c_error)
+            elif isinstance(value, bytes):
+                c_value = value
+                status = AdbcConnectionSetOptionBytes(
+                    &self.connection, c_key, <const uint8_t*> c_value, 
len(value), &c_error)
+            elif isinstance(value, float):
+                status = AdbcConnectionSetOptionDouble(
+                    &self.connection, c_key, value, &c_error)
+            elif isinstance(value, int):
+                status = AdbcConnectionSetOptionInt(
+                    &self.connection, c_key, value, &c_error)
+            else:
+                raise ValueError(
+                    f"Unsupported type {type(value)} for value {value!r} "
+                    f"of option {key}")
 
-            status = AdbcConnectionSetOption(
-                &self.connection, c_key, c_value, &c_error)
             check_error(status, &c_error)
 
     def close(self) -> None:
@@ -970,7 +1235,16 @@ cdef class AdbcStatement(_AdbcHandle):
                 &c_error)
         check_error(status, &c_error)
 
+    def cancel(self) -> None:
+        """Attempt to cancel any ongoing operations on the connection."""
+        cdef CAdbcError c_error = empty_error()
+        cdef CAdbcStatusCode status
+        with nogil:
+            status = AdbcStatementCancel(&self.statement, &c_error)
+        check_error(status, &c_error)
+
     def close(self) -> None:
+        """Release the handle to the statement."""
         cdef CAdbcError c_error = empty_error()
         cdef CAdbcStatusCode status
         self.connection._close_child()
@@ -1044,6 +1318,25 @@ cdef class AdbcStatement(_AdbcHandle):
 
         return (partitions, schema, rows_affected)
 
+    def execute_schema(self) -> ArrowSchemaHandle:
+        """
+        Get the schema of the result set without executing the query.
+
+        Returns
+        -------
+        ArrowSchemaHandle
+            The schema of the result set.
+        """
+        cdef CAdbcError c_error = empty_error()
+        cdef ArrowSchemaHandle schema = ArrowSchemaHandle()
+        with nogil:
+            status = AdbcStatementExecuteSchema(
+                &self.statement,
+                &schema.schema,
+                &c_error)
+        check_error(status, &c_error)
+        return schema
+
     def execute_update(self) -> int:
         """
         Execute the query without a result set.
@@ -1064,6 +1357,93 @@ cdef class AdbcStatement(_AdbcHandle):
         check_error(status, &c_error)
         return rows_affected
 
+    def get_option(self, key: str) -> str:
+        """Get the value of a string option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef uint8_t* c_value = NULL
+        cdef size_t c_len = 0
+
+        buf = bytearray(1024)
+        while True:
+            c_value = buf
+            c_len = len(buf)
+            check_error(
+                AdbcStatementGetOption(
+                    &self.statement, c_key, buf, &c_len, &c_error),
+                &c_error)
+            if c_len <= len(buf):
+                # Entire value read
+                break
+            else:
+                # Buffer too small
+                new_len = len(buf) * 2
+                if new_len > _MAX_OPTION_SIZE:
+                    raise RuntimeError(
+                        f"Could not read option {key}: "
+                        f"would need more than {len(buf)} bytes")
+                buf = bytearray(new_len)
+
+        # Remove trailing null terminator
+        if c_len > 0:
+            c_len -= 1
+        return buf[:c_len].decode("utf-8")
+
+    def get_option_bytes(self, key: str) -> bytes:
+        """Get the value of a binary option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef uint8_t* c_value = NULL
+        cdef size_t c_len = 0
+
+        buf = bytearray(1024)
+        while True:
+            c_value = buf
+            c_len = len(buf)
+            check_error(
+                AdbcStatementGetOptionBytes(
+                    &self.statement, c_key, buf, &c_len, &c_error),
+                &c_error)
+            if c_len <= len(buf):
+                # Entire value read
+                break
+            else:
+                # Buffer too small
+                new_len = len(buf) * 2
+                if new_len > _MAX_OPTION_SIZE:
+                    raise RuntimeError(
+                        f"Could not read option {key}: "
+                        f"would need more than {len(buf)} bytes")
+                buf = bytearray(new_len)
+
+        return bytes(buf[:c_len])
+
+    def get_option_float(self, key: str) -> float:
+        """Get the value of a floating-point option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef double c_value = 0.0
+        check_error(
+            AdbcStatementGetOptionDouble(
+                &self.statement, c_key, &c_value, &c_error),
+            &c_error)
+        return c_value
+
+    def get_option_int(self, key: str) -> int:
+        """Get the value of an integer option."""
+        cdef CAdbcError c_error = empty_error()
+        key_bytes = key.encode("utf-8")
+        cdef char* c_key = key_bytes
+        cdef int64_t c_value = 0
+        check_error(
+            AdbcStatementGetOptionInt(
+                &self.statement, c_key, &c_value, &c_error),
+            &c_error)
+        return c_value
+
     def get_parameter_schema(self) -> ArrowSchemaHandle:
         """Get the Arrow schema for bound parameters.
 
@@ -1104,7 +1484,8 @@ cdef class AdbcStatement(_AdbcHandle):
         check_error(status, &c_error)
 
     def set_options(self, **kwargs) -> None:
-        """Set arbitrary key-value options.
+        """
+        Set arbitrary key-value options.
 
         See Also
         --------
@@ -1119,12 +1500,28 @@ cdef class AdbcStatement(_AdbcHandle):
 
             if value is None:
                 c_value = NULL
-            else:
+                status = AdbcStatementSetOption(
+                    &self.statement, c_key, c_value, &c_error)
+            elif isinstance(value, str):
                 value = value.encode("utf-8")
                 c_value = value
+                status = AdbcStatementSetOption(
+                    &self.statement, c_key, c_value, &c_error)
+            elif isinstance(value, bytes):
+                c_value = value
+                status = AdbcStatementSetOptionBytes(
+                    &self.statement, c_key, <const uint8_t*> c_value, 
len(value), &c_error)
+            elif isinstance(value, float):
+                status = AdbcStatementSetOptionDouble(
+                    &self.statement, c_key, value, &c_error)
+            elif isinstance(value, int):
+                status = AdbcStatementSetOptionInt(
+                    &self.statement, c_key, value, &c_error)
+            else:
+                raise ValueError(
+                    f"Unsupported type {type(value)} for value {value!r} "
+                    f"of option {key}")
 
-            status = AdbcStatementSetOption(
-                &self.statement, c_key, c_value, &c_error)
             check_error(status, &c_error)
 
     def set_sql_query(self, str query not None) -> None:
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py 
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 31e4392a..d9b1f554 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -43,6 +43,8 @@ try:
 except ImportError as e:
     raise ImportError("PyArrow is required for the DBAPI-compatible 
interface") from e
 
+import adbc_driver_manager
+
 from . import _lib
 
 if typing.TYPE_CHECKING:
@@ -78,6 +80,7 @@ _KNOWN_INFO_VALUES = {
     100: "driver_name",
     101: "driver_version",
     102: "driver_arrow_version",
+    103: "driver_adbc_version",
 }
 
 # ----------------------------------------------------------
@@ -344,6 +347,16 @@ class Connection(_Closeable):
     # API Extensions
     # ------------------------------------------------------------
 
+    def adbc_cancel(self) -> None:
+        """
+        Cancel any ongoing operations on this connection.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        self._conn.cancel()
+
     def adbc_clone(self) -> "Connection":
         """
         Create a new Connection sharing the same underlying database.
@@ -479,6 +492,40 @@ class Connection(_Closeable):
         """
         return self._conn
 
+    @property
+    def adbc_current_catalog(self) -> str:
+        """
+        The name of the current catalog.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        key = adbc_driver_manager.ConnectionOptions.CURRENT_CATALOG.value
+        return self._conn.get_option(key)
+
+    @adbc_current_catalog.setter
+    def adbc_current_catalog(self, catalog: str) -> None:
+        key = adbc_driver_manager.ConnectionOptions.CURRENT_CATALOG.value
+        self._conn.set_options(**{key: catalog})
+
+    @property
+    def adbc_current_db_schema(self) -> str:
+        """
+        The name of the current schema.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        key = adbc_driver_manager.ConnectionOptions.CURRENT_DB_SCHEMA.value
+        return self._conn.get_option(key)
+
+    @adbc_current_db_schema.setter
+    def adbc_current_db_schema(self, db_schema: str) -> None:
+        key = adbc_driver_manager.ConnectionOptions.CURRENT_DB_SCHEMA.value
+        self._conn.set_options(**{key: db_schema})
+
     @property
     def adbc_database(self) -> _lib.AdbcDatabase:
         """
@@ -729,11 +776,21 @@ class Cursor(_Closeable):
     # API Extensions
     # ------------------------------------------------------------
 
+    def adbc_cancel(self) -> None:
+        """
+        Cancel any ongoing operations on this statement.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        self._stmt.cancel()
+
     def adbc_ingest(
         self,
         table_name: str,
         data: Union[pyarrow.RecordBatch, pyarrow.Table, 
pyarrow.RecordBatchReader],
-        mode: Literal["append", "create"] = "create",
+        mode: Literal["append", "create", "replace", "append_create"] = 
"create",
     ) -> int:
         """
         Ingest Arrow data into a database table.
@@ -748,7 +805,12 @@ class Cursor(_Closeable):
         data
             The Arrow data to insert.
         mode
-            Whether to append data to an existing table, or create a new table.
+            How to deal with existing data:
+
+            - 'append': append to a table (error if table does not exist)
+            - 'create': create a table and insert (error if table exists)
+            - 'create_append': create a table (if not exists) and insert
+            - 'replace': drop existing table (if any), then same as 'create'
 
         Returns
         -------
@@ -764,6 +826,10 @@ class Cursor(_Closeable):
             c_mode = _lib.INGEST_OPTION_MODE_APPEND
         elif mode == "create":
             c_mode = _lib.INGEST_OPTION_MODE_CREATE
+        elif mode == "create_append":
+            c_mode = _lib.INGEST_OPTION_MODE_CREATE_APPEND
+        elif mode == "replace":
+            c_mode = _lib.INGEST_OPTION_MODE_REPLACE
         else:
             raise ValueError(f"Invalid value for 'mode': {mode}")
         self._stmt.set_options(
@@ -810,6 +876,23 @@ class Cursor(_Closeable):
         partitions, schema, self._rowcount = self._stmt.execute_partitions()
         return partitions, pyarrow.Schema._import_from_c(schema.address)
 
+    def adbc_execute_schema(self, operation, parameters=None) -> 
pyarrow.Schema:
+        """
+        Get the schema of the result set of a query without executing it.
+
+        Returns
+        -------
+        pyarrow.Schema
+            The schema of the result set.
+
+        Notes
+        -----
+        This is an extension and not part of the DBAPI standard.
+        """
+        self._prepare_execute(operation, parameters)
+        schema = self._stmt.execute_schema()
+        return pyarrow.Schema._import_from_c(schema.address)
+
     def adbc_prepare(self, operation: Union[bytes, str]) -> 
Optional[pyarrow.Schema]:
         """
         Prepare a query without executing it.
diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py 
b/python/adbc_driver_postgresql/tests/test_dbapi.py
index c50cad1e..e3f86a44 100644
--- a/python/adbc_driver_postgresql/tests/test_dbapi.py
+++ b/python/adbc_driver_postgresql/tests/test_dbapi.py
@@ -17,6 +17,7 @@
 
 from typing import Generator
 
+import pyarrow
 import pytest
 
 from adbc_driver_postgresql import StatementOptions, dbapi
@@ -28,6 +29,32 @@ def postgres(postgres_uri: str) -> 
Generator[dbapi.Connection, None, None]:
         yield conn
 
 
+def test_conn_current_catalog(postgres: dbapi.Connection) -> None:
+    assert postgres.adbc_current_catalog != ""
+
+
+def test_conn_current_db_schema(postgres: dbapi.Connection) -> None:
+    assert postgres.adbc_current_db_schema == "public"
+
+
+def test_conn_change_db_schema(postgres: dbapi.Connection) -> None:
+    assert postgres.adbc_current_db_schema == "public"
+
+    with postgres.cursor() as cur:
+        cur.execute("CREATE SCHEMA IF NOT EXISTS dbapischema")
+
+    assert postgres.adbc_current_db_schema == "public"
+    postgres.adbc_current_db_schema = "dbapischema"
+    assert postgres.adbc_current_db_schema == "dbapischema"
+
+
+def test_conn_get_info(postgres: dbapi.Connection) -> None:
+    info = postgres.adbc_get_info()
+    assert info["driver_name"] == "ADBC PostgreSQL Driver"
+    assert info["driver_adbc_version"] == 1_001_000
+    assert info["vendor_name"] == "PostgreSQL"
+
+
 def test_query_batch_size(postgres: dbapi.Connection):
     with postgres.cursor() as cur:
         cur.execute("DROP TABLE IF EXISTS test_batch_size")
@@ -47,6 +74,12 @@ def test_query_batch_size(postgres: dbapi.Connection):
         cur.adbc_statement.set_options(
             **{StatementOptions.BATCH_SIZE_HINT_BYTES.value: "1"}
         )
+        assert (
+            cur.adbc_statement.get_option_int(
+                StatementOptions.BATCH_SIZE_HINT_BYTES.value
+            )
+            == 1
+        )
         cur.execute("SELECT * FROM test_batch_size")
         table = cur.fetch_arrow_table()
         assert len(table.to_batches()) == 65536
@@ -54,17 +87,55 @@ def test_query_batch_size(postgres: dbapi.Connection):
         cur.adbc_statement.set_options(
             **{StatementOptions.BATCH_SIZE_HINT_BYTES.value: "4096"}
         )
+        assert (
+            cur.adbc_statement.get_option_int(
+                StatementOptions.BATCH_SIZE_HINT_BYTES.value
+            )
+            == 4096
+        )
         cur.execute("SELECT * FROM test_batch_size")
         table = cur.fetch_arrow_table()
         assert 64 <= len(table.to_batches()) <= 256
 
 
+def test_query_cancel(postgres: dbapi.Connection) -> None:
+    with postgres.cursor() as cur:
+        cur.execute("DROP TABLE IF EXISTS test_batch_size")
+        cur.execute("CREATE TABLE test_batch_size (ints INT)")
+        cur.execute(
+            """
+            INSERT INTO test_batch_size (ints)
+            SELECT generated :: INT
+            FROM GENERATE_SERIES(1, 65536) temp(generated)
+        """
+        )
+
+        cur.execute("SELECT * FROM test_batch_size")
+        cur.adbc_cancel()
+        # XXX(https://github.com/apache/arrow-adbc/issues/940):
+        # PyArrow swallows the errno and doesn't set it into the
+        # OSError, so we have no clue what happened here. (Though the
+        # driver does properly return ECANCELED.)
+        with pytest.raises(OSError, match="canceling statement"):
+            cur.fetchone()
+
+
+def test_query_execute_schema(postgres: dbapi.Connection) -> None:
+    with postgres.cursor() as cur:
+        schema = cur.adbc_execute_schema("SELECT 1 AS foo")
+        assert schema == pyarrow.schema([("foo", "int32")])
+
+
 def test_query_trivial(postgres: dbapi.Connection):
     with postgres.cursor() as cur:
         cur.execute("SELECT 1")
         assert cur.fetchone() == (1,)
 
 
+def test_stmt_ingest(postgres: dbapi.Connection) -> None:
+    pass
+
+
 def test_ddl(postgres: dbapi.Connection):
     with postgres.cursor() as cur:
         cur.execute("DROP TABLE IF EXISTS test_ddl")


Reply via email to