This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new cb016366 fix(go/adbc/pkg): allow ConnectionSetOptions before Init 
(#789)
cb016366 is described below

commit cb0163663cd927385b76bd8e67b0ad0ae6675501
Author: David Li <[email protected]>
AuthorDate: Wed Jun 14 17:21:47 2023 -0400

    fix(go/adbc/pkg): allow ConnectionSetOptions before Init (#789)
    
    Fixes #713.
---
 go/adbc/pkg/_tmpl/driver.go.tmpl                 | 39 ++++++++++++++++++++++--
 go/adbc/pkg/flightsql/driver.go                  | 39 ++++++++++++++++++++++--
 go/adbc/pkg/panicdummy/driver.go                 | 39 ++++++++++++++++++++++--
 go/adbc/pkg/snowflake/driver.go                  | 39 ++++++++++++++++++++++--
 python/adbc_driver_flightsql/tests/conftest.py   | 36 ++++++++++++++--------
 python/adbc_driver_flightsql/tests/test_dbapi.py | 18 +++++++++++
 6 files changed, 186 insertions(+), 24 deletions(-)

diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index fbc80aab..03a94c02 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -243,7 +243,8 @@ func {{.Prefix}}DatabaseRelease(db *C.struct_AdbcDatabase, 
err *C.struct_AdbcErr
 }
 
 type cConn struct {
-       cnxn adbc.Connection
+       cnxn     adbc.Connection
+       initArgs map[string]string
 }
 
 func checkConnAlloc(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, 
fname string) bool {
@@ -308,7 +309,22 @@ func {{.Prefix}}ConnectionSetOption(cnxn 
*C.struct_AdbcConnection, key, val *C.c
        }
        conn := getFromHandle[cConn](cnxn.private_data)
 
-       rawCode := errToAdbcErr(err, 
conn.cnxn.(adbc.PostInitOptions).SetOption(C.GoString(key), C.GoString(val)))
+       if conn.cnxn == nil {
+               // not yet initialized
+               k, v := C.GoString(key), C.GoString(val)
+               if conn.initArgs == nil {
+                       conn.initArgs = map[string]string{}
+               }
+               conn.initArgs[k] = v
+               return C.ADBC_STATUS_OK
+       }
+
+       opts, ok := conn.cnxn.(adbc.PostInitOptions)
+       if !ok {
+               setErr(err, "AdbcConnectionSetOption: not supported post-init")
+               return C.ADBC_STATUS_NOT_IMPLEMENTED
+       }
+       rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), 
C.GoString(val)))
        return C.AdbcStatusCode(rawCode)
 }
 
@@ -336,8 +352,25 @@ func {{.Prefix}}ConnectionInit(cnxn 
*C.struct_AdbcConnection, db *C.struct_AdbcD
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
-
        conn.cnxn = c
+
+       if len(conn.initArgs) > 0 {
+               // C allow SetOption before Init, Go doesn't allow options to 
Open so set them now
+               opts, ok := conn.cnxn.(adbc.PostInitOptions)
+               if !ok {
+                       setErr(err, "AdbcConnectionInit: options are not 
supported")
+                       return C.ADBC_STATUS_NOT_IMPLEMENTED
+               }
+
+               for k, v := range conn.initArgs {
+                       rawCode := errToAdbcErr(err, opts.SetOption(k, v))
+                       if rawCode != adbc.StatusOK {
+                               return C.AdbcStatusCode(rawCode)
+                       }
+               }
+               conn.initArgs = nil
+       }
+
        return C.ADBC_STATUS_OK
 }
 
diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go
index 976c28e1..6d5cf75b 100644
--- a/go/adbc/pkg/flightsql/driver.go
+++ b/go/adbc/pkg/flightsql/driver.go
@@ -247,7 +247,8 @@ func FlightSQLDatabaseRelease(db *C.struct_AdbcDatabase, 
err *C.struct_AdbcError
 }
 
 type cConn struct {
-       cnxn adbc.Connection
+       cnxn     adbc.Connection
+       initArgs map[string]string
 }
 
 func checkConnAlloc(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, 
fname string) bool {
@@ -312,7 +313,22 @@ func FlightSQLConnectionSetOption(cnxn 
*C.struct_AdbcConnection, key, val *C.cch
        }
        conn := getFromHandle[cConn](cnxn.private_data)
 
-       rawCode := errToAdbcErr(err, 
conn.cnxn.(adbc.PostInitOptions).SetOption(C.GoString(key), C.GoString(val)))
+       if conn.cnxn == nil {
+               // not yet initialized
+               k, v := C.GoString(key), C.GoString(val)
+               if conn.initArgs == nil {
+                       conn.initArgs = map[string]string{}
+               }
+               conn.initArgs[k] = v
+               return C.ADBC_STATUS_OK
+       }
+
+       opts, ok := conn.cnxn.(adbc.PostInitOptions)
+       if !ok {
+               setErr(err, "AdbcConnectionSetOption: not supported post-init")
+               return C.ADBC_STATUS_NOT_IMPLEMENTED
+       }
+       rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), 
C.GoString(val)))
        return C.AdbcStatusCode(rawCode)
 }
 
@@ -340,8 +356,25 @@ func FlightSQLConnectionInit(cnxn 
*C.struct_AdbcConnection, db *C.struct_AdbcDat
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
-
        conn.cnxn = c
+
+       if len(conn.initArgs) > 0 {
+               // C allow SetOption before Init, Go doesn't allow options to 
Open so set them now
+               opts, ok := conn.cnxn.(adbc.PostInitOptions)
+               if !ok {
+                       setErr(err, "AdbcConnectionInit: options are not 
supported")
+                       return C.ADBC_STATUS_NOT_IMPLEMENTED
+               }
+
+               for k, v := range conn.initArgs {
+                       rawCode := errToAdbcErr(err, opts.SetOption(k, v))
+                       if rawCode != adbc.StatusOK {
+                               return C.AdbcStatusCode(rawCode)
+                       }
+               }
+               conn.initArgs = nil
+       }
+
        return C.ADBC_STATUS_OK
 }
 
diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go
index 915e70d2..374c3cb8 100644
--- a/go/adbc/pkg/panicdummy/driver.go
+++ b/go/adbc/pkg/panicdummy/driver.go
@@ -247,7 +247,8 @@ func PanicDummyDatabaseRelease(db *C.struct_AdbcDatabase, 
err *C.struct_AdbcErro
 }
 
 type cConn struct {
-       cnxn adbc.Connection
+       cnxn     adbc.Connection
+       initArgs map[string]string
 }
 
 func checkConnAlloc(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, 
fname string) bool {
@@ -312,7 +313,22 @@ func PanicDummyConnectionSetOption(cnxn 
*C.struct_AdbcConnection, key, val *C.cc
        }
        conn := getFromHandle[cConn](cnxn.private_data)
 
-       rawCode := errToAdbcErr(err, 
conn.cnxn.(adbc.PostInitOptions).SetOption(C.GoString(key), C.GoString(val)))
+       if conn.cnxn == nil {
+               // not yet initialized
+               k, v := C.GoString(key), C.GoString(val)
+               if conn.initArgs == nil {
+                       conn.initArgs = map[string]string{}
+               }
+               conn.initArgs[k] = v
+               return C.ADBC_STATUS_OK
+       }
+
+       opts, ok := conn.cnxn.(adbc.PostInitOptions)
+       if !ok {
+               setErr(err, "AdbcConnectionSetOption: not supported post-init")
+               return C.ADBC_STATUS_NOT_IMPLEMENTED
+       }
+       rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), 
C.GoString(val)))
        return C.AdbcStatusCode(rawCode)
 }
 
@@ -340,8 +356,25 @@ func PanicDummyConnectionInit(cnxn 
*C.struct_AdbcConnection, db *C.struct_AdbcDa
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
-
        conn.cnxn = c
+
+       if len(conn.initArgs) > 0 {
+               // C allow SetOption before Init, Go doesn't allow options to 
Open so set them now
+               opts, ok := conn.cnxn.(adbc.PostInitOptions)
+               if !ok {
+                       setErr(err, "AdbcConnectionInit: options are not 
supported")
+                       return C.ADBC_STATUS_NOT_IMPLEMENTED
+               }
+
+               for k, v := range conn.initArgs {
+                       rawCode := errToAdbcErr(err, opts.SetOption(k, v))
+                       if rawCode != adbc.StatusOK {
+                               return C.AdbcStatusCode(rawCode)
+                       }
+               }
+               conn.initArgs = nil
+       }
+
        return C.ADBC_STATUS_OK
 }
 
diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go
index 4e7e5f6f..31e2f131 100644
--- a/go/adbc/pkg/snowflake/driver.go
+++ b/go/adbc/pkg/snowflake/driver.go
@@ -247,7 +247,8 @@ func SnowflakeDatabaseRelease(db *C.struct_AdbcDatabase, 
err *C.struct_AdbcError
 }
 
 type cConn struct {
-       cnxn adbc.Connection
+       cnxn     adbc.Connection
+       initArgs map[string]string
 }
 
 func checkConnAlloc(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcError, 
fname string) bool {
@@ -312,7 +313,22 @@ func SnowflakeConnectionSetOption(cnxn 
*C.struct_AdbcConnection, key, val *C.cch
        }
        conn := getFromHandle[cConn](cnxn.private_data)
 
-       rawCode := errToAdbcErr(err, 
conn.cnxn.(adbc.PostInitOptions).SetOption(C.GoString(key), C.GoString(val)))
+       if conn.cnxn == nil {
+               // not yet initialized
+               k, v := C.GoString(key), C.GoString(val)
+               if conn.initArgs == nil {
+                       conn.initArgs = map[string]string{}
+               }
+               conn.initArgs[k] = v
+               return C.ADBC_STATUS_OK
+       }
+
+       opts, ok := conn.cnxn.(adbc.PostInitOptions)
+       if !ok {
+               setErr(err, "AdbcConnectionSetOption: not supported post-init")
+               return C.ADBC_STATUS_NOT_IMPLEMENTED
+       }
+       rawCode := errToAdbcErr(err, opts.SetOption(C.GoString(key), 
C.GoString(val)))
        return C.AdbcStatusCode(rawCode)
 }
 
@@ -340,8 +356,25 @@ func SnowflakeConnectionInit(cnxn 
*C.struct_AdbcConnection, db *C.struct_AdbcDat
        if e != nil {
                return C.AdbcStatusCode(errToAdbcErr(err, e))
        }
-
        conn.cnxn = c
+
+       if len(conn.initArgs) > 0 {
+               // C allow SetOption before Init, Go doesn't allow options to 
Open so set them now
+               opts, ok := conn.cnxn.(adbc.PostInitOptions)
+               if !ok {
+                       setErr(err, "AdbcConnectionInit: options are not 
supported")
+                       return C.ADBC_STATUS_NOT_IMPLEMENTED
+               }
+
+               for k, v := range conn.initArgs {
+                       rawCode := errToAdbcErr(err, opts.SetOption(k, v))
+                       if rawCode != adbc.StatusOK {
+                               return C.AdbcStatusCode(rawCode)
+                       }
+               }
+               conn.initArgs = nil
+       }
+
        return C.ADBC_STATUS_OK
 }
 
diff --git a/python/adbc_driver_flightsql/tests/conftest.py 
b/python/adbc_driver_flightsql/tests/conftest.py
index b80f952e..4ca9508d 100644
--- a/python/adbc_driver_flightsql/tests/conftest.py
+++ b/python/adbc_driver_flightsql/tests/conftest.py
@@ -24,23 +24,37 @@ import adbc_driver_flightsql.dbapi
 import adbc_driver_manager
 
 
[email protected]
-def dremio_uri():
[email protected](scope="session")
+def dremio_uri() -> str:
     dremio_uri = os.environ.get("ADBC_DREMIO_FLIGHTSQL_URI")
     if not dremio_uri:
         pytest.skip("Set ADBC_DREMIO_FLIGHTSQL_URI to run tests")
-    yield dremio_uri
+    return dremio_uri
 
 
[email protected]
-def dremio(dremio_uri):
[email protected](scope="session")
+def dremio_user() -> str:
     username = os.environ.get("ADBC_DREMIO_FLIGHTSQL_USER")
+    if not username:
+        pytest.skip("Set ADBC_DREMIO_FLIGHTSQL_USER to run tests")
+    return username
+
+
[email protected](scope="session")
+def dremio_pass() -> str:
     password = os.environ.get("ADBC_DREMIO_FLIGHTSQL_PASS")
+    if not password:
+        pytest.skip("Set ADBC_DREMIO_FLIGHTSQL_PASS to run tests")
+    return password
+
+
[email protected]
+def dremio(dremio_uri, dremio_user, dremio_pass):
     with adbc_driver_flightsql.connect(
         dremio_uri,
         db_kwargs={
-            adbc_driver_manager.DatabaseOptions.USERNAME.value: username,
-            adbc_driver_manager.DatabaseOptions.PASSWORD.value: password,
+            adbc_driver_manager.DatabaseOptions.USERNAME.value: dremio_user,
+            adbc_driver_manager.DatabaseOptions.PASSWORD.value: dremio_pass,
         },
     ) as db:
         with adbc_driver_manager.AdbcConnection(db) as conn:
@@ -48,14 +62,12 @@ def dremio(dremio_uri):
 
 
 @pytest.fixture
-def dremio_dbapi(dremio_uri):
-    username = os.environ.get("ADBC_DREMIO_FLIGHTSQL_USER")
-    password = os.environ.get("ADBC_DREMIO_FLIGHTSQL_PASS")
+def dremio_dbapi(dremio_uri, dremio_user, dremio_pass):
     with adbc_driver_flightsql.dbapi.connect(
         dremio_uri,
         db_kwargs={
-            adbc_driver_manager.DatabaseOptions.USERNAME.value: username,
-            adbc_driver_manager.DatabaseOptions.PASSWORD.value: password,
+            adbc_driver_manager.DatabaseOptions.USERNAME.value: dremio_user,
+            adbc_driver_manager.DatabaseOptions.PASSWORD.value: dremio_pass,
         },
     ) as conn:
         yield conn
diff --git a/python/adbc_driver_flightsql/tests/test_dbapi.py 
b/python/adbc_driver_flightsql/tests/test_dbapi.py
index 115cd5a4..cf72ab0e 100644
--- a/python/adbc_driver_flightsql/tests/test_dbapi.py
+++ b/python/adbc_driver_flightsql/tests/test_dbapi.py
@@ -17,6 +17,9 @@
 
 import pyarrow
 
+import adbc_driver_flightsql.dbapi
+import adbc_driver_manager
+
 
 def test_query_trivial(dremio_dbapi):
     with dremio_dbapi.cursor() as cur:
@@ -32,3 +35,18 @@ def test_query_partitioned(dremio_dbapi):
 
         cur.adbc_read_partition(partitions[0])
         assert cur.fetchone() == (1,)
+
+
+def test_set_options(dremio_uri, dremio_user, dremio_pass):
+    # Regression test for apache/arrow-adbc#713
+    with adbc_driver_flightsql.dbapi.connect(
+        dremio_uri,
+        db_kwargs={
+            adbc_driver_manager.DatabaseOptions.USERNAME.value: dremio_user,
+            adbc_driver_manager.DatabaseOptions.PASSWORD.value: dremio_pass,
+        },
+        conn_kwargs={
+            "adbc.flight.sql.rpc.call_header.x-foo": "1",
+        },
+    ):
+        pass

Reply via email to