This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new ab9be7935 feat(python/adbc_driver_manager): allow more types in init
(#4088)
ab9be7935 is described below
commit ab9be79352bf65f3452effe2871a3cd111c9bbeb
Author: David Li <[email protected]>
AuthorDate: Wed Mar 18 09:29:53 2026 +0900
feat(python/adbc_driver_manager): allow more types in init (#4088)
- Test overriding options from profiles
- Allow bool, bytes, int, float option values when
constructing a database or connection
- Document option overriding behavior in profiles when
the option values are of different types.
We could improve the behavior: track all options of all types
set outside a profile, and skip profile options where they match.
Closes #4086.
---
c/driver/sqlite/sqlite.cc | 30 ++++-
docs/source/format/connection_profiles.rst | 7 +
.../adbc_driver_manager/_lib.pyi | 19 ++-
.../adbc_driver_manager/_lib.pyx | 102 +++++++++++++--
python/adbc_driver_manager/tests/test_lowlevel.py | 145 ++++++++++++++++++++-
python/adbc_driver_manager/tests/test_profile.py | 5 +-
6 files changed, 284 insertions(+), 24 deletions(-)
diff --git a/c/driver/sqlite/sqlite.cc b/c/driver/sqlite/sqlite.cc
index e91263b18..d970d76a0 100644
--- a/c/driver/sqlite/sqlite.cc
+++ b/c/driver/sqlite/sqlite.cc
@@ -701,7 +701,9 @@ class SqliteConnection : public
driver::Connection<SqliteConnection> {
Status InitImpl(void* parent) {
auto& db = *reinterpret_cast<SqliteDatabase*>(parent);
UNWRAP_RESULT(conn_, db.OpenConnection());
- batch_size_ = db.batch_size_;
+ if (!batch_size_.has_value()) {
+ batch_size_ = db.batch_size_;
+ }
return status::Ok();
}
@@ -725,7 +727,8 @@ class SqliteConnection : public
driver::Connection<SqliteConnection> {
Result<driver::Option> GetOption(std::string_view key) override {
if (key == kStatementOptionBatchRows) {
- return driver::Option(static_cast<int64_t>(batch_size_));
+ return driver::Option(
+ static_cast<int64_t>(batch_size_.value_or(kDefaultBatchSize)));
}
return Base::GetOption(key);
}
@@ -783,6 +786,23 @@ class SqliteConnection : public
driver::Connection<SqliteConnection> {
return status::NotImplemented(
"this driver build does not support extension loading");
#endif
+ } else if (key == kStatementOptionBatchRows) {
+ if (lifecycle_state_ != driver::LifecycleState::kUninitialized) {
+ return status::fmt::InvalidState(
+ "{} cannot set {} after AdbcConnectionInit, set it directly on the
statement "
+ "instead",
+ kErrorPrefix, key);
+ }
+ int64_t batch_size;
+ UNWRAP_RESULT(batch_size, value.AsInt());
+ if (batch_size <= 0 || batch_size > std::numeric_limits<int>::max()) {
+ return status::fmt::InvalidArgument(
+ "{} Invalid statement option value {}={} (value is non-positive or
out of "
+ "range of int)",
+ kErrorPrefix, key, value.Format());
+ }
+ batch_size_ = static_cast<int>(batch_size);
+ return status::Ok();
}
return Base::SetOptionImpl(key, value);
}
@@ -811,7 +831,7 @@ class SqliteConnection : public
driver::Connection<SqliteConnection> {
// Temporarily hold the extension path (since the path and entrypoint need
// to be set separately)
std::string extension_path_;
- int batch_size_ = kDefaultBatchSize;
+ std::optional<int> batch_size_;
};
class SqliteStatement : public driver::Statement<SqliteStatement> {
@@ -1153,7 +1173,9 @@ class SqliteStatement : public
driver::Statement<SqliteStatement> {
Status InitImpl(void* parent) {
auto& conn = *reinterpret_cast<SqliteConnection*>(parent);
conn_ = conn.conn();
- batch_size_ = conn.batch_size_;
+ if (conn.batch_size_) {
+ batch_size_ = conn.batch_size_.value();
+ }
return Statement::InitImpl(parent);
}
diff --git a/docs/source/format/connection_profiles.rst
b/docs/source/format/connection_profiles.rst
index c9336b8c9..87d8ac142 100644
--- a/docs/source/format/connection_profiles.rst
+++ b/docs/source/format/connection_profiles.rst
@@ -322,6 +322,13 @@ Example:
AdbcDatabaseInit(&database, &error);
// Result: warehouse = "ANALYTICS_WH"
+.. note:: Options of different types are set separately. For example, if the
+ profile defines an option with an integer value, and the application
+ sets the same option but with a string value, it is
+ implementation-defined as to which value will take precedence. If
+ the application were to use an integer value instead, then the
+ application value would take precedence as expected.
+
Custom Profile Providers
=========================
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
index 628f43c36..218006431 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
@@ -33,7 +33,11 @@ INGEST_OPTION_MODE_REPLACE: str
INGEST_OPTION_TARGET_TABLE: str
class AdbcConnection(_AdbcHandle):
- def __init__(self, database: "AdbcDatabase", **kwargs: str) -> None: ...
+ def __init__(
+ self,
+ database: "AdbcDatabase",
+ **kwargs: Union[bytes, float, int, str, bool, enum.Enum, pathlib.Path],
+ ) -> None: ...
def cancel(self) -> None: ...
def close(self) -> None: ...
def commit(self) -> None: ...
@@ -69,10 +73,15 @@ class AdbcConnection(_AdbcHandle):
def read_partition(self, partition: bytes) -> "ArrowArrayStreamHandle": ...
def rollback(self) -> None: ...
def set_autocommit(self, enabled: bool) -> None: ...
- def set_options(self, **kwargs: Union[bytes, float, int, str, None]) ->
None: ...
+ def set_options(
+ self, **kwargs: Union[bytes, float, int, str, bool, None]
+ ) -> None: ...
class AdbcDatabase(_AdbcHandle):
- def __init__(self, **kwargs: Union[str, pathlib.Path]) -> None: ...
+ def __init__(
+ self,
+ **kwargs: Union[bytes, float, int, str, bool, enum.Enum, pathlib.Path],
+ ) -> None: ...
def close(self) -> None: ...
def get_option(
self,
@@ -84,7 +93,9 @@ class AdbcDatabase(_AdbcHandle):
def get_option_bytes(self, key: Union[bytes, str]) -> bytes: ...
def get_option_float(self, key: Union[bytes, str]) -> float: ...
def get_option_int(self, key: Union[bytes, str]) -> int: ...
- def set_options(self, **kwargs: Union[bytes, float, int, str, None]) ->
None: ...
+ def set_options(
+ self, **kwargs: Union[bytes, float, int, str, bool, None]
+ ) -> None: ...
class AdbcInfoCode(enum.IntEnum):
DRIVER_ARROW_VERSION = ...
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index a58900785..95d2c299b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -554,6 +554,9 @@ cdef class AdbcDatabase(_AdbcHandle):
cdef CAdbcStatusCode status
cdef const char* c_key
cdef const char* c_value
+ cdef int64_t c_value_int
+ cdef double c_value_double
+ cdef size_t c_value_len = 0
memset(&self.database, 0, cython.sizeof(CAdbcDatabase))
with nogil:
@@ -582,13 +585,41 @@ cdef class AdbcDatabase(_AdbcHandle):
raise ValueError(f"value for key '{key}' cannot be None")
else:
key = _to_bytes(key, "key")
+ c_key = key
+
if isinstance(value, pathlib.Path):
value = str(value)
- value = _to_bytes(value, "value")
- c_key = key
- c_value = value
- status = AdbcDatabaseSetOption(
- &self.database, c_key, c_value, &c_error)
+ elif isinstance(value, enum.Enum):
+ value = value.value
+
+ if isinstance(value, bool):
+ if value:
+ value = ADBC_OPTION_VALUE_ENABLED
+ else:
+ value = ADBC_OPTION_VALUE_DISABLED
+ value = _to_bytes(value, "option value")
+ c_value = value
+ status = AdbcDatabaseSetOption(
+ &self.database, c_key, c_value, &c_error)
+ elif isinstance(value, int):
+ c_value_int = value
+ status = AdbcDatabaseSetOptionInt(
+ &self.database, c_key, c_value_int, &c_error)
+ elif isinstance(value, float):
+ c_value_double = value
+ status = AdbcDatabaseSetOptionDouble(
+ &self.database, c_key, c_value_double, &c_error)
+ elif isinstance(value, bytes):
+ c_value = value
+ c_value_len = len(value)
+ status = AdbcDatabaseSetOptionBytes(
+ &self.database, c_key, <const uint8_t*> c_value,
+ c_value_len, &c_error)
+ else:
+ value = _to_bytes(value, "value")
+ c_value = value
+ status = AdbcDatabaseSetOption(
+ &self.database, c_key, c_value, &c_error)
check_error(status, &c_error)
# check if we're running in a venv
@@ -816,6 +847,9 @@ cdef class AdbcConnection(_AdbcHandle):
cdef CAdbcStatusCode status
cdef const char* c_key
cdef const char* c_value
+ cdef int64_t c_value_int
+ cdef double c_value_double
+ cdef size_t c_value_len = 0
self.database = database
memset(&self.connection, 0, cython.sizeof(CAdbcConnection))
@@ -825,15 +859,57 @@ cdef class AdbcConnection(_AdbcHandle):
check_error(status, &c_error)
for key, value in kwargs.items():
- key = key.encode("utf-8")
- value = value.encode("utf-8")
+ key = _to_bytes(key, "key")
c_key = key
- c_value = value
- with nogil:
- status = AdbcConnectionSetOption(
- &self.connection, c_key, c_value, &c_error)
- if status != ADBC_STATUS_OK:
- AdbcConnectionRelease(&self.connection, NULL)
+
+ if isinstance(value, pathlib.Path):
+ value = str(value)
+ elif isinstance(value, enum.Enum):
+ value = value.value
+
+ if isinstance(value, bool):
+ if value:
+ value = ADBC_OPTION_VALUE_ENABLED
+ else:
+ value = ADBC_OPTION_VALUE_DISABLED
+ value = _to_bytes(value, "option value")
+ c_value = value
+ with nogil:
+ status = AdbcConnectionSetOption(
+ &self.connection, c_key, c_value, &c_error)
+ if status != ADBC_STATUS_OK:
+ AdbcConnectionRelease(&self.connection, NULL)
+ if isinstance(value, int):
+ c_value_int = value
+ with nogil:
+ status = AdbcConnectionSetOptionInt(
+ &self.connection, c_key, c_value_int, &c_error)
+ if status != ADBC_STATUS_OK:
+ AdbcConnectionRelease(&self.connection, NULL)
+ elif isinstance(value, float):
+ c_value_double = value
+ with nogil:
+ status = AdbcConnectionSetOptionDouble(
+ &self.connection, c_key, c_value_double, &c_error)
+ if status != ADBC_STATUS_OK:
+ AdbcConnectionRelease(&self.connection, NULL)
+ elif isinstance(value, bytes):
+ c_value = value
+ c_value_len = len(value)
+ with nogil:
+ status = AdbcConnectionSetOptionBytes(
+ &self.connection, c_key, <const uint8_t*> c_value,
+ c_value_len, &c_error)
+ if status != ADBC_STATUS_OK:
+ AdbcConnectionRelease(&self.connection, NULL)
+ else:
+ value = _to_bytes(value, "value")
+ c_value = value
+ with nogil:
+ status = AdbcConnectionSetOption(
+ &self.connection, c_key, c_value, &c_error)
+ if status != ADBC_STATUS_OK:
+ AdbcConnectionRelease(&self.connection, NULL)
check_error(status, &c_error)
with nogil:
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py
b/python/adbc_driver_manager/tests/test_lowlevel.py
index dac1955b6..ffd77b292 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
+import enum
import pathlib
+import re
import pyarrow
import pytest
@@ -51,7 +53,7 @@ def test_version() -> None:
assert adbc_driver_manager.__version__ # type:ignore
-def test_database_init() -> None:
+def test_database_init_no_option() -> None:
with pytest.raises(
adbc_driver_manager.ProgrammingError,
match=".*Must set 'driver' option.*",
@@ -98,6 +100,81 @@ def test_error_mapping() -> None:
assert exc_info.value.sqlstate == "X0000"
+class ExampleEnum(enum.Enum):
+ BAR = "BAR"
+
+
[email protected]
+def test_database_init(tmp_path) -> None:
+ option = "adbc.sqlite.query.batch_rows"
+ path = (tmp_path / "db.sqlite").as_uri()
+ with adbc_driver_manager.AdbcDatabase(
+ driver="adbc_driver_sqlite", uri=path, **{option: 42}
+ ) as db:
+ assert db.get_option_int(option) == 42
+
+ # test different data types
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown database option foo=1.0"),
+ ):
+ with adbc_driver_manager.AdbcDatabase(
+ driver="adbc_driver_sqlite",
+ **{"foo": 1.0},
+ ):
+ pass
+
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown database option foo=(3 bytes)"),
+ ):
+ with adbc_driver_manager.AdbcDatabase(
+ driver="adbc_driver_sqlite",
+ **{"foo": b"b\0r"},
+ ):
+ pass
+
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown database option foo='true'"),
+ ):
+ with adbc_driver_manager.AdbcDatabase(
+ driver="adbc_driver_sqlite",
+ **{"foo": True},
+ ):
+ pass
+
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown database option foo='false'"),
+ ):
+ with adbc_driver_manager.AdbcDatabase(
+ driver="adbc_driver_sqlite",
+ **{"foo": False},
+ ):
+ pass
+
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown database option foo='"),
+ ):
+ with adbc_driver_manager.AdbcDatabase(
+ driver="adbc_driver_sqlite",
+ **{"foo": pathlib.Path("/tmp")},
+ ):
+ pass
+
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown database option foo='BAR'"),
+ ):
+ with adbc_driver_manager.AdbcDatabase(
+ driver="adbc_driver_sqlite",
+ **{"foo": ExampleEnum.BAR},
+ ):
+ pass
+
+
@pytest.mark.sqlite
def test_database_set_options(sqlite_raw) -> None:
db, _ = sqlite_raw
@@ -114,6 +191,72 @@ def test_database_set_options(sqlite_raw) -> None:
db.set_options(foo=None)
[email protected]
+def test_connection_init() -> None:
+ option = "adbc.sqlite.query.batch_rows"
+ with adbc_driver_manager.AdbcDatabase(driver="adbc_driver_sqlite") as db:
+ with adbc_driver_manager.AdbcConnection(db, **{option: 42}) as conn:
+ assert conn.get_option_int(option) == 42
+
+ # test different data types
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown connection option foo=1.0"),
+ ):
+ with adbc_driver_manager.AdbcConnection(db, **{"foo": 1.0}):
+ pass
+
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown connection option foo=(3 bytes)"),
+ ):
+ with adbc_driver_manager.AdbcConnection(
+ db,
+ **{"foo": b"b\0r"},
+ ):
+ pass
+
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown connection option foo='true'"),
+ ):
+ with adbc_driver_manager.AdbcConnection(
+ db,
+ **{"foo": True},
+ ):
+ pass
+
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown connection option foo='false'"),
+ ):
+ with adbc_driver_manager.AdbcConnection(
+ db,
+ **{"foo": False},
+ ):
+ pass
+
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown connection option foo='"),
+ ):
+ with adbc_driver_manager.AdbcConnection(
+ db,
+ **{"foo": pathlib.Path("/tmp")},
+ ):
+ pass
+
+ with pytest.raises(
+ adbc_driver_manager.NotSupportedError,
+ match=re.escape("Unknown connection option foo='BAR'"),
+ ):
+ with adbc_driver_manager.AdbcConnection(
+ db,
+ **{"foo": ExampleEnum.BAR},
+ ):
+ pass
+
+
@pytest.mark.sqlite
def test_connection_get_info(sqlite_raw) -> None:
_, conn = sqlite_raw
diff --git a/python/adbc_driver_manager/tests/test_profile.py
b/python/adbc_driver_manager/tests/test_profile.py
index a49f3afe5..f8eda8027 100644
--- a/python/adbc_driver_manager/tests/test_profile.py
+++ b/python/adbc_driver_manager/tests/test_profile.py
@@ -185,17 +185,18 @@ uri = "{contents}"
pass
[email protected](reason="https://github.com/apache/arrow-adbc/issues/4086")
def test_option_override(tmp_path, monkeypatch) -> None:
# Test that the driver is optional
monkeypatch.setenv("ADBC_PROFILE_PATH", str(tmp_path))
+ # NOTE: if we use an int value here, the override won't appear to work,
+ # because the options are set separately and do not override each other.
with (tmp_path / "dev.toml").open("w") as sink:
sink.write("""
profile_version = 1
driver = "adbc_driver_sqlite"
[Options]
-adbc.sqlite.query.batch_rows = 7
+adbc.sqlite.query.batch_rows = "7"
""")
key = "adbc.sqlite.query.batch_rows"