This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 796568454 feat(python/adbc_driver_manager): simplify connect (#3537)
796568454 is described below
commit 7965684548603b79eef0ade574943340348569bc
Author: David Li <[email protected]>
AuthorDate: Fri Oct 17 19:35:14 2025 +0900
feat(python/adbc_driver_manager): simplify connect (#3537)
- Allow positional driver argument.
- Allow URI as a toplevel argument given that it is fairly common.
- Infer driver/URI argument if given a URI-like string as the driver
argument.
Closes #3517.
---
c/driver_manager/adbc_driver_manager.cc | 10 +++++-
c/driver_manager/adbc_driver_manager_test.cc | 25 +++++++++++++
go/adbc/drivermgr/adbc_driver_manager.cc | 10 +++++-
.../adbc_driver_manager/dbapi.py | 36 +++++++++++++++----
python/adbc_driver_manager/tests/test_dbapi.py | 41 ++++++++++++++++++++++
5 files changed, 114 insertions(+), 8 deletions(-)
diff --git a/c/driver_manager/adbc_driver_manager.cc
b/c/driver_manager/adbc_driver_manager.cc
index 3943fd81b..a8bac46b5 100644
--- a/c/driver_manager/adbc_driver_manager.cc
+++ b/c/driver_manager/adbc_driver_manager.cc
@@ -1618,7 +1618,15 @@ AdbcStatusCode AdbcDatabaseSetOption(struct
AdbcDatabase* database, const char*
TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data);
if (std::strcmp(key, "driver") == 0) {
- args->driver = value;
+ std::string_view v{value};
+ std::string::size_type pos = v.find("://");
+ if (pos != std::string::npos) {
+ std::string_view d = v.substr(0, pos);
+ args->driver = std::string{d};
+ args->options["uri"] = std::string{v};
+ } else {
+ args->driver = value;
+ }
} else if (std::strcmp(key, "entrypoint") == 0) {
args->entrypoint = value;
} else {
diff --git a/c/driver_manager/adbc_driver_manager_test.cc
b/c/driver_manager/adbc_driver_manager_test.cc
index 5ce5a406a..18f5f7489 100644
--- a/c/driver_manager/adbc_driver_manager_test.cc
+++ b/c/driver_manager/adbc_driver_manager_test.cc
@@ -1036,4 +1036,29 @@ TEST_F(DriverManifest, CondaPrefix) {
}
}
+TEST_F(DriverManifest, ImplicitUri) {
+ auto filepath = temp_dir / "postgresql.toml";
+ std::ofstream test_manifest_file(filepath);
+ ASSERT_TRUE(test_manifest_file.is_open());
+ test_manifest_file << R"([Driver]
+shared = "adbc_driver_postgresql")";
+ test_manifest_file.close();
+
+ // Should attempt to load the "postgresql" driver by inferring from the URI
+ std::string uri = "postgresql://a:b@localhost:9999/nonexistent";
+ adbc_validation::Handle<struct AdbcDatabase> database;
+ ASSERT_THAT(AdbcDatabaseNew(&database.value, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcDatabaseSetOption(&database.value, "driver", uri.c_str(),
&error),
+ IsOkStatus(&error));
+ std::string search_path = temp_dir.string();
+ ASSERT_THAT(AdbcDriverManagerDatabaseSetAdditionalSearchPathList(
+ &database.value, search_path.data(), &error),
+ IsOkStatus(&error));
+ ASSERT_THAT(AdbcDatabaseInit(&database.value, &error),
+ IsStatus(ADBC_STATUS_IO, &error));
+ ASSERT_THAT(error.message, ::testing::HasSubstr("Failed to connect"));
+
+ ASSERT_TRUE(std::filesystem::remove(filepath));
+}
+
} // namespace adbc
diff --git a/go/adbc/drivermgr/adbc_driver_manager.cc
b/go/adbc/drivermgr/adbc_driver_manager.cc
index 3943fd81b..a8bac46b5 100644
--- a/go/adbc/drivermgr/adbc_driver_manager.cc
+++ b/go/adbc/drivermgr/adbc_driver_manager.cc
@@ -1618,7 +1618,15 @@ AdbcStatusCode AdbcDatabaseSetOption(struct
AdbcDatabase* database, const char*
TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data);
if (std::strcmp(key, "driver") == 0) {
- args->driver = value;
+ std::string_view v{value};
+ std::string::size_type pos = v.find("://");
+ if (pos != std::string::npos) {
+ std::string_view d = v.substr(0, pos);
+ args->driver = std::string{d};
+ args->options["uri"] = std::string{v};
+ } else {
+ args->driver = value;
+ }
} else if (std::strcmp(key, "entrypoint") == 0) {
args->entrypoint = value;
} else {
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index a0d3dfa0b..91072768e 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -182,8 +182,9 @@ else:
def connect(
- *,
driver: Union[str, pathlib.Path],
+ uri: Optional[str] = None,
+ *,
entrypoint: Optional[str] = None,
db_kwargs: Optional[Dict[str, Union[str, pathlib.Path]]] = None,
conn_kwargs: Optional[Dict[str, str]] = None,
@@ -195,11 +196,30 @@ def connect(
Parameters
----------
driver
- The driver name. For example, "adbc_driver_sqlite" will
- attempt to load libadbc_driver_sqlite.so on Linux systems,
- libadbc_driver_sqlite.dylib on MacOS, and
- adbc_driver_sqlite.dll on Windows. This may also be a path to
- the library to load.
+ The driver to use. This can be one of several values:
+
+ - A driver name or manifest name.
+
+ For example, "adbc_driver_sqlite" will first attempt to load
+ adbc_driver_sqlite.toml from the various search paths. Then, it
+ will try to load libadbc_driver_sqlite.so on Linux,
+ libadbc_driver_sqlite.dylib on macOS, or adbc_driver_sqlite.dll on
+ Windows. See :doc:`/format/driver_manifests`.
+
+ - A relative or absolute path to a shared library to load.
+
+ - Only a URI, in which case the URI scheme will be assumed to be the
+ driver name and will be loaded as above. This will happen when
+ "://" is detected in the driver name. (It is not assumed that the
+ URI is actually a valid URI.) The driver manager will pass the URI
+ on unchanged, so this is only useful if the driver supports URIs
+ where the scheme happens to be the same as the driver name (so
+ PostgreSQL works, but not SQLite, for example, as SQLite uses
+ ``file:`` URIs).
+ uri
+ The "uri" parameter to the database (if applicable). This is
+ equivalent to passing it in ``db_kwargs`` but is slightly cleaner.
+ If given, takes precedence over any value in ``db_kwargs``.
entrypoint
The driver-specific entrypoint, if different than the default.
db_kwargs
@@ -218,10 +238,14 @@ def connect(
db_kwargs = dict(db_kwargs or {})
db_kwargs["driver"] = driver
+ if uri:
+ db_kwargs["uri"] = uri
if entrypoint:
db_kwargs["entrypoint"] = entrypoint
if conn_kwargs is None:
conn_kwargs = {}
+ # N.B. treating uri = "postgresql://..." as driver = "postgresql", uri =
+ # "..." is handled at the C driver manager layer
try:
db = _lib.AdbcDatabase(**db_kwargs)
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py
b/python/adbc_driver_manager/tests/test_dbapi.py
index 4497bbbe8..38c836a58 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -567,3 +567,44 @@ def test_dbapi_extensions(sqlite):
with sqlite.cursor() as cur:
assert cur.execute("SELECT 1").fetchall() == [(1,)]
assert cur.execute("SELECT 42").fetchall() == [(42,)]
+
+
[email protected]
+def test_connect(tmp_path: pathlib.Path, monkeypatch) -> None:
+ with dbapi.connect(driver="adbc_driver_sqlite") as conn:
+ with conn.cursor() as cur:
+ cur.execute("SELECT 1")
+ assert cur.fetchone() == (1,)
+
+ # https://github.com/apache/arrow-adbc/issues/3517: allow positional
+ # argument
+ with dbapi.connect("adbc_driver_sqlite") as conn:
+ with conn.cursor() as cur:
+ cur.execute("SELECT 1")
+ assert cur.fetchone() == (1,)
+
+ # https://github.com/apache/arrow-adbc/issues/3517: allow URI argument
+ db = tmp_path / "test.db"
+ with dbapi.connect("adbc_driver_sqlite", db.as_uri()) as conn:
+ with conn.cursor() as cur:
+ cur.execute("CREATE TABLE foo (a)")
+ cur.execute("INSERT INTO foo VALUES (1)")
+ conn.commit()
+
+ with dbapi.connect(driver="adbc_driver_sqlite", uri=db.as_uri()) as conn:
+ with conn.cursor() as cur:
+ cur.execute("SELECT * FROM foo")
+ assert cur.fetchone() == (1,)
+
+ monkeypatch.setenv("ADBC_DRIVER_PATH", tmp_path)
+ with (tmp_path / "foobar.toml").open("w") as f:
+ f.write(
+ """
+[Driver]
+shared = "adbc_driver_foobar"
+ """
+ )
+ # Just check that the driver gets detected and loaded (should fail)
+ with pytest.raises(dbapi.ProgrammingError, match="NOT_FOUND"):
+ with dbapi.connect("foobar://localhost:5439"):
+ pass