This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new ab9be7935 feat(python/adbc_driver_manager): allow more types in init 
(#4088)
ab9be7935 is described below

commit ab9be79352bf65f3452effe2871a3cd111c9bbeb
Author: David Li <[email protected]>
AuthorDate: Wed Mar 18 09:29:53 2026 +0900

    feat(python/adbc_driver_manager): allow more types in init (#4088)
    
    - Test overriding options from profiles
    - Allow bool, bytes, int, float option values when
      constructing a database or connection
    - Document option overriding behavior in profiles when
      the option values are of different types.
    
    We could improve the behavior: track all options of all types
    set outside a profile, and skip profile options where they match.
    
    Closes #4086.
---
 c/driver/sqlite/sqlite.cc                          |  30 ++++-
 docs/source/format/connection_profiles.rst         |   7 +
 .../adbc_driver_manager/_lib.pyi                   |  19 ++-
 .../adbc_driver_manager/_lib.pyx                   | 102 +++++++++++++--
 python/adbc_driver_manager/tests/test_lowlevel.py  | 145 ++++++++++++++++++++-
 python/adbc_driver_manager/tests/test_profile.py   |   5 +-
 6 files changed, 284 insertions(+), 24 deletions(-)

diff --git a/c/driver/sqlite/sqlite.cc b/c/driver/sqlite/sqlite.cc
index e91263b18..d970d76a0 100644
--- a/c/driver/sqlite/sqlite.cc
+++ b/c/driver/sqlite/sqlite.cc
@@ -701,7 +701,9 @@ class SqliteConnection : public 
driver::Connection<SqliteConnection> {
   Status InitImpl(void* parent) {
     auto& db = *reinterpret_cast<SqliteDatabase*>(parent);
     UNWRAP_RESULT(conn_, db.OpenConnection());
-    batch_size_ = db.batch_size_;
+    if (!batch_size_.has_value()) {
+      batch_size_ = db.batch_size_;
+    }
     return status::Ok();
   }
 
@@ -725,7 +727,8 @@ class SqliteConnection : public 
driver::Connection<SqliteConnection> {
 
   Result<driver::Option> GetOption(std::string_view key) override {
     if (key == kStatementOptionBatchRows) {
-      return driver::Option(static_cast<int64_t>(batch_size_));
+      return driver::Option(
+          static_cast<int64_t>(batch_size_.value_or(kDefaultBatchSize)));
     }
     return Base::GetOption(key);
   }
@@ -783,6 +786,23 @@ class SqliteConnection : public 
driver::Connection<SqliteConnection> {
       return status::NotImplemented(
           "this driver build does not support extension loading");
 #endif
+    } else if (key == kStatementOptionBatchRows) {
+      if (lifecycle_state_ != driver::LifecycleState::kUninitialized) {
+        return status::fmt::InvalidState(
+            "{} cannot set {} after AdbcConnectionInit, set it directly on the 
statement "
+            "instead",
+            kErrorPrefix, key);
+      }
+      int64_t batch_size;
+      UNWRAP_RESULT(batch_size, value.AsInt());
+      if (batch_size <= 0 || batch_size > std::numeric_limits<int>::max()) {
+        return status::fmt::InvalidArgument(
+            "{} Invalid statement option value {}={} (value is non-positive or 
out of "
+            "range of int)",
+            kErrorPrefix, key, value.Format());
+      }
+      batch_size_ = static_cast<int>(batch_size);
+      return status::Ok();
     }
     return Base::SetOptionImpl(key, value);
   }
@@ -811,7 +831,7 @@ class SqliteConnection : public 
driver::Connection<SqliteConnection> {
   // Temporarily hold the extension path (since the path and entrypoint need
   // to be set separately)
   std::string extension_path_;
-  int batch_size_ = kDefaultBatchSize;
+  std::optional<int> batch_size_;
 };
 
 class SqliteStatement : public driver::Statement<SqliteStatement> {
@@ -1153,7 +1173,9 @@ class SqliteStatement : public 
driver::Statement<SqliteStatement> {
   Status InitImpl(void* parent) {
     auto& conn = *reinterpret_cast<SqliteConnection*>(parent);
     conn_ = conn.conn();
-    batch_size_ = conn.batch_size_;
+    if (conn.batch_size_) {
+      batch_size_ = conn.batch_size_.value();
+    }
     return Statement::InitImpl(parent);
   }
 
diff --git a/docs/source/format/connection_profiles.rst 
b/docs/source/format/connection_profiles.rst
index c9336b8c9..87d8ac142 100644
--- a/docs/source/format/connection_profiles.rst
+++ b/docs/source/format/connection_profiles.rst
@@ -322,6 +322,13 @@ Example:
    AdbcDatabaseInit(&database, &error);
    // Result: warehouse = "ANALYTICS_WH"
 
+.. note:: Options of different types are set separately. For example, if the
+          profile defines an option with an integer value, and the application
+          sets the same option but with a string value, it is
+          implementation-defined as to which value will take precedence.  If
+          the application were to use an integer value instead, then the
+          application value would take precedence as expected.
+
 Custom Profile Providers
 =========================
 
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
index 628f43c36..218006431 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
@@ -33,7 +33,11 @@ INGEST_OPTION_MODE_REPLACE: str
 INGEST_OPTION_TARGET_TABLE: str
 
 class AdbcConnection(_AdbcHandle):
-    def __init__(self, database: "AdbcDatabase", **kwargs: str) -> None: ...
+    def __init__(
+        self,
+        database: "AdbcDatabase",
+        **kwargs: Union[bytes, float, int, str, bool, enum.Enum, pathlib.Path],
+    ) -> None: ...
     def cancel(self) -> None: ...
     def close(self) -> None: ...
     def commit(self) -> None: ...
@@ -69,10 +73,15 @@ class AdbcConnection(_AdbcHandle):
     def read_partition(self, partition: bytes) -> "ArrowArrayStreamHandle": ...
     def rollback(self) -> None: ...
     def set_autocommit(self, enabled: bool) -> None: ...
-    def set_options(self, **kwargs: Union[bytes, float, int, str, None]) -> 
None: ...
+    def set_options(
+        self, **kwargs: Union[bytes, float, int, str, bool, None]
+    ) -> None: ...
 
 class AdbcDatabase(_AdbcHandle):
-    def __init__(self, **kwargs: Union[str, pathlib.Path]) -> None: ...
+    def __init__(
+        self,
+        **kwargs: Union[bytes, float, int, str, bool, enum.Enum, pathlib.Path],
+    ) -> None: ...
     def close(self) -> None: ...
     def get_option(
         self,
@@ -84,7 +93,9 @@ class AdbcDatabase(_AdbcHandle):
     def get_option_bytes(self, key: Union[bytes, str]) -> bytes: ...
     def get_option_float(self, key: Union[bytes, str]) -> float: ...
     def get_option_int(self, key: Union[bytes, str]) -> int: ...
-    def set_options(self, **kwargs: Union[bytes, float, int, str, None]) -> 
None: ...
+    def set_options(
+        self, **kwargs: Union[bytes, float, int, str, bool, None]
+    ) -> None: ...
 
 class AdbcInfoCode(enum.IntEnum):
     DRIVER_ARROW_VERSION = ...
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx 
b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index a58900785..95d2c299b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -554,6 +554,9 @@ cdef class AdbcDatabase(_AdbcHandle):
         cdef CAdbcStatusCode status
         cdef const char* c_key
         cdef const char* c_value
+        cdef int64_t c_value_int
+        cdef double c_value_double
+        cdef size_t c_value_len = 0
         memset(&self.database, 0, cython.sizeof(CAdbcDatabase))
 
         with nogil:
@@ -582,13 +585,41 @@ cdef class AdbcDatabase(_AdbcHandle):
                 raise ValueError(f"value for key '{key}' cannot be None")
             else:
                 key = _to_bytes(key, "key")
+                c_key = key
+
                 if isinstance(value, pathlib.Path):
                     value = str(value)
-                value = _to_bytes(value, "value")
-                c_key = key
-                c_value = value
-                status = AdbcDatabaseSetOption(
-                    &self.database, c_key, c_value, &c_error)
+                elif isinstance(value, enum.Enum):
+                    value = value.value
+
+                if isinstance(value, bool):
+                    if value:
+                        value = ADBC_OPTION_VALUE_ENABLED
+                    else:
+                        value = ADBC_OPTION_VALUE_DISABLED
+                    value = _to_bytes(value, "option value")
+                    c_value = value
+                    status = AdbcDatabaseSetOption(
+                        &self.database, c_key, c_value, &c_error)
+                elif isinstance(value, int):
+                    c_value_int = value
+                    status = AdbcDatabaseSetOptionInt(
+                        &self.database, c_key, c_value_int, &c_error)
+                elif isinstance(value, float):
+                    c_value_double = value
+                    status = AdbcDatabaseSetOptionDouble(
+                        &self.database, c_key, c_value_double, &c_error)
+                elif isinstance(value, bytes):
+                    c_value = value
+                    c_value_len = len(value)
+                    status = AdbcDatabaseSetOptionBytes(
+                        &self.database, c_key, <const uint8_t*> c_value,
+                        c_value_len, &c_error)
+                else:
+                    value = _to_bytes(value, "value")
+                    c_value = value
+                    status = AdbcDatabaseSetOption(
+                        &self.database, c_key, c_value, &c_error)
             check_error(status, &c_error)
 
         # check if we're running in a venv
@@ -816,6 +847,9 @@ cdef class AdbcConnection(_AdbcHandle):
         cdef CAdbcStatusCode status
         cdef const char* c_key
         cdef const char* c_value
+        cdef int64_t c_value_int
+        cdef double c_value_double
+        cdef size_t c_value_len = 0
 
         self.database = database
         memset(&self.connection, 0, cython.sizeof(CAdbcConnection))
@@ -825,15 +859,57 @@ cdef class AdbcConnection(_AdbcHandle):
         check_error(status, &c_error)
 
         for key, value in kwargs.items():
-            key = key.encode("utf-8")
-            value = value.encode("utf-8")
+            key = _to_bytes(key, "key")
             c_key = key
-            c_value = value
-            with nogil:
-                status = AdbcConnectionSetOption(
-                    &self.connection, c_key, c_value, &c_error)
-                if status != ADBC_STATUS_OK:
-                    AdbcConnectionRelease(&self.connection, NULL)
+
+            if isinstance(value, pathlib.Path):
+                value = str(value)
+            elif isinstance(value, enum.Enum):
+                value = value.value
+
+            if isinstance(value, bool):
+                if value:
+                    value = ADBC_OPTION_VALUE_ENABLED
+                else:
+                    value = ADBC_OPTION_VALUE_DISABLED
+                value = _to_bytes(value, "option value")
+                c_value = value
+                with nogil:
+                    status = AdbcConnectionSetOption(
+                        &self.connection, c_key, c_value, &c_error)
+                    if status != ADBC_STATUS_OK:
+                        AdbcConnectionRelease(&self.connection, NULL)
+            if isinstance(value, int):
+                c_value_int = value
+                with nogil:
+                    status = AdbcConnectionSetOptionInt(
+                        &self.connection, c_key, c_value_int, &c_error)
+                    if status != ADBC_STATUS_OK:
+                        AdbcConnectionRelease(&self.connection, NULL)
+            elif isinstance(value, float):
+                c_value_double = value
+                with nogil:
+                    status = AdbcConnectionSetOptionDouble(
+                        &self.connection, c_key, c_value_double, &c_error)
+                    if status != ADBC_STATUS_OK:
+                        AdbcConnectionRelease(&self.connection, NULL)
+            elif isinstance(value, bytes):
+                c_value = value
+                c_value_len = len(value)
+                with nogil:
+                    status = AdbcConnectionSetOptionBytes(
+                        &self.connection, c_key, <const uint8_t*> c_value,
+                        c_value_len, &c_error)
+                    if status != ADBC_STATUS_OK:
+                        AdbcConnectionRelease(&self.connection, NULL)
+            else:
+                value = _to_bytes(value, "value")
+                c_value = value
+                with nogil:
+                    status = AdbcConnectionSetOption(
+                        &self.connection, c_key, c_value, &c_error)
+                    if status != ADBC_STATUS_OK:
+                        AdbcConnectionRelease(&self.connection, NULL)
             check_error(status, &c_error)
 
         with nogil:
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py 
b/python/adbc_driver_manager/tests/test_lowlevel.py
index dac1955b6..ffd77b292 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -15,7 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import enum
 import pathlib
+import re
 
 import pyarrow
 import pytest
@@ -51,7 +53,7 @@ def test_version() -> None:
     assert adbc_driver_manager.__version__  # type:ignore
 
 
-def test_database_init() -> None:
+def test_database_init_no_option() -> None:
     with pytest.raises(
         adbc_driver_manager.ProgrammingError,
         match=".*Must set 'driver' option.*",
@@ -98,6 +100,81 @@ def test_error_mapping() -> None:
         assert exc_info.value.sqlstate == "X0000"
 
 
+class ExampleEnum(enum.Enum):
+    BAR = "BAR"
+
+
[email protected]
+def test_database_init(tmp_path) -> None:
+    option = "adbc.sqlite.query.batch_rows"
+    path = (tmp_path / "db.sqlite").as_uri()
+    with adbc_driver_manager.AdbcDatabase(
+        driver="adbc_driver_sqlite", uri=path, **{option: 42}
+    ) as db:
+        assert db.get_option_int(option) == 42
+
+    # test different data types
+    with pytest.raises(
+        adbc_driver_manager.NotSupportedError,
+        match=re.escape("Unknown database option foo=1.0"),
+    ):
+        with adbc_driver_manager.AdbcDatabase(
+            driver="adbc_driver_sqlite",
+            **{"foo": 1.0},
+        ):
+            pass
+
+    with pytest.raises(
+        adbc_driver_manager.NotSupportedError,
+        match=re.escape("Unknown database option foo=(3 bytes)"),
+    ):
+        with adbc_driver_manager.AdbcDatabase(
+            driver="adbc_driver_sqlite",
+            **{"foo": b"b\0r"},
+        ):
+            pass
+
+    with pytest.raises(
+        adbc_driver_manager.NotSupportedError,
+        match=re.escape("Unknown database option foo='true'"),
+    ):
+        with adbc_driver_manager.AdbcDatabase(
+            driver="adbc_driver_sqlite",
+            **{"foo": True},
+        ):
+            pass
+
+    with pytest.raises(
+        adbc_driver_manager.NotSupportedError,
+        match=re.escape("Unknown database option foo='false'"),
+    ):
+        with adbc_driver_manager.AdbcDatabase(
+            driver="adbc_driver_sqlite",
+            **{"foo": False},
+        ):
+            pass
+
+    with pytest.raises(
+        adbc_driver_manager.NotSupportedError,
+        match=re.escape("Unknown database option foo='"),
+    ):
+        with adbc_driver_manager.AdbcDatabase(
+            driver="adbc_driver_sqlite",
+            **{"foo": pathlib.Path("/tmp")},
+        ):
+            pass
+
+    with pytest.raises(
+        adbc_driver_manager.NotSupportedError,
+        match=re.escape("Unknown database option foo='BAR'"),
+    ):
+        with adbc_driver_manager.AdbcDatabase(
+            driver="adbc_driver_sqlite",
+            **{"foo": ExampleEnum.BAR},
+        ):
+            pass
+
+
 @pytest.mark.sqlite
 def test_database_set_options(sqlite_raw) -> None:
     db, _ = sqlite_raw
@@ -114,6 +191,72 @@ def test_database_set_options(sqlite_raw) -> None:
         db.set_options(foo=None)
 
 
[email protected]
+def test_connection_init() -> None:
+    option = "adbc.sqlite.query.batch_rows"
+    with adbc_driver_manager.AdbcDatabase(driver="adbc_driver_sqlite") as db:
+        with adbc_driver_manager.AdbcConnection(db, **{option: 42}) as conn:
+            assert conn.get_option_int(option) == 42
+
+        # test different data types
+        with pytest.raises(
+            adbc_driver_manager.NotSupportedError,
+            match=re.escape("Unknown connection option foo=1.0"),
+        ):
+            with adbc_driver_manager.AdbcConnection(db, **{"foo": 1.0}):
+                pass
+
+        with pytest.raises(
+            adbc_driver_manager.NotSupportedError,
+            match=re.escape("Unknown connection option foo=(3 bytes)"),
+        ):
+            with adbc_driver_manager.AdbcConnection(
+                db,
+                **{"foo": b"b\0r"},
+            ):
+                pass
+
+        with pytest.raises(
+            adbc_driver_manager.NotSupportedError,
+            match=re.escape("Unknown connection option foo='true'"),
+        ):
+            with adbc_driver_manager.AdbcConnection(
+                db,
+                **{"foo": True},
+            ):
+                pass
+
+        with pytest.raises(
+            adbc_driver_manager.NotSupportedError,
+            match=re.escape("Unknown connection option foo='false'"),
+        ):
+            with adbc_driver_manager.AdbcConnection(
+                db,
+                **{"foo": False},
+            ):
+                pass
+
+        with pytest.raises(
+            adbc_driver_manager.NotSupportedError,
+            match=re.escape("Unknown connection option foo='"),
+        ):
+            with adbc_driver_manager.AdbcConnection(
+                db,
+                **{"foo": pathlib.Path("/tmp")},
+            ):
+                pass
+
+        with pytest.raises(
+            adbc_driver_manager.NotSupportedError,
+            match=re.escape("Unknown connection option foo='BAR'"),
+        ):
+            with adbc_driver_manager.AdbcConnection(
+                db,
+                **{"foo": ExampleEnum.BAR},
+            ):
+                pass
+
+
 @pytest.mark.sqlite
 def test_connection_get_info(sqlite_raw) -> None:
     _, conn = sqlite_raw
diff --git a/python/adbc_driver_manager/tests/test_profile.py 
b/python/adbc_driver_manager/tests/test_profile.py
index a49f3afe5..f8eda8027 100644
--- a/python/adbc_driver_manager/tests/test_profile.py
+++ b/python/adbc_driver_manager/tests/test_profile.py
@@ -185,17 +185,18 @@ uri = "{contents}"
                     pass
 
 
[email protected](reason="https://github.com/apache/arrow-adbc/issues/4086";)
 def test_option_override(tmp_path, monkeypatch) -> None:
     # Test that the driver is optional
     monkeypatch.setenv("ADBC_PROFILE_PATH", str(tmp_path))
 
+    # NOTE: if we use an int value here, the override won't appear to work,
+    # because the options are set separately and do not override each other.
     with (tmp_path / "dev.toml").open("w") as sink:
         sink.write("""
 profile_version = 1
 driver = "adbc_driver_sqlite"
 [Options]
-adbc.sqlite.query.batch_rows = 7
+adbc.sqlite.query.batch_rows = "7"
 """)
 
     key = "adbc.sqlite.query.batch_rows"

Reply via email to