This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 796568454 feat(python/adbc_driver_manager): simplify connect (#3537)
796568454 is described below

commit 7965684548603b79eef0ade574943340348569bc
Author: David Li <[email protected]>
AuthorDate: Fri Oct 17 19:35:14 2025 +0900

    feat(python/adbc_driver_manager): simplify connect (#3537)
    
    - Allow positional driver argument.
    - Allow URI as a toplevel argument given that it is fairly common.
    - Infer driver/URI argument if given a URI-like string as the driver
    argument.
    
    Closes #3517.
---
 c/driver_manager/adbc_driver_manager.cc            | 10 +++++-
 c/driver_manager/adbc_driver_manager_test.cc       | 25 +++++++++++++
 go/adbc/drivermgr/adbc_driver_manager.cc           | 10 +++++-
 .../adbc_driver_manager/dbapi.py                   | 36 +++++++++++++++----
 python/adbc_driver_manager/tests/test_dbapi.py     | 41 ++++++++++++++++++++++
 5 files changed, 114 insertions(+), 8 deletions(-)

diff --git a/c/driver_manager/adbc_driver_manager.cc 
b/c/driver_manager/adbc_driver_manager.cc
index 3943fd81b..a8bac46b5 100644
--- a/c/driver_manager/adbc_driver_manager.cc
+++ b/c/driver_manager/adbc_driver_manager.cc
@@ -1618,7 +1618,15 @@ AdbcStatusCode AdbcDatabaseSetOption(struct 
AdbcDatabase* database, const char*
 
   TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data);
   if (std::strcmp(key, "driver") == 0) {
-    args->driver = value;
+    std::string_view v{value};
+    std::string::size_type pos = v.find("://");
+    if (pos != std::string::npos) {
+      std::string_view d = v.substr(0, pos);
+      args->driver = std::string{d};
+      args->options["uri"] = std::string{v};
+    } else {
+      args->driver = value;
+    }
   } else if (std::strcmp(key, "entrypoint") == 0) {
     args->entrypoint = value;
   } else {
diff --git a/c/driver_manager/adbc_driver_manager_test.cc 
b/c/driver_manager/adbc_driver_manager_test.cc
index 5ce5a406a..18f5f7489 100644
--- a/c/driver_manager/adbc_driver_manager_test.cc
+++ b/c/driver_manager/adbc_driver_manager_test.cc
@@ -1036,4 +1036,29 @@ TEST_F(DriverManifest, CondaPrefix) {
   }
 }
 
+TEST_F(DriverManifest, ImplicitUri) {
+  auto filepath = temp_dir / "postgresql.toml";
+  std::ofstream test_manifest_file(filepath);
+  ASSERT_TRUE(test_manifest_file.is_open());
+  test_manifest_file << R"([Driver]
+shared = "adbc_driver_postgresql")";
+  test_manifest_file.close();
+
+  // Should attempt to load the "postgresql" driver by inferring from the URI
+  std::string uri = "postgresql://a:b@localhost:9999/nonexistent";
+  adbc_validation::Handle<struct AdbcDatabase> database;
+  ASSERT_THAT(AdbcDatabaseNew(&database.value, &error), IsOkStatus(&error));
+  ASSERT_THAT(AdbcDatabaseSetOption(&database.value, "driver", uri.c_str(), 
&error),
+              IsOkStatus(&error));
+  std::string search_path = temp_dir.string();
+  ASSERT_THAT(AdbcDriverManagerDatabaseSetAdditionalSearchPathList(
+                  &database.value, search_path.data(), &error),
+              IsOkStatus(&error));
+  ASSERT_THAT(AdbcDatabaseInit(&database.value, &error),
+              IsStatus(ADBC_STATUS_IO, &error));
+  ASSERT_THAT(error.message, ::testing::HasSubstr("Failed to connect"));
+
+  ASSERT_TRUE(std::filesystem::remove(filepath));
+}
+
 }  // namespace adbc
diff --git a/go/adbc/drivermgr/adbc_driver_manager.cc 
b/go/adbc/drivermgr/adbc_driver_manager.cc
index 3943fd81b..a8bac46b5 100644
--- a/go/adbc/drivermgr/adbc_driver_manager.cc
+++ b/go/adbc/drivermgr/adbc_driver_manager.cc
@@ -1618,7 +1618,15 @@ AdbcStatusCode AdbcDatabaseSetOption(struct 
AdbcDatabase* database, const char*
 
   TempDatabase* args = reinterpret_cast<TempDatabase*>(database->private_data);
   if (std::strcmp(key, "driver") == 0) {
-    args->driver = value;
+    std::string_view v{value};
+    std::string::size_type pos = v.find("://");
+    if (pos != std::string::npos) {
+      std::string_view d = v.substr(0, pos);
+      args->driver = std::string{d};
+      args->options["uri"] = std::string{v};
+    } else {
+      args->driver = value;
+    }
   } else if (std::strcmp(key, "entrypoint") == 0) {
     args->entrypoint = value;
   } else {
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py 
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index a0d3dfa0b..91072768e 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -182,8 +182,9 @@ else:
 
 
 def connect(
-    *,
     driver: Union[str, pathlib.Path],
+    uri: Optional[str] = None,
+    *,
     entrypoint: Optional[str] = None,
     db_kwargs: Optional[Dict[str, Union[str, pathlib.Path]]] = None,
     conn_kwargs: Optional[Dict[str, str]] = None,
@@ -195,11 +196,30 @@ def connect(
     Parameters
     ----------
     driver
-        The driver name. For example, "adbc_driver_sqlite" will
-        attempt to load libadbc_driver_sqlite.so on Linux systems,
-        libadbc_driver_sqlite.dylib on MacOS, and
-        adbc_driver_sqlite.dll on Windows. This may also be a path to
-        the library to load.
+        The driver to use.  This can be one of several values:
+
+        - A driver name or manifest name.
+
+          For example, "adbc_driver_sqlite" will first attempt to load
+          adbc_driver_sqlite.toml from the various search paths.  Then, it
+          will try to load libadbc_driver_sqlite.so on Linux,
+          libadbc_driver_sqlite.dylib on macOS, or adbc_driver_sqlite.dll on
+          Windows.  See :doc:`/format/driver_manifests`.
+
+        - A relative or absolute path to a shared library to load.
+
+        - Only a URI, in which case the URI scheme will be assumed to be the
+          driver name and will be loaded as above.  This will happen when
+          "://" is detected in the driver name.  (It is not assumed that the
+          URI is actually a valid URI.)  The driver manager will pass the URI
+          on unchanged, so this is only useful if the driver supports URIs
+          where the scheme happens to be the same as the driver name (so
+          PostgreSQL works, but not SQLite, for example, as SQLite uses
+          ``file:`` URIs).
+    uri
+        The "uri" parameter to the database (if applicable).  This is
+        equivalent to passing it in ``db_kwargs`` but is slightly cleaner.
+        If given, takes precedence over any value in ``db_kwargs``.
     entrypoint
         The driver-specific entrypoint, if different than the default.
     db_kwargs
@@ -218,10 +238,14 @@ def connect(
 
     db_kwargs = dict(db_kwargs or {})
     db_kwargs["driver"] = driver
+    if uri:
+        db_kwargs["uri"] = uri
     if entrypoint:
         db_kwargs["entrypoint"] = entrypoint
     if conn_kwargs is None:
         conn_kwargs = {}
+    # N.B. treating uri = "postgresql://..." as driver = "postgresql", uri =
+    # "..." is handled at the C driver manager layer
 
     try:
         db = _lib.AdbcDatabase(**db_kwargs)
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py 
b/python/adbc_driver_manager/tests/test_dbapi.py
index 4497bbbe8..38c836a58 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -567,3 +567,44 @@ def test_dbapi_extensions(sqlite):
     with sqlite.cursor() as cur:
         assert cur.execute("SELECT 1").fetchall() == [(1,)]
         assert cur.execute("SELECT 42").fetchall() == [(42,)]
+
+
[email protected]
+def test_connect(tmp_path: pathlib.Path, monkeypatch) -> None:
+    with dbapi.connect(driver="adbc_driver_sqlite") as conn:
+        with conn.cursor() as cur:
+            cur.execute("SELECT 1")
+            assert cur.fetchone() == (1,)
+
+    # https://github.com/apache/arrow-adbc/issues/3517: allow positional
+    # argument
+    with dbapi.connect("adbc_driver_sqlite") as conn:
+        with conn.cursor() as cur:
+            cur.execute("SELECT 1")
+            assert cur.fetchone() == (1,)
+
+    # https://github.com/apache/arrow-adbc/issues/3517: allow URI argument
+    db = tmp_path / "test.db"
+    with dbapi.connect("adbc_driver_sqlite", db.as_uri()) as conn:
+        with conn.cursor() as cur:
+            cur.execute("CREATE TABLE foo (a)")
+            cur.execute("INSERT INTO foo VALUES (1)")
+        conn.commit()
+
+    with dbapi.connect(driver="adbc_driver_sqlite", uri=db.as_uri()) as conn:
+        with conn.cursor() as cur:
+            cur.execute("SELECT * FROM foo")
+            assert cur.fetchone() == (1,)
+
+    monkeypatch.setenv("ADBC_DRIVER_PATH", tmp_path)
+    with (tmp_path / "foobar.toml").open("w") as f:
+        f.write(
+            """
+[Driver]
+shared = "adbc_driver_foobar"
+        """
+        )
+    # Just check that the driver gets detected and loaded (should fail)
+    with pytest.raises(dbapi.ProgrammingError, match="NOT_FOUND"):
+        with dbapi.connect("foobar://localhost:5439"):
+            pass

Reply via email to