This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 06308e431 feat(go/adbc/driver/snowflake): implement WithTransporter 
driver option (#2558)
06308e431 is described below

commit 06308e431b6804d69e5ea73dd7910eff21cf8015
Author: Felipe Vianna <[email protected]>
AuthorDate: Thu Feb 27 17:39:14 2025 -0300

    feat(go/adbc/driver/snowflake): implement WithTransporter driver option 
(#2558)
    
    Relates to #2547
    
    ---------
    
    Co-authored-by: Felipe Vianna <[email protected]>
---
 go/adbc/driver/snowflake/driver.go      | 48 +++++++++++++++++++++++++++++++--
 go/adbc/driver/snowflake/driver_test.go | 30 +++++++++++++++++++++
 2 files changed, 76 insertions(+), 2 deletions(-)

diff --git a/go/adbc/driver/snowflake/driver.go 
b/go/adbc/driver/snowflake/driver.go
index b7f39dd92..c2d7e8dba 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -20,6 +20,7 @@ package snowflake
 import (
        "errors"
        "maps"
+       "net/http"
        "runtime/debug"
        "strings"
 
@@ -170,22 +171,58 @@ func quoteTblName(name string) string {
        return "\"" + strings.ReplaceAll(name, "\"", "\"\"") + "\""
 }
 
+type config struct {
+       *gosnowflake.Config
+}
+
+// Option is a function type to set custom driver configurations.
+//
+// It is intended for configurations that cannot be provided from the standard 
options map,
+// e.g. the underlying HTTP transporter.
+type Option func(*config) error
+
+// WithTransporter sets the custom transporter to use for the Snowflake 
connection.
+// This allows to intercept HTTP requests and responses.
+func WithTransporter(transporter http.RoundTripper) Option {
+       return func(cfg *config) error {
+               cfg.Transporter = transporter
+               return nil
+       }
+}
+
+// Driver is the Snowflake driver interface.
+//
+// It extends the base adbc.Driver to provide additional options
+// when creating the Snowflake database.
+type Driver interface {
+       adbc.Driver
+
+       // NewDatabaseWithOptions creates a new Snowflake database with the 
provided options.
+       NewDatabaseWithOptions(map[string]string, ...Option) (adbc.Database, 
error)
+}
+
+var _ Driver = (*driverImpl)(nil)
+
 type driverImpl struct {
        driverbase.DriverImplBase
 }
 
 // NewDriver creates a new Snowflake driver using the given Arrow allocator.
-func NewDriver(alloc memory.Allocator) adbc.Driver {
+func NewDriver(alloc memory.Allocator) Driver {
        info := driverbase.DefaultDriverInfo("Snowflake")
        if infoVendorVersion != "" {
                if err := info.RegisterInfoCode(adbc.InfoVendorVersion, 
infoVendorVersion); err != nil {
                        panic(err)
                }
        }
-       return driverbase.NewDriver(&driverImpl{DriverImplBase: 
driverbase.NewDriverImplBase(info, alloc)})
+       return &driverImpl{DriverImplBase: driverbase.NewDriverImplBase(info, 
alloc)}
 }
 
 func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database, 
error) {
+       return d.NewDatabaseWithOptions(opts)
+}
+
+func (d *driverImpl) NewDatabaseWithOptions(opts map[string]string, optFuncs 
...Option) (adbc.Database, error) {
        opts = maps.Clone(opts)
        db := &databaseImpl{
                DatabaseImplBase: 
driverbase.NewDatabaseImplBase(&d.DriverImplBase),
@@ -195,5 +232,12 @@ func (d *driverImpl) NewDatabase(opts map[string]string) 
(adbc.Database, error)
                return nil, err
        }
 
+       cfg := &config{Config: db.cfg}
+       for _, opt := range optFuncs {
+               if err := opt(cfg); err != nil {
+                       return nil, err
+               }
+       }
+
        return driverbase.NewDatabase(db), nil
 }
diff --git a/go/adbc/driver/snowflake/driver_test.go 
b/go/adbc/driver/snowflake/driver_test.go
index 9c8185b5b..b4b3ba5da 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -30,6 +30,7 @@ import (
        "encoding/pem"
        "fmt"
        "math"
+       "net/http"
        "os"
        "runtime"
        "strconv"
@@ -350,6 +351,35 @@ func (suite *SnowflakeTests) TearDownTest() {
        suite.driver = nil
 }
 
+type customTransport struct {
+       base   *http.Transport
+       called bool
+}
+
+func (t *customTransport) RoundTrip(r *http.Request) (*http.Response, error) {
+       t.called = true
+       return t.base.RoundTrip(r)
+}
+
+func (suite *SnowflakeTests) TestNewDatabaseWithOptions() {
+       t := suite.T()
+
+       drv := suite.Quirks.SetupDriver(t).(driver.Driver)
+
+       t.Run("WithTransporter", func(t *testing.T) {
+               transport := &customTransport{base: 
gosnowflake.SnowflakeTransport}
+               db, err := 
drv.NewDatabaseWithOptions(suite.Quirks.DatabaseOptions(),
+                       driver.WithTransporter(transport))
+               suite.NoError(err)
+               suite.NotNil(db)
+               cnxn, err := db.Open(suite.ctx)
+               suite.NoError(err)
+               suite.NoError(db.Close())
+               suite.NoError(cnxn.Close())
+               suite.True(transport.called)
+       })
+}
+
 func (suite *SnowflakeTests) TestSqlIngestTimestamp() {
        suite.Require().NoError(suite.Quirks.DropTable(suite.cnxn, 
"bulk_ingest"))
 

Reply via email to