This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 7a7a5f68 refactor(go/adbc/driver): add driver framework (#1081)
7a7a5f68 is described below
commit 7a7a5f685ebc0a9d00316ecb0c5da02993cc88cc
Author: David Li <[email protected]>
AuthorDate: Fri Sep 29 12:16:24 2023 -0400
refactor(go/adbc/driver): add driver framework (#1081)
Fixes #996.
---
go/adbc/driver/driverbase/database.go | 147 ++++
go/adbc/driver/driverbase/driver.go | 66 ++
.../flightsql.go => driver/driverbase/error.go} | 23 +-
.../flightsql.go => driver/driverbase/logging.go} | 16 +-
.../driver/flightsql/flightsql_adbc_server_test.go | 2 +-
go/adbc/driver/flightsql/flightsql_adbc_test.go | 6 +-
.../{flightsql_adbc.go => flightsql_connection.go} | 816 +--------------------
go/adbc/driver/flightsql/flightsql_database.go | 491 +++++++++++++
go/adbc/driver/flightsql/flightsql_driver.go | 149 ++++
go/adbc/driver/flightsql/logging.go | 9 -
go/adbc/driver/flightsql/timeouts.go | 223 ++++++
go/adbc/driver/panicdummy/panicdummy_adbc.go | 5 +
go/adbc/driver/snowflake/connection.go | 10 +-
go/adbc/driver/snowflake/driver.go | 400 +---------
go/adbc/driver/snowflake/driver_test.go | 2 +-
.../snowflake/{driver.go => snowflake_database.go} | 215 +-----
go/adbc/pkg/_tmpl/driver.go.tmpl | 2 +-
go/adbc/pkg/flightsql/driver.go | 2 +-
go/adbc/pkg/gen/main.go | 4 +-
go/adbc/pkg/panicdummy/driver.go | 2 +-
go/adbc/pkg/snowflake/driver.go | 2 +-
go/adbc/sqldriver/flightsql/flightsql.go | 4 +-
22 files changed, 1153 insertions(+), 1443 deletions(-)
diff --git a/go/adbc/driver/driverbase/database.go
b/go/adbc/driver/driverbase/database.go
new file mode 100644
index 00000000..51e3712c
--- /dev/null
+++ b/go/adbc/driver/driverbase/database.go
@@ -0,0 +1,147 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package driverbase
+
+import (
+ "context"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+ "github.com/apache/arrow/go/v13/arrow/memory"
+ "golang.org/x/exp/slog"
+)
+
+// DatabaseImpl is an interface that drivers implement to provide
+// vendor-specific functionality.
+type DatabaseImpl interface {
+ adbc.GetSetOptions
+ Base() *DatabaseImplBase
+ Open(context.Context) (adbc.Connection, error)
+ SetOptions(map[string]string) error
+}
+
+// DatabaseImplBase is a struct that provides default implementations of the
+// DatabaseImpl interface. It is meant to be used as a composite struct for a
+// driver's DatabaseImpl implementation.
+type DatabaseImplBase struct {
+ Alloc memory.Allocator
+ ErrorHelper ErrorHelper
+ Logger *slog.Logger
+}
+
+// NewDatabaseImplBase instantiates DatabaseImplBase. name is the driver's
+// name and is used to construct error messages. alloc is an Arrow allocator
+// to use.
+func NewDatabaseImplBase(driver *DriverImplBase) DatabaseImplBase {
+ return DatabaseImplBase{Alloc: driver.Alloc, ErrorHelper:
driver.ErrorHelper, Logger: nilLogger()}
+}
+
+func (base *DatabaseImplBase) Base() *DatabaseImplBase {
+ return base
+}
+
+func (base *DatabaseImplBase) GetOption(key string) (string, error) {
+ return "", base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown
database option '%s'", key)
+}
+
+func (base *DatabaseImplBase) GetOptionBytes(key string) ([]byte, error) {
+ return nil, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown
database option '%s'", key)
+}
+
+func (base *DatabaseImplBase) GetOptionDouble(key string) (float64, error) {
+ return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown
database option '%s'", key)
+}
+
+func (base *DatabaseImplBase) GetOptionInt(key string) (int64, error) {
+ return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown
database option '%s'", key)
+}
+
+func (base *DatabaseImplBase) SetOption(key string, val string) error {
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown
database option '%s'", key)
+}
+
+func (base *DatabaseImplBase) SetOptionBytes(key string, val []byte) error {
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown
database option '%s'", key)
+}
+
+func (base *DatabaseImplBase) SetOptionDouble(key string, val float64) error {
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown
database option '%s'", key)
+}
+
+func (base *DatabaseImplBase) SetOptionInt(key string, val int64) error {
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown
database option '%s'", key)
+}
+
+// database is the implementation of adbc.Database.
+type database struct {
+ impl DatabaseImpl
+}
+
+// NewDatabase wraps a DatabaseImpl to create an adbc.Database.
+func NewDatabase(impl DatabaseImpl) adbc.Database {
+ return &database{
+ impl: impl,
+ }
+}
+
+func (db *database) GetOption(key string) (string, error) {
+ return db.impl.GetOption(key)
+}
+
+func (db *database) GetOptionBytes(key string) ([]byte, error) {
+ return db.impl.GetOptionBytes(key)
+}
+
+func (db *database) GetOptionDouble(key string) (float64, error) {
+ return db.impl.GetOptionDouble(key)
+}
+
+func (db *database) GetOptionInt(key string) (int64, error) {
+ return db.impl.GetOptionInt(key)
+}
+
+func (db *database) SetOption(key string, val string) error {
+ return db.impl.SetOption(key, val)
+}
+
+func (db *database) SetOptionBytes(key string, val []byte) error {
+ return db.impl.SetOptionBytes(key, val)
+}
+
+func (db *database) SetOptionDouble(key string, val float64) error {
+ return db.impl.SetOptionDouble(key, val)
+}
+
+func (db *database) SetOptionInt(key string, val int64) error {
+ return db.impl.SetOptionInt(key, val)
+}
+
+func (db *database) Open(ctx context.Context) (adbc.Connection, error) {
+ return db.impl.Open(ctx)
+}
+
+func (db *database) SetLogger(logger *slog.Logger) {
+ if logger != nil {
+ db.impl.Base().Logger = logger
+ } else {
+ db.impl.Base().Logger = nilLogger()
+ }
+}
+
+func (db *database) SetOptions(opts map[string]string) error {
+ return db.impl.SetOptions(opts)
+}
diff --git a/go/adbc/driver/driverbase/driver.go
b/go/adbc/driver/driverbase/driver.go
new file mode 100644
index 00000000..905965f4
--- /dev/null
+++ b/go/adbc/driver/driverbase/driver.go
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Package driverbase provides a framework for implementing ADBC drivers in
+// Go. It intends to reduce boilerplate for common functionality and managing
+// state transitions.
+package driverbase
+
+import (
+ "github.com/apache/arrow-adbc/go/adbc"
+ "github.com/apache/arrow/go/v13/arrow/memory"
+)
+
+// DriverImpl is an interface that drivers implement to provide
+// vendor-specific functionality.
+type DriverImpl interface {
+ Base() *DriverImplBase
+ NewDatabase(opts map[string]string) (adbc.Database, error)
+}
+
+// DatabaseImplBase is a struct that provides default implementations of the
+// DriverImpl interface. It is meant to be used as a composite struct for a
+// driver's DriverImpl implementation.
+type DriverImplBase struct {
+ Alloc memory.Allocator
+ ErrorHelper ErrorHelper
+}
+
+func NewDriverImplBase(name string, alloc memory.Allocator) DriverImplBase {
+ if alloc == nil {
+ alloc = memory.DefaultAllocator
+ }
+ return DriverImplBase{Alloc: alloc, ErrorHelper:
ErrorHelper{DriverName: name}}
+}
+
+func (base *DriverImplBase) Base() *DriverImplBase {
+ return base
+}
+
+// driver is the actual implementation of adbc.Driver.
+type driver struct {
+ impl DriverImpl
+}
+
+// NewDatabase wraps a DriverImpl to create an adbc.Driver.
+func NewDriver(impl DriverImpl) adbc.Driver {
+ return &driver{impl}
+}
+
+func (drv *driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
+ return drv.impl.NewDatabase(opts)
+}
diff --git a/go/adbc/sqldriver/flightsql/flightsql.go
b/go/adbc/driver/driverbase/error.go
similarity index 67%
copy from go/adbc/sqldriver/flightsql/flightsql.go
copy to go/adbc/driver/driverbase/error.go
index f318cbb6..cb60c17e 100644
--- a/go/adbc/sqldriver/flightsql/flightsql.go
+++ b/go/adbc/driver/driverbase/error.go
@@ -15,16 +15,23 @@
// specific language governing permissions and limitations
// under the License.
-package flightsql
+package driverbase
import (
- "database/sql"
- "github.com/apache/arrow-adbc/go/adbc/driver/flightsql"
- "github.com/apache/arrow-adbc/go/adbc/sqldriver"
+ "fmt"
+
+ "github.com/apache/arrow-adbc/go/adbc"
)
-func init() {
- sql.Register("flightsql", sqldriver.Driver{
- Driver: flightsql.Driver{},
- })
+// ErrorHelper helps format errors for ADBC drivers.
+type ErrorHelper struct {
+ DriverName string
+}
+
+func (helper *ErrorHelper) Errorf(code adbc.Status, message string, format
...interface{}) error {
+ msg := fmt.Sprintf(message, format...)
+ return adbc.Error{
+ Code: code,
+ Msg: fmt.Sprintf("[%s] %s", helper.DriverName, msg),
+ }
}
diff --git a/go/adbc/sqldriver/flightsql/flightsql.go
b/go/adbc/driver/driverbase/logging.go
similarity index 78%
copy from go/adbc/sqldriver/flightsql/flightsql.go
copy to go/adbc/driver/driverbase/logging.go
index f318cbb6..2660b4a9 100644
--- a/go/adbc/sqldriver/flightsql/flightsql.go
+++ b/go/adbc/driver/driverbase/logging.go
@@ -15,16 +15,18 @@
// specific language governing permissions and limitations
// under the License.
-package flightsql
+package driverbase
import (
- "database/sql"
- "github.com/apache/arrow-adbc/go/adbc/driver/flightsql"
- "github.com/apache/arrow-adbc/go/adbc/sqldriver"
+ "os"
+
+ "golang.org/x/exp/slog"
)
-func init() {
- sql.Register("flightsql", sqldriver.Driver{
- Driver: flightsql.Driver{},
+func nilLogger() *slog.Logger {
+ h := slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
+ AddSource: false,
+ Level: slog.LevelError,
})
+ return slog.New(h)
}
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index e6adaae1..4419ccfb 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -75,7 +75,7 @@ func (suite *ServerBasedTests) DoSetupSuite(srv
flightsql.Server, srvMiddleware
"uri": uri,
}
maps.Copy(args, dbArgs)
- suite.db, err = (driver.Driver{}).NewDatabase(args)
+ suite.db, err =
(driver.NewDriver(memory.DefaultAllocator)).NewDatabase(args)
suite.Require().NoError(err)
}
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 057fd08c..55ef0f88 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -100,7 +100,7 @@ func (s *FlightSQLQuirks) SetupDriver(t *testing.T)
adbc.Driver {
_ = s.s.Serve()
}()
- return driver.Driver{Alloc: s.mem}
+ return driver.NewDriver(s.mem)
}
func (s *FlightSQLQuirks) TearDownDriver(t *testing.T, _ adbc.Driver) {
@@ -902,7 +902,7 @@ func (suite *ConnectionTests) SetupSuite() {
var err error
suite.ctx = context.Background()
- suite.Driver = driver.Driver{Alloc: suite.alloc}
+ suite.Driver = driver.NewDriver(suite.alloc)
suite.DB, err = suite.Driver.NewDatabase(map[string]string{
adbc.OptionKeyURI: "grpc+tcp://" + suite.server.Addr().String(),
})
@@ -995,7 +995,7 @@ func (suite *DomainSocketTests) SetupSuite() {
}()
suite.ctx = context.Background()
- suite.Driver = driver.Driver{Alloc: suite.alloc}
+ suite.Driver = driver.NewDriver(suite.alloc)
suite.DB, err = suite.Driver.NewDatabase(map[string]string{
adbc.OptionKeyURI: "grpc+unix://" + listenSocket,
})
diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go
b/go/adbc/driver/flightsql/flightsql_connection.go
similarity index 57%
rename from go/adbc/driver/flightsql/flightsql_adbc.go
rename to go/adbc/driver/flightsql/flightsql_connection.go
index d4583ac9..cce2d35d 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_connection.go
@@ -15,36 +15,15 @@
// specific language governing permissions and limitations
// under the License.
-// Package flightsql is an ADBC Driver Implementation for Flight SQL
-// natively in go.
-//
-// It can be used to register a driver for database/sql by importing
-// github.com/apache/arrow-adbc/go/adbc/sqldriver and running:
-//
-// sql.Register("flightsql", sqldriver.Driver{flightsql.Driver{}})
-//
-// You can then open a flightsql connection with the database/sql
-// standard package by using:
-//
-// db, err := sql.Open("flightsql", "uri=<flight sql db url>")
-//
-// The URI passed *must* contain a scheme, most likely "grpc+tcp://"
package flightsql
import (
"bytes"
"context"
- "crypto/tls"
- "crypto/x509"
"fmt"
"io"
"math"
- "net/url"
- "runtime/debug"
- "strconv"
"strings"
- "sync"
- "time"
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/internal"
@@ -54,801 +33,18 @@ import (
"github.com/apache/arrow/go/v13/arrow/flight/flightsql"
"github.com/apache/arrow/go/v13/arrow/flight/flightsql/schema_ref"
"github.com/apache/arrow/go/v13/arrow/ipc"
- "github.com/apache/arrow/go/v13/arrow/memory"
"github.com/bluele/gcache"
- "golang.org/x/exp/maps"
- "golang.org/x/exp/slog"
"google.golang.org/grpc"
grpccodes "google.golang.org/grpc/codes"
- "google.golang.org/grpc/credentials"
- "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
grpcstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
-const (
- OptionAuthority = "adbc.flight.sql.client_option.authority"
- OptionMTLSCertChain =
"adbc.flight.sql.client_option.mtls_cert_chain"
- OptionMTLSPrivateKey =
"adbc.flight.sql.client_option.mtls_private_key"
- OptionSSLOverrideHostname =
"adbc.flight.sql.client_option.tls_override_hostname"
- OptionSSLSkipVerify =
"adbc.flight.sql.client_option.tls_skip_verify"
- OptionSSLRootCerts =
"adbc.flight.sql.client_option.tls_root_certs"
- OptionWithBlock = "adbc.flight.sql.client_option.with_block"
- OptionWithMaxMsgSize =
"adbc.flight.sql.client_option.with_max_msg_size"
- OptionAuthorizationHeader = "adbc.flight.sql.authorization_header"
- OptionTimeoutFetch = "adbc.flight.sql.rpc.timeout_seconds.fetch"
- OptionTimeoutQuery = "adbc.flight.sql.rpc.timeout_seconds.query"
- OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update"
- OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header."
- OptionCookieMiddleware = "adbc.flight.sql.rpc.with_cookie_middleware"
- infoDriverName = "ADBC Flight SQL Driver - Go"
-)
-
-var (
- infoDriverVersion string
- infoDriverArrowVersion string
- infoSupportedCodes []adbc.InfoCode
-)
-
-var errNoTransactionSupport = adbc.Error{
- Msg: "[Flight SQL] server does not report transaction support",
- Code: adbc.StatusNotImplemented,
-}
-
-func init() {
- if info, ok := debug.ReadBuildInfo(); ok {
- for _, dep := range info.Deps {
- switch {
- case dep.Path ==
"github.com/apache/arrow-adbc/go/adbc/driver/flightsql":
- infoDriverVersion = dep.Version
- case strings.HasPrefix(dep.Path,
"github.com/apache/arrow/go/"):
- infoDriverArrowVersion = dep.Version
- }
- }
- }
- // XXX: Deps not populated in tests
- // https://github.com/golang/go/issues/33976
- if infoDriverVersion == "" {
- infoDriverVersion = "(unknown or development build)"
- }
- if infoDriverArrowVersion == "" {
- infoDriverArrowVersion = "(unknown or development build)"
- }
-
- infoSupportedCodes = []adbc.InfoCode{
- adbc.InfoDriverName,
- adbc.InfoDriverVersion,
- adbc.InfoDriverArrowVersion,
- adbc.InfoDriverADBCVersion,
- adbc.InfoVendorName,
- adbc.InfoVendorVersion,
- adbc.InfoVendorArrowVersion,
- }
-}
-
-type Driver struct {
- Alloc memory.Allocator
-}
-
-func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
- opts = maps.Clone(opts)
- uri, ok := opts[adbc.OptionKeyURI]
- if !ok {
- return nil, adbc.Error{
- Msg: "URI required for a FlightSQL DB",
- Code: adbc.StatusInvalidArgument,
- }
- }
- delete(opts, adbc.OptionKeyURI)
-
- db := &database{alloc: d.Alloc, hdrs: make(metadata.MD)}
- if db.alloc == nil {
- db.alloc = memory.DefaultAllocator
- }
-
- var err error
- if db.uri, err = url.Parse(uri); err != nil {
- return nil, adbc.Error{Msg: err.Error(), Code:
adbc.StatusInvalidArgument}
- }
-
- // Do not set WithBlock since it converts some types of connection
- // errors to infinite hangs
- // Use WithMaxMsgSize(16 MiB) since Flight services tend to send large
messages
- db.dialOpts.block = false
- db.dialOpts.maxMsgSize = 16 * 1024 * 1024
-
- db.logger = nilLogger()
- db.options = make(map[string]string)
-
- return db, db.SetOptions(opts)
-}
-
-type dbDialOpts struct {
- opts []grpc.DialOption
- block bool
- maxMsgSize int
- authority string
-}
-
-func (d *dbDialOpts) rebuild() {
- d.opts = []grpc.DialOption{
-
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(d.maxMsgSize),
- grpc.MaxCallSendMsgSize(d.maxMsgSize)),
- grpc.WithUserAgent("ADBC Flight SQL Driver " +
infoDriverVersion),
- }
- if d.block {
- d.opts = append(d.opts, grpc.WithBlock())
- }
- if d.authority != "" {
- d.opts = append(d.opts, grpc.WithAuthority(d.authority))
- }
-}
-
-type database struct {
- uri *url.URL
- creds credentials.TransportCredentials
- user, pass string
- hdrs metadata.MD
- timeout timeoutOption
- dialOpts dbDialOpts
- enableCookies bool
- logger *slog.Logger
- options map[string]string
-
- alloc memory.Allocator
-}
-
-func (d *database) SetLogger(logger *slog.Logger) {
- if logger != nil {
- d.logger = logger
- } else {
- d.logger = nilLogger()
- }
-}
-
-func (d *database) SetOptions(cnOptions map[string]string) error {
- var tlsConfig tls.Config
-
- for k, v := range cnOptions {
- d.options[k] = v
- }
-
- if authority, ok := cnOptions[OptionAuthority]; ok {
- d.dialOpts.authority = authority
- delete(cnOptions, OptionAuthority)
- }
-
- mtlsCert := cnOptions[OptionMTLSCertChain]
- mtlsKey := cnOptions[OptionMTLSPrivateKey]
- switch {
- case mtlsCert != "" && mtlsKey != "":
- cert, err := tls.X509KeyPair([]byte(mtlsCert), []byte(mtlsKey))
- if err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid mTLS certificate:
%#v", err),
- Code: adbc.StatusInvalidArgument,
- }
- }
- tlsConfig.Certificates = []tls.Certificate{cert}
- delete(cnOptions, OptionMTLSCertChain)
- delete(cnOptions, OptionMTLSPrivateKey)
- case mtlsCert != "":
- return adbc.Error{
- Msg: fmt.Sprintf("Must provide both '%s' and '%s',
only provided '%s'", OptionMTLSCertChain, OptionMTLSPrivateKey,
OptionMTLSCertChain),
- Code: adbc.StatusInvalidArgument,
- }
- case mtlsKey != "":
- return adbc.Error{
- Msg: fmt.Sprintf("Must provide both '%s' and '%s',
only provided '%s'", OptionMTLSCertChain, OptionMTLSPrivateKey,
OptionMTLSPrivateKey),
- Code: adbc.StatusInvalidArgument,
- }
- }
-
- if hostname, ok := cnOptions[OptionSSLOverrideHostname]; ok {
- tlsConfig.ServerName = hostname
- delete(cnOptions, OptionSSLOverrideHostname)
- }
-
- if val, ok := cnOptions[OptionSSLSkipVerify]; ok {
- if val == adbc.OptionValueEnabled {
- tlsConfig.InsecureSkipVerify = true
- } else if val == adbc.OptionValueDisabled {
- tlsConfig.InsecureSkipVerify = false
- } else {
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for database
option '%s': '%s'", OptionSSLSkipVerify, val),
- Code: adbc.StatusInvalidArgument,
- }
- }
- delete(cnOptions, OptionSSLSkipVerify)
- }
-
- if cert, ok := cnOptions[OptionSSLRootCerts]; ok {
- cp := x509.NewCertPool()
- if !cp.AppendCertsFromPEM([]byte(cert)) {
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for database
option '%s': failed to append certificates", OptionSSLRootCerts),
- Code: adbc.StatusInvalidArgument,
- }
- }
- tlsConfig.RootCAs = cp
- delete(cnOptions, OptionSSLRootCerts)
- }
-
- d.creds = credentials.NewTLS(&tlsConfig)
-
- if auth, ok := cnOptions[OptionAuthorizationHeader]; ok {
- d.hdrs.Set("authorization", auth)
- delete(cnOptions, OptionAuthorizationHeader)
- }
-
- if u, ok := cnOptions[adbc.OptionKeyUsername]; ok {
- if d.hdrs.Len() > 0 {
- return adbc.Error{
- Msg: "Authorization header already provided,
do not provide user/pass also",
- Code: adbc.StatusInvalidArgument,
- }
- }
- d.user = u
- delete(cnOptions, adbc.OptionKeyUsername)
- }
-
- if p, ok := cnOptions[adbc.OptionKeyPassword]; ok {
- if d.hdrs.Len() > 0 {
- return adbc.Error{
- Msg: "Authorization header already provided,
do not provide user/pass also",
- Code: adbc.StatusInvalidArgument,
- }
- }
- d.pass = p
- delete(cnOptions, adbc.OptionKeyPassword)
- }
-
- var err error
- if tv, ok := cnOptions[OptionTimeoutFetch]; ok {
- if err = d.timeout.setTimeoutString(OptionTimeoutFetch, tv);
err != nil {
- return err
- }
- delete(cnOptions, OptionTimeoutFetch)
- }
-
- if tv, ok := cnOptions[OptionTimeoutQuery]; ok {
- if err = d.timeout.setTimeoutString(OptionTimeoutQuery, tv);
err != nil {
- return err
- }
- delete(cnOptions, OptionTimeoutQuery)
- }
-
- if tv, ok := cnOptions[OptionTimeoutUpdate]; ok {
- if err = d.timeout.setTimeoutString(OptionTimeoutUpdate, tv);
err != nil {
- return err
- }
- delete(cnOptions, OptionTimeoutUpdate)
- }
-
- if val, ok := cnOptions[OptionWithBlock]; ok {
- if val == adbc.OptionValueEnabled {
- d.dialOpts.block = true
- } else if val == adbc.OptionValueDisabled {
- d.dialOpts.block = false
- } else {
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for database
option '%s': '%s'", OptionWithBlock, val),
- Code: adbc.StatusInvalidArgument,
- }
- }
- delete(cnOptions, OptionWithBlock)
- }
-
- if val, ok := cnOptions[OptionWithMaxMsgSize]; ok {
- var err error
- var size int
- if size, err = strconv.Atoi(val); err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for database
option '%s': '%s' is not a positive integer", OptionWithMaxMsgSize, val),
- Code: adbc.StatusInvalidArgument,
- }
- } else if size <= 0 {
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for database
option '%s': '%s' is not a positive integer", OptionWithMaxMsgSize, val),
- Code: adbc.StatusInvalidArgument,
- }
- }
- d.dialOpts.maxMsgSize = size
- delete(cnOptions, OptionWithMaxMsgSize)
- }
- d.dialOpts.rebuild()
-
- if val, ok := cnOptions[OptionCookieMiddleware]; ok {
- if val == adbc.OptionValueEnabled {
- d.enableCookies = true
- } else if val == adbc.OptionValueDisabled {
- d.enableCookies = false
- } else {
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for database
option '%s': '%s'", OptionCookieMiddleware, val),
- Code: adbc.StatusInvalidArgument,
- }
- }
- delete(cnOptions, OptionCookieMiddleware)
- }
-
- for key, val := range cnOptions {
- if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
- d.hdrs.Append(strings.TrimPrefix(key,
OptionRPCCallHeaderPrefix), val)
- continue
- }
- return adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] Unknown database option
'%s'", key),
- Code: adbc.StatusInvalidArgument,
- }
- }
-
- return nil
-}
-
-func (d *database) GetOption(key string) (string, error) {
- switch key {
- case OptionTimeoutFetch:
- return d.timeout.fetchTimeout.String(), nil
- case OptionTimeoutQuery:
- return d.timeout.queryTimeout.String(), nil
- case OptionTimeoutUpdate:
- return d.timeout.updateTimeout.String(), nil
- }
- if val, ok := d.options[key]; ok {
- return val, nil
- }
- return "", adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
-}
-func (d *database) GetOptionBytes(key string) ([]byte, error) {
- return nil, adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
-}
-func (d *database) GetOptionInt(key string) (int64, error) {
- switch key {
- case OptionTimeoutFetch:
- fallthrough
- case OptionTimeoutQuery:
- fallthrough
- case OptionTimeoutUpdate:
- val, err := d.GetOptionDouble(key)
- if err != nil {
- return 0, err
- }
- return int64(val), nil
- }
-
- return 0, adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
-}
-func (d *database) GetOptionDouble(key string) (float64, error) {
- switch key {
- case OptionTimeoutFetch:
- return d.timeout.fetchTimeout.Seconds(), nil
- case OptionTimeoutQuery:
- return d.timeout.queryTimeout.Seconds(), nil
- case OptionTimeoutUpdate:
- return d.timeout.updateTimeout.Seconds(), nil
- }
-
- return 0, adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
-}
-func (d *database) SetOption(key, value string) error {
- // We can't change most options post-init
- switch key {
- case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate:
- return d.timeout.setTimeoutString(key, value)
- }
- if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
- d.hdrs.Set(strings.TrimPrefix(key, OptionRPCCallHeaderPrefix),
value)
- }
- return adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-func (d *database) SetOptionBytes(key string, value []byte) error {
- return adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-func (d *database) SetOptionInt(key string, value int64) error {
- switch key {
- case OptionTimeoutFetch:
- fallthrough
- case OptionTimeoutQuery:
- fallthrough
- case OptionTimeoutUpdate:
- return d.timeout.setTimeout(key, float64(value))
- }
-
- return adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-func (d *database) SetOptionDouble(key string, value float64) error {
- switch key {
- case OptionTimeoutFetch:
- fallthrough
- case OptionTimeoutQuery:
- fallthrough
- case OptionTimeoutUpdate:
- return d.timeout.setTimeout(key, value)
- }
-
- return adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-
-type timeoutOption struct {
- grpc.EmptyCallOption
-
- // timeout for DoGet requests
- fetchTimeout time.Duration
- // timeout for GetFlightInfo requests
- queryTimeout time.Duration
- // timeout for DoPut or DoAction requests
- updateTimeout time.Duration
-}
-
-func (t *timeoutOption) setTimeout(key string, value float64) error {
- if math.IsNaN(value) || math.IsInf(value, 0) || value < 0 {
- return adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] invalid timeout option
value %s = %f: timeouts must be non-negative and finite",
- key, value),
- Code: adbc.StatusInvalidArgument,
- }
- }
-
- timeout := time.Duration(value * float64(time.Second))
-
- switch key {
- case OptionTimeoutFetch:
- t.fetchTimeout = timeout
- case OptionTimeoutQuery:
- t.queryTimeout = timeout
- case OptionTimeoutUpdate:
- t.updateTimeout = timeout
- default:
- return adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option
'%s'", key),
- Code: adbc.StatusNotImplemented,
- }
- }
- return nil
-}
-
-func (t *timeoutOption) setTimeoutString(key string, value string) error {
- timeout, err := strconv.ParseFloat(value, 64)
- if err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("[Flight SQL] invalid timeout option
value %s = %s: %s",
- key, value, err.Error()),
- Code: adbc.StatusInvalidArgument,
- }
- }
- return t.setTimeout(key, timeout)
-}
-
-func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration,
bool) {
- for _, opt := range callOptions {
- if to, ok := opt.(timeoutOption); ok {
- var tm time.Duration
- switch {
- case strings.HasSuffix(method, "DoGet"):
- tm = to.fetchTimeout
- case strings.HasSuffix(method, "GetFlightInfo"):
- tm = to.queryTimeout
- case strings.HasSuffix(method, "DoPut") ||
strings.HasSuffix(method, "DoAction"):
- tm = to.updateTimeout
- }
-
- return tm, tm > 0
- }
- }
-
- return 0, false
-}
-
-func unaryTimeoutInterceptor(ctx context.Context, method string, req, reply
any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption)
error {
- if tm, ok := getTimeout(method, opts); ok {
- ctx, cancel := context.WithTimeout(ctx, tm)
- defer cancel()
- return invoker(ctx, method, req, reply, cc, opts...)
- }
-
- return invoker(ctx, method, req, reply, cc, opts...)
-}
-
-type streamEventType int
-
-const (
- receiveEndEvent streamEventType = iota
- errorEvent
-)
-
-type streamEvent struct {
- Type streamEventType
- Err error
-}
-
-type wrappedClientStream struct {
- grpc.ClientStream
-
- desc *grpc.StreamDesc
- events chan streamEvent
- eventsDone chan struct{}
-}
-
-func (w *wrappedClientStream) RecvMsg(m any) error {
- err := w.ClientStream.RecvMsg(m)
-
- switch {
- case err == nil && !w.desc.ServerStreams:
- w.sendStreamEvent(receiveEndEvent, nil)
- case err == io.EOF:
- w.sendStreamEvent(receiveEndEvent, nil)
- case err != nil:
- w.sendStreamEvent(errorEvent, err)
- }
-
- return err
-}
-
-func (w *wrappedClientStream) SendMsg(m any) error {
- err := w.ClientStream.SendMsg(m)
- if err != nil {
- w.sendStreamEvent(errorEvent, err)
- }
- return err
-}
-
-func (w *wrappedClientStream) Header() (metadata.MD, error) {
- md, err := w.ClientStream.Header()
- if err != nil {
- w.sendStreamEvent(errorEvent, err)
- }
- return md, err
-}
-
-func (w *wrappedClientStream) CloseSend() error {
- err := w.ClientStream.CloseSend()
- if err != nil {
- w.sendStreamEvent(errorEvent, err)
- }
- return err
-}
-
-func (w *wrappedClientStream) sendStreamEvent(eventType streamEventType, err
error) {
- select {
- case <-w.eventsDone:
- case w.events <- streamEvent{Type: eventType, Err: err}:
- }
-}
-
-func streamTimeoutInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc
*grpc.ClientConn, method string, streamer grpc.Streamer, opts
...grpc.CallOption) (grpc.ClientStream, error) {
- if tm, ok := getTimeout(method, opts); ok {
- ctx, cancel := context.WithTimeout(ctx, tm)
- s, err := streamer(ctx, desc, cc, method, opts...)
- if err != nil {
- defer cancel()
- return s, err
- }
-
- events, eventsDone := make(chan streamEvent), make(chan
struct{})
- go func() {
- defer close(eventsDone)
- defer cancel()
-
- for {
- select {
- case event := <-events:
- // split by event type in case we want
to add more logging
- // or even adding in some telemetry in
the future.
- // Errors will already be propagated by
the RecvMsg, SendMsg
- // methods.
- switch event.Type {
- case receiveEndEvent:
- return
- case errorEvent:
- return
- }
- case <-ctx.Done():
- return
- }
- }
- }()
-
- stream := &wrappedClientStream{
- ClientStream: s,
- desc: desc,
- events: events,
- eventsDone: eventsDone,
- }
- return stream, nil
- }
-
- return streamer(ctx, desc, cc, method, opts...)
-}
-
-type bearerAuthMiddleware struct {
- mutex sync.RWMutex
- hdrs metadata.MD
-}
-
-func (b *bearerAuthMiddleware) StartCall(ctx context.Context) context.Context {
- md, _ := metadata.FromOutgoingContext(ctx)
- b.mutex.RLock()
- defer b.mutex.RUnlock()
- return metadata.NewOutgoingContext(ctx, metadata.Join(md, b.hdrs))
-}
-
-func (b *bearerAuthMiddleware) HeadersReceived(ctx context.Context, md
metadata.MD) {
- // apache/arrow-adbc#584
- headers := md.Get("authorization")
- if len(headers) > 0 {
- b.mutex.Lock()
- defer b.mutex.Unlock()
- b.hdrs.Set("authorization", headers...)
- }
-}
-
-func getFlightClient(ctx context.Context, loc string, d *database)
(*flightsql.Client, error) {
- authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()}
- middleware := []flight.ClientMiddleware{
- {
- Unary: makeUnaryLoggingInterceptor(d.logger),
- Stream: makeStreamLoggingInterceptor(d.logger),
- },
- flight.CreateClientMiddleware(authMiddle),
- {
- Unary: unaryTimeoutInterceptor,
- Stream: streamTimeoutInterceptor,
- },
- }
-
- if d.enableCookies {
- middleware = append(middleware,
flight.NewClientCookieMiddleware())
- }
-
- uri, err := url.Parse(loc)
- if err != nil {
- return nil, adbc.Error{Msg: fmt.Sprintf("Invalid URI '%s': %s",
loc, err), Code: adbc.StatusInvalidArgument}
- }
- creds := d.creds
-
- target := uri.Host
- if uri.Scheme == "grpc" || uri.Scheme == "grpc+tcp" {
- creds = insecure.NewCredentials()
- } else if uri.Scheme == "grpc+unix" {
- creds = insecure.NewCredentials()
- target = "unix:" + uri.Path
- }
- dialOpts := append(d.dialOpts.opts,
grpc.WithTransportCredentials(creds))
-
- cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
- if err != nil {
- return nil, adbc.Error{
- Msg: err.Error(),
- Code: adbc.StatusIO,
- }
- }
-
- cl.Alloc = d.alloc
- if d.user != "" || d.pass != "" {
- var header, trailer metadata.MD
- ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user,
d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout)
- if err != nil {
- return nil, adbcFromFlightStatusWithDetails(err,
header, trailer, "AuthenticateBasicToken")
- }
-
- if md, ok := metadata.FromOutgoingContext(ctx); ok {
- // No need to worry about lock here since we are sole
owner
- authMiddle.hdrs.Set("authorization",
md.Get("Authorization")[0])
- }
- }
-
- return cl, nil
-}
-
-type support struct {
- transactions bool
-}
-
-func (d *database) Open(ctx context.Context) (adbc.Connection, error) {
- cl, err := getFlightClient(ctx, d.uri.String(), d)
- if err != nil {
- return nil, err
- }
-
- cache := gcache.New(20).LRU().
- Expiration(5 * time.Minute).
- LoaderFunc(func(loc interface{}) (interface{}, error) {
- uri, ok := loc.(string)
- if !ok {
- return nil, adbc.Error{Msg:
fmt.Sprintf("Location must be a string, got %#v", uri), Code:
adbc.StatusInternal}
- }
-
- cl, err := getFlightClient(context.Background(), uri, d)
- if err != nil {
- return nil, err
- }
-
- cl.Alloc = d.alloc
- return cl, nil
- }).
- EvictedFunc(func(_, client interface{}) {
- conn := client.(*flightsql.Client)
- conn.Close()
- }).Build()
-
- var cnxnSupport support
-
- info, err := cl.GetSqlInfo(ctx,
[]flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, d.timeout)
- // ignore this if it fails
- if err == nil {
- const int32code = 3
-
- for _, endpoint := range info.Endpoint {
- rdr, err := doGet(ctx, cl, endpoint, cache, d.timeout)
- if err != nil {
- continue
- }
- defer rdr.Release()
-
- for rdr.Next() {
- rec := rdr.Record()
- codes := rec.Column(0).(*array.Uint32)
- values := rec.Column(1).(*array.DenseUnion)
- int32Value :=
values.Field(int32code).(*array.Int32)
-
- for i := 0; i < int(rec.NumRows()); i++ {
- switch codes.Value(i) {
- case
uint32(flightsql.SqlInfoFlightSqlServerTransaction):
- if values.TypeCode(i) !=
int32code {
- continue
- }
-
- idx := values.ValueOffset(i)
- if
!int32Value.IsValid(int(idx)) {
- continue
- }
-
- value :=
int32Value.Value(int(idx))
- cnxnSupport.transactions =
- value ==
int32(flightsql.SqlTransactionTransaction) ||
- value ==
int32(flightsql.SqlTransactionSavepoint)
- }
- }
- }
- }
- }
-
- return &cnxn{cl: cl, db: d, clientCache: cache,
- hdrs: make(metadata.MD), timeouts: d.timeout,
- supportInfo: cnxnSupport}, nil
-}
-
type cnxn struct {
cl *flightsql.Client
- db *database
+ db *databaseImpl
clientCache gcache.Cache
hdrs metadata.MD
timeouts timeoutOption
@@ -1290,7 +486,7 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes
[]adbc.InfoCode) (array.Re
func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog
*string, dbSchema *string, tableName *string, columnName *string, tableType
[]string) (array.RecordReader, error) {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog,
DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType:
tableType}
- if err := g.Init(c.db.alloc, c.getObjectsDbSchemas,
c.getObjectsTables); err != nil {
+ if err := g.Init(c.db.Alloc, c.getObjectsDbSchemas,
c.getObjectsTables); err != nil {
return nil, err
}
defer g.Release()
@@ -1335,7 +531,7 @@ func (c *cnxn) GetObjects(ctx context.Context, depth
adbc.ObjectDepth, catalog *
// Helper function to read and validate a metadata stream
func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema,
info *flight.FlightInfo, opts ...grpc.CallOption) (array.RecordReader, error) {
// use a default queueSize for the reader
- rdr, err := newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache,
5, opts...)
+ rdr, err := newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache,
5, opts...)
if err != nil {
return nil, adbcFromFlightStatus(err, "DoGet")
}
@@ -1530,7 +726,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog
*string, dbSchema *st
// 3: table_type: utf8 not null
// 4: table_schema: bytes not null
schemaBytes := rec.Column(4).(*array.Binary).Value(i)
- s, err = flight.DeserializeSchema(schemaBytes,
c.db.alloc)
+ s, err = flight.DeserializeSchema(schemaBytes,
c.db.Alloc)
if err != nil {
return nil, adbcFromFlightStatus(err,
"GetTableSchema")
}
@@ -1559,7 +755,7 @@ func (c *cnxn) GetTableTypes(ctx context.Context)
(array.RecordReader, error) {
return nil, adbcFromFlightStatusWithDetails(err, header,
trailer, "GetTableTypes")
}
- return newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5)
+ return newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache, 5)
}
// Commit commits any pending transactions on this connection, it should
@@ -1633,7 +829,7 @@ func (c *cnxn) Rollback(ctx context.Context) error {
// NewStatement initializes a new statement object tied to this connection
func (c *cnxn) NewStatement() (adbc.Statement, error) {
return &statement{
- alloc: c.db.alloc,
+ alloc: c.db.Alloc,
clientCache: c.clientCache,
hdrs: c.hdrs.Copy(),
queueSize: 5,
diff --git a/go/adbc/driver/flightsql/flightsql_database.go
b/go/adbc/driver/flightsql/flightsql_database.go
new file mode 100644
index 00000000..fc1469e2
--- /dev/null
+++ b/go/adbc/driver/flightsql/flightsql_database.go
@@ -0,0 +1,491 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package flightsql
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "net/url"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+ "github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
+ "github.com/apache/arrow/go/v13/arrow/array"
+ "github.com/apache/arrow/go/v13/arrow/flight"
+ "github.com/apache/arrow/go/v13/arrow/flight/flightsql"
+ "github.com/bluele/gcache"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/credentials/insecure"
+ "google.golang.org/grpc/metadata"
+)
+
+type dbDialOpts struct {
+ opts []grpc.DialOption
+ block bool
+ maxMsgSize int
+ authority string
+}
+
+func (d *dbDialOpts) rebuild() {
+ d.opts = []grpc.DialOption{
+
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(d.maxMsgSize),
+ grpc.MaxCallSendMsgSize(d.maxMsgSize)),
+ grpc.WithUserAgent("ADBC Flight SQL Driver " +
infoDriverVersion),
+ }
+ if d.block {
+ d.opts = append(d.opts, grpc.WithBlock())
+ }
+ if d.authority != "" {
+ d.opts = append(d.opts, grpc.WithAuthority(d.authority))
+ }
+}
+
+type databaseImpl struct {
+ driverbase.DatabaseImplBase
+
+ uri *url.URL
+ creds credentials.TransportCredentials
+ user, pass string
+ hdrs metadata.MD
+ timeout timeoutOption
+ dialOpts dbDialOpts
+ enableCookies bool
+ options map[string]string
+}
+
+func (d *databaseImpl) SetOptions(cnOptions map[string]string) error {
+ var tlsConfig tls.Config
+
+ for k, v := range cnOptions {
+ d.options[k] = v
+ }
+
+ if authority, ok := cnOptions[OptionAuthority]; ok {
+ d.dialOpts.authority = authority
+ delete(cnOptions, OptionAuthority)
+ }
+
+ mtlsCert := cnOptions[OptionMTLSCertChain]
+ mtlsKey := cnOptions[OptionMTLSPrivateKey]
+ switch {
+ case mtlsCert != "" && mtlsKey != "":
+ cert, err := tls.X509KeyPair([]byte(mtlsCert), []byte(mtlsKey))
+ if err != nil {
+ return adbc.Error{
+ Msg: fmt.Sprintf("Invalid mTLS certificate:
%#v", err),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ tlsConfig.Certificates = []tls.Certificate{cert}
+ delete(cnOptions, OptionMTLSCertChain)
+ delete(cnOptions, OptionMTLSPrivateKey)
+ case mtlsCert != "":
+ return adbc.Error{
+ Msg: fmt.Sprintf("Must provide both '%s' and '%s',
only provided '%s'", OptionMTLSCertChain, OptionMTLSPrivateKey,
OptionMTLSCertChain),
+ Code: adbc.StatusInvalidArgument,
+ }
+ case mtlsKey != "":
+ return adbc.Error{
+ Msg: fmt.Sprintf("Must provide both '%s' and '%s',
only provided '%s'", OptionMTLSCertChain, OptionMTLSPrivateKey,
OptionMTLSPrivateKey),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ if hostname, ok := cnOptions[OptionSSLOverrideHostname]; ok {
+ tlsConfig.ServerName = hostname
+ delete(cnOptions, OptionSSLOverrideHostname)
+ }
+
+ if val, ok := cnOptions[OptionSSLSkipVerify]; ok {
+ if val == adbc.OptionValueEnabled {
+ tlsConfig.InsecureSkipVerify = true
+ } else if val == adbc.OptionValueDisabled {
+ tlsConfig.InsecureSkipVerify = false
+ } else {
+ return adbc.Error{
+ Msg: fmt.Sprintf("Invalid value for database
option '%s': '%s'", OptionSSLSkipVerify, val),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ delete(cnOptions, OptionSSLSkipVerify)
+ }
+
+ if cert, ok := cnOptions[OptionSSLRootCerts]; ok {
+ cp := x509.NewCertPool()
+ if !cp.AppendCertsFromPEM([]byte(cert)) {
+ return adbc.Error{
+ Msg: fmt.Sprintf("Invalid value for database
option '%s': failed to append certificates", OptionSSLRootCerts),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ tlsConfig.RootCAs = cp
+ delete(cnOptions, OptionSSLRootCerts)
+ }
+
+ d.creds = credentials.NewTLS(&tlsConfig)
+
+ if auth, ok := cnOptions[OptionAuthorizationHeader]; ok {
+ d.hdrs.Set("authorization", auth)
+ delete(cnOptions, OptionAuthorizationHeader)
+ }
+
+ if u, ok := cnOptions[adbc.OptionKeyUsername]; ok {
+ if d.hdrs.Len() > 0 {
+ return adbc.Error{
+ Msg: "Authorization header already provided,
do not provide user/pass also",
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ d.user = u
+ delete(cnOptions, adbc.OptionKeyUsername)
+ }
+
+ if p, ok := cnOptions[adbc.OptionKeyPassword]; ok {
+ if d.hdrs.Len() > 0 {
+ return adbc.Error{
+ Msg: "Authorization header already provided,
do not provide user/pass also",
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ d.pass = p
+ delete(cnOptions, adbc.OptionKeyPassword)
+ }
+
+ var err error
+ if tv, ok := cnOptions[OptionTimeoutFetch]; ok {
+ if err = d.timeout.setTimeoutString(OptionTimeoutFetch, tv);
err != nil {
+ return err
+ }
+ delete(cnOptions, OptionTimeoutFetch)
+ }
+
+ if tv, ok := cnOptions[OptionTimeoutQuery]; ok {
+ if err = d.timeout.setTimeoutString(OptionTimeoutQuery, tv);
err != nil {
+ return err
+ }
+ delete(cnOptions, OptionTimeoutQuery)
+ }
+
+ if tv, ok := cnOptions[OptionTimeoutUpdate]; ok {
+ if err = d.timeout.setTimeoutString(OptionTimeoutUpdate, tv);
err != nil {
+ return err
+ }
+ delete(cnOptions, OptionTimeoutUpdate)
+ }
+
+ if val, ok := cnOptions[OptionWithBlock]; ok {
+ if val == adbc.OptionValueEnabled {
+ d.dialOpts.block = true
+ } else if val == adbc.OptionValueDisabled {
+ d.dialOpts.block = false
+ } else {
+ return adbc.Error{
+ Msg: fmt.Sprintf("Invalid value for database
option '%s': '%s'", OptionWithBlock, val),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ delete(cnOptions, OptionWithBlock)
+ }
+
+ if val, ok := cnOptions[OptionWithMaxMsgSize]; ok {
+ var err error
+ var size int
+ if size, err = strconv.Atoi(val); err != nil {
+ return adbc.Error{
+ Msg: fmt.Sprintf("Invalid value for database
option '%s': '%s' is not a positive integer", OptionWithMaxMsgSize, val),
+ Code: adbc.StatusInvalidArgument,
+ }
+ } else if size <= 0 {
+ return adbc.Error{
+ Msg: fmt.Sprintf("Invalid value for database
option '%s': '%s' is not a positive integer", OptionWithMaxMsgSize, val),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ d.dialOpts.maxMsgSize = size
+ delete(cnOptions, OptionWithMaxMsgSize)
+ }
+ d.dialOpts.rebuild()
+
+ if val, ok := cnOptions[OptionCookieMiddleware]; ok {
+ if val == adbc.OptionValueEnabled {
+ d.enableCookies = true
+ } else if val == adbc.OptionValueDisabled {
+ d.enableCookies = false
+ } else {
+ return d.ErrorHelper.Errorf(adbc.StatusInvalidArgument,
"Invalid value for database option '%s': '%s'", OptionCookieMiddleware, val)
+ }
+ delete(cnOptions, OptionCookieMiddleware)
+ }
+
+ for key, val := range cnOptions {
+ if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
+ d.hdrs.Append(strings.TrimPrefix(key,
OptionRPCCallHeaderPrefix), val)
+ continue
+ }
+ return d.ErrorHelper.Errorf(adbc.StatusInvalidArgument,
"[Flight SQL] Unknown database option '%s'", key)
+ }
+
+ return nil
+}
+
+func (d *databaseImpl) GetOption(key string) (string, error) {
+ switch key {
+ case OptionTimeoutFetch:
+ return d.timeout.fetchTimeout.String(), nil
+ case OptionTimeoutQuery:
+ return d.timeout.queryTimeout.String(), nil
+ case OptionTimeoutUpdate:
+ return d.timeout.updateTimeout.String(), nil
+ }
+ if val, ok := d.options[key]; ok {
+ return val, nil
+ }
+ return d.DatabaseImplBase.GetOption(key)
+}
+
+func (d *databaseImpl) GetOptionInt(key string) (int64, error) {
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ val, err := d.GetOptionDouble(key)
+ if err != nil {
+ return 0, err
+ }
+ return int64(val), nil
+ }
+
+ return d.DatabaseImplBase.GetOptionInt(key)
+}
+
+func (d *databaseImpl) GetOptionDouble(key string) (float64, error) {
+ switch key {
+ case OptionTimeoutFetch:
+ return d.timeout.fetchTimeout.Seconds(), nil
+ case OptionTimeoutQuery:
+ return d.timeout.queryTimeout.Seconds(), nil
+ case OptionTimeoutUpdate:
+ return d.timeout.updateTimeout.Seconds(), nil
+ }
+
+ return d.DatabaseImplBase.GetOptionDouble(key)
+}
+
+func (d *databaseImpl) SetOption(key, value string) error {
+ // We can't change most options post-init
+ switch key {
+ case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate:
+ return d.timeout.setTimeoutString(key, value)
+ }
+ if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
+ d.hdrs.Set(strings.TrimPrefix(key, OptionRPCCallHeaderPrefix),
value)
+ }
+ return d.DatabaseImplBase.SetOption(key, value)
+}
+
+func (d *databaseImpl) SetOptionInt(key string, value int64) error {
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ return d.timeout.setTimeout(key, float64(value))
+ }
+
+ return d.DatabaseImplBase.SetOptionInt(key, value)
+}
+
+func (d *databaseImpl) SetOptionDouble(key string, value float64) error {
+ switch key {
+ case OptionTimeoutFetch:
+ fallthrough
+ case OptionTimeoutQuery:
+ fallthrough
+ case OptionTimeoutUpdate:
+ return d.timeout.setTimeout(key, value)
+ }
+
+ return d.DatabaseImplBase.SetOptionDouble(key, value)
+}
+
+func getFlightClient(ctx context.Context, loc string, d *databaseImpl)
(*flightsql.Client, error) {
+ authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()}
+ middleware := []flight.ClientMiddleware{
+ {
+ Unary: makeUnaryLoggingInterceptor(d.Logger),
+ Stream: makeStreamLoggingInterceptor(d.Logger),
+ },
+ flight.CreateClientMiddleware(authMiddle),
+ {
+ Unary: unaryTimeoutInterceptor,
+ Stream: streamTimeoutInterceptor,
+ },
+ }
+
+ if d.enableCookies {
+ middleware = append(middleware,
flight.NewClientCookieMiddleware())
+ }
+
+ uri, err := url.Parse(loc)
+ if err != nil {
+ return nil, adbc.Error{Msg: fmt.Sprintf("Invalid URI '%s': %s",
loc, err), Code: adbc.StatusInvalidArgument}
+ }
+ creds := d.creds
+
+ target := uri.Host
+ if uri.Scheme == "grpc" || uri.Scheme == "grpc+tcp" {
+ creds = insecure.NewCredentials()
+ } else if uri.Scheme == "grpc+unix" {
+ creds = insecure.NewCredentials()
+ target = "unix:" + uri.Path
+ }
+ dialOpts := append(d.dialOpts.opts,
grpc.WithTransportCredentials(creds))
+
+ cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
+ if err != nil {
+ return nil, adbc.Error{
+ Msg: err.Error(),
+ Code: adbc.StatusIO,
+ }
+ }
+
+ cl.Alloc = d.Alloc
+ if d.user != "" || d.pass != "" {
+ var header, trailer metadata.MD
+ ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user,
d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout)
+ if err != nil {
+ return nil, adbcFromFlightStatusWithDetails(err,
header, trailer, "AuthenticateBasicToken")
+ }
+
+ if md, ok := metadata.FromOutgoingContext(ctx); ok {
+ // No need to worry about lock here since we are sole
owner
+ authMiddle.hdrs.Set("authorization",
md.Get("Authorization")[0])
+ }
+ }
+
+ return cl, nil
+}
+
+type support struct {
+ transactions bool
+}
+
+func (impl *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
+ cl, err := getFlightClient(ctx, impl.uri.String(), impl)
+ if err != nil {
+ return nil, err
+ }
+
+ cache := gcache.New(20).LRU().
+ Expiration(5 * time.Minute).
+ LoaderFunc(func(loc interface{}) (interface{}, error) {
+ uri, ok := loc.(string)
+ if !ok {
+ return nil, adbc.Error{Msg:
fmt.Sprintf("Location must be a string, got %#v", uri), Code:
adbc.StatusInternal}
+ }
+
+ cl, err := getFlightClient(context.Background(), uri,
impl)
+ if err != nil {
+ return nil, err
+ }
+
+ cl.Alloc = impl.Alloc
+ return cl, nil
+ }).
+ EvictedFunc(func(_, client interface{}) {
+ conn := client.(*flightsql.Client)
+ conn.Close()
+ }).Build()
+
+ var cnxnSupport support
+
+ info, err := cl.GetSqlInfo(ctx,
[]flightsql.SqlInfo{flightsql.SqlInfoFlightSqlServerTransaction}, impl.timeout)
+ // ignore this if it fails
+ if err == nil {
+ const int32code = 3
+
+ for _, endpoint := range info.Endpoint {
+ rdr, err := doGet(ctx, cl, endpoint, cache,
impl.timeout)
+ if err != nil {
+ continue
+ }
+ defer rdr.Release()
+
+ for rdr.Next() {
+ rec := rdr.Record()
+ codes := rec.Column(0).(*array.Uint32)
+ values := rec.Column(1).(*array.DenseUnion)
+ int32Value :=
values.Field(int32code).(*array.Int32)
+
+ for i := 0; i < int(rec.NumRows()); i++ {
+ switch codes.Value(i) {
+ case
uint32(flightsql.SqlInfoFlightSqlServerTransaction):
+ if values.TypeCode(i) !=
int32code {
+ continue
+ }
+
+ idx := values.ValueOffset(i)
+ if
!int32Value.IsValid(int(idx)) {
+ continue
+ }
+
+ value :=
int32Value.Value(int(idx))
+ cnxnSupport.transactions =
+ value ==
int32(flightsql.SqlTransactionTransaction) ||
+ value ==
int32(flightsql.SqlTransactionSavepoint)
+ }
+ }
+ }
+ }
+ }
+
+ return &cnxn{cl: cl, db: impl, clientCache: cache,
+ hdrs: make(metadata.MD), timeouts: impl.timeout,
+ supportInfo: cnxnSupport}, nil
+}
+
+type bearerAuthMiddleware struct {
+ mutex sync.RWMutex
+ hdrs metadata.MD
+}
+
+func (b *bearerAuthMiddleware) StartCall(ctx context.Context) context.Context {
+ md, _ := metadata.FromOutgoingContext(ctx)
+ b.mutex.RLock()
+ defer b.mutex.RUnlock()
+ return metadata.NewOutgoingContext(ctx, metadata.Join(md, b.hdrs))
+}
+
+func (b *bearerAuthMiddleware) HeadersReceived(ctx context.Context, md
metadata.MD) {
+ // apache/arrow-adbc#584
+ headers := md.Get("authorization")
+ if len(headers) > 0 {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+ b.hdrs.Set("authorization", headers...)
+ }
+}
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go
b/go/adbc/driver/flightsql/flightsql_driver.go
new file mode 100644
index 00000000..f556dcf6
--- /dev/null
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -0,0 +1,149 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Package flightsql is an ADBC Driver Implementation for Flight SQL
+// natively in go.
+//
+// It can be used to register a driver for database/sql by importing
+// github.com/apache/arrow-adbc/go/adbc/sqldriver and running:
+//
+// sql.Register("flightsql", sqldriver.Driver{flightsql.Driver{}})
+//
+// You can then open a flightsql connection with the database/sql
+// standard package by using:
+//
+// db, err := sql.Open("flightsql", "uri=<flight sql db url>")
+//
+// The URI passed *must* contain a scheme, most likely "grpc+tcp://"
+package flightsql
+
+import (
+ "net/url"
+ "runtime/debug"
+ "strings"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+ "github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
+ "github.com/apache/arrow/go/v13/arrow/memory"
+ "golang.org/x/exp/maps"
+ "google.golang.org/grpc/metadata"
+)
+
+const (
+ OptionAuthority = "adbc.flight.sql.client_option.authority"
+ OptionMTLSCertChain =
"adbc.flight.sql.client_option.mtls_cert_chain"
+ OptionMTLSPrivateKey =
"adbc.flight.sql.client_option.mtls_private_key"
+ OptionSSLOverrideHostname =
"adbc.flight.sql.client_option.tls_override_hostname"
+ OptionSSLSkipVerify =
"adbc.flight.sql.client_option.tls_skip_verify"
+ OptionSSLRootCerts =
"adbc.flight.sql.client_option.tls_root_certs"
+ OptionWithBlock = "adbc.flight.sql.client_option.with_block"
+ OptionWithMaxMsgSize =
"adbc.flight.sql.client_option.with_max_msg_size"
+ OptionAuthorizationHeader = "adbc.flight.sql.authorization_header"
+ OptionTimeoutFetch = "adbc.flight.sql.rpc.timeout_seconds.fetch"
+ OptionTimeoutQuery = "adbc.flight.sql.rpc.timeout_seconds.query"
+ OptionTimeoutUpdate = "adbc.flight.sql.rpc.timeout_seconds.update"
+ OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header."
+ OptionCookieMiddleware = "adbc.flight.sql.rpc.with_cookie_middleware"
+ infoDriverName = "ADBC Flight SQL Driver - Go"
+)
+
+var (
+ infoDriverVersion string
+ infoDriverArrowVersion string
+ infoSupportedCodes []adbc.InfoCode
+)
+
+var errNoTransactionSupport = adbc.Error{
+ Msg: "[Flight SQL] server does not report transaction support",
+ Code: adbc.StatusNotImplemented,
+}
+
+func init() {
+ if info, ok := debug.ReadBuildInfo(); ok {
+ for _, dep := range info.Deps {
+ switch {
+ case dep.Path ==
"github.com/apache/arrow-adbc/go/adbc/driver/flightsql":
+ infoDriverVersion = dep.Version
+ case strings.HasPrefix(dep.Path,
"github.com/apache/arrow/go/"):
+ infoDriverArrowVersion = dep.Version
+ }
+ }
+ }
+ // XXX: Deps not populated in tests
+ // https://github.com/golang/go/issues/33976
+ if infoDriverVersion == "" {
+ infoDriverVersion = "(unknown or development build)"
+ }
+ if infoDriverArrowVersion == "" {
+ infoDriverArrowVersion = "(unknown or development build)"
+ }
+
+ infoSupportedCodes = []adbc.InfoCode{
+ adbc.InfoDriverName,
+ adbc.InfoDriverVersion,
+ adbc.InfoDriverArrowVersion,
+ adbc.InfoDriverADBCVersion,
+ adbc.InfoVendorName,
+ adbc.InfoVendorVersion,
+ adbc.InfoVendorArrowVersion,
+ }
+}
+
+type driverImpl struct {
+ driverbase.DriverImplBase
+}
+
+// NewDriver creates a new Flight SQL driver using the given Arrow allocator.
+func NewDriver(alloc memory.Allocator) adbc.Driver {
+ impl := driverImpl{DriverImplBase: driverbase.NewDriverImplBase("Flight
SQL", alloc)}
+ return driverbase.NewDriver(&impl)
+}
+
+func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database,
error) {
+ opts = maps.Clone(opts)
+ uri, ok := opts[adbc.OptionKeyURI]
+ if !ok {
+ return nil, adbc.Error{
+ Msg: "URI required for a FlightSQL DB",
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ delete(opts, adbc.OptionKeyURI)
+
+ db := &databaseImpl{
+ DatabaseImplBase:
driverbase.NewDatabaseImplBase(&d.DriverImplBase),
+ hdrs: make(metadata.MD),
+ }
+
+ var err error
+ if db.uri, err = url.Parse(uri); err != nil {
+ return nil, adbc.Error{Msg: err.Error(), Code:
adbc.StatusInvalidArgument}
+ }
+
+ // Do not set WithBlock since it converts some types of connection
+ // errors to infinite hangs
+ // Use WithMaxMsgSize(16 MiB) since Flight services tend to send large
messages
+ db.dialOpts.block = false
+ db.dialOpts.maxMsgSize = 16 * 1024 * 1024
+
+ db.options = make(map[string]string)
+
+ if err := db.SetOptions(opts); err != nil {
+ return nil, err
+ }
+ return driverbase.NewDatabase(db), nil
+}
diff --git a/go/adbc/driver/flightsql/logging.go
b/go/adbc/driver/flightsql/logging.go
index 3d0f51c8..6957218b 100644
--- a/go/adbc/driver/flightsql/logging.go
+++ b/go/adbc/driver/flightsql/logging.go
@@ -20,7 +20,6 @@ package flightsql
import (
"context"
"io"
- "os"
"time"
"golang.org/x/exp/maps"
@@ -30,14 +29,6 @@ import (
"google.golang.org/grpc/metadata"
)
-func nilLogger() *slog.Logger {
- h := slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
- AddSource: false,
- Level: slog.LevelError,
- })
- return slog.New(h)
-}
-
func makeUnaryLoggingInterceptor(logger *slog.Logger)
grpc.UnaryClientInterceptor {
interceptor := func(ctx context.Context, method string, req, reply any,
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
start := time.Now()
diff --git a/go/adbc/driver/flightsql/timeouts.go
b/go/adbc/driver/flightsql/timeouts.go
new file mode 100644
index 00000000..77375268
--- /dev/null
+++ b/go/adbc/driver/flightsql/timeouts.go
@@ -0,0 +1,223 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package flightsql
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "math"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/metadata"
+)
+
+type timeoutOption struct {
+ grpc.EmptyCallOption
+
+ // timeout for DoGet requests
+ fetchTimeout time.Duration
+ // timeout for GetFlightInfo requests
+ queryTimeout time.Duration
+ // timeout for DoPut or DoAction requests
+ updateTimeout time.Duration
+}
+
+func (t *timeoutOption) setTimeout(key string, value float64) error {
+ if math.IsNaN(value) || math.IsInf(value, 0) || value < 0 {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] invalid timeout option
value %s = %f: timeouts must be non-negative and finite",
+ key, value),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ timeout := time.Duration(value * float64(time.Second))
+
+ switch key {
+ case OptionTimeoutFetch:
+ t.fetchTimeout = timeout
+ case OptionTimeoutQuery:
+ t.queryTimeout = timeout
+ case OptionTimeoutUpdate:
+ t.updateTimeout = timeout
+ default:
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] Unknown timeout option
'%s'", key),
+ Code: adbc.StatusNotImplemented,
+ }
+ }
+ return nil
+}
+
+func (t *timeoutOption) setTimeoutString(key string, value string) error {
+ timeout, err := strconv.ParseFloat(value, 64)
+ if err != nil {
+ return adbc.Error{
+ Msg: fmt.Sprintf("[Flight SQL] invalid timeout option
value %s = %s: %s",
+ key, value, err.Error()),
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+ return t.setTimeout(key, timeout)
+}
+
+func getTimeout(method string, callOptions []grpc.CallOption) (time.Duration,
bool) {
+ for _, opt := range callOptions {
+ if to, ok := opt.(timeoutOption); ok {
+ var tm time.Duration
+ switch {
+ case strings.HasSuffix(method, "DoGet"):
+ tm = to.fetchTimeout
+ case strings.HasSuffix(method, "GetFlightInfo"):
+ tm = to.queryTimeout
+ case strings.HasSuffix(method, "DoPut") ||
strings.HasSuffix(method, "DoAction"):
+ tm = to.updateTimeout
+ }
+
+ return tm, tm > 0
+ }
+ }
+
+ return 0, false
+}
+
+func unaryTimeoutInterceptor(ctx context.Context, method string, req, reply
any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption)
error {
+ if tm, ok := getTimeout(method, opts); ok {
+ ctx, cancel := context.WithTimeout(ctx, tm)
+ defer cancel()
+ return invoker(ctx, method, req, reply, cc, opts...)
+ }
+
+ return invoker(ctx, method, req, reply, cc, opts...)
+}
+
+type streamEventType int
+
+const (
+ receiveEndEvent streamEventType = iota
+ errorEvent
+)
+
+type streamEvent struct {
+ Type streamEventType
+ Err error
+}
+
+type wrappedClientStream struct {
+ grpc.ClientStream
+
+ desc *grpc.StreamDesc
+ events chan streamEvent
+ eventsDone chan struct{}
+}
+
+func (w *wrappedClientStream) RecvMsg(m any) error {
+ err := w.ClientStream.RecvMsg(m)
+
+ switch {
+ case err == nil && !w.desc.ServerStreams:
+ w.sendStreamEvent(receiveEndEvent, nil)
+ case err == io.EOF:
+ w.sendStreamEvent(receiveEndEvent, nil)
+ case err != nil:
+ w.sendStreamEvent(errorEvent, err)
+ }
+
+ return err
+}
+
+func (w *wrappedClientStream) SendMsg(m any) error {
+ err := w.ClientStream.SendMsg(m)
+ if err != nil {
+ w.sendStreamEvent(errorEvent, err)
+ }
+ return err
+}
+
+func (w *wrappedClientStream) Header() (metadata.MD, error) {
+ md, err := w.ClientStream.Header()
+ if err != nil {
+ w.sendStreamEvent(errorEvent, err)
+ }
+ return md, err
+}
+
+func (w *wrappedClientStream) CloseSend() error {
+ err := w.ClientStream.CloseSend()
+ if err != nil {
+ w.sendStreamEvent(errorEvent, err)
+ }
+ return err
+}
+
+func (w *wrappedClientStream) sendStreamEvent(eventType streamEventType, err
error) {
+ select {
+ case <-w.eventsDone:
+ case w.events <- streamEvent{Type: eventType, Err: err}:
+ }
+}
+
+func streamTimeoutInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc
*grpc.ClientConn, method string, streamer grpc.Streamer, opts
...grpc.CallOption) (grpc.ClientStream, error) {
+ if tm, ok := getTimeout(method, opts); ok {
+ ctx, cancel := context.WithTimeout(ctx, tm)
+ s, err := streamer(ctx, desc, cc, method, opts...)
+ if err != nil {
+ defer cancel()
+ return s, err
+ }
+
+ events, eventsDone := make(chan streamEvent), make(chan
struct{})
+ go func() {
+ defer close(eventsDone)
+ defer cancel()
+
+ for {
+ select {
+ case event := <-events:
+ // split by event type in case we want
to add more logging
+ // or even adding in some telemetry in
the future.
+ // Errors will already be propagated by
the RecvMsg, SendMsg
+ // methods.
+ switch event.Type {
+ case receiveEndEvent:
+ return
+ case errorEvent:
+ return
+ }
+ case <-ctx.Done():
+ return
+ }
+ }
+ }()
+
+ stream := &wrappedClientStream{
+ ClientStream: s,
+ desc: desc,
+ events: events,
+ eventsDone: eventsDone,
+ }
+ return stream, nil
+ }
+
+ return streamer(ctx, desc, cc, method, opts...)
+}
diff --git a/go/adbc/driver/panicdummy/panicdummy_adbc.go
b/go/adbc/driver/panicdummy/panicdummy_adbc.go
index 57d8b993..2a6e71ac 100644
--- a/go/adbc/driver/panicdummy/panicdummy_adbc.go
+++ b/go/adbc/driver/panicdummy/panicdummy_adbc.go
@@ -44,6 +44,11 @@ type Driver struct {
Alloc memory.Allocator
}
+// NewDriver creates a new PanicDummy driver using the given Arrow allocator.
+func NewDriver(alloc memory.Allocator) adbc.Driver {
+ return Driver{Alloc: alloc}
+}
+
func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
maybePanic("NewDatabase")
return &database{}, nil
diff --git a/go/adbc/driver/snowflake/connection.go
b/go/adbc/driver/snowflake/connection.go
index c8f1dc75..a96f104b 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -51,7 +51,7 @@ type snowflakeConn interface {
type cnxn struct {
cn snowflakeConn
- db *database
+ db *databaseImpl
ctor gosnowflake.Connector
sqldb *sql.DB
@@ -107,7 +107,7 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes
[]adbc.InfoCode) (array.Re
infoCodes = infoSupportedCodes
}
- bldr := array.NewRecordBuilder(c.db.alloc, adbc.GetInfoSchema)
+ bldr := array.NewRecordBuilder(c.db.Alloc, adbc.GetInfoSchema)
defer bldr.Release()
bldr.Reserve(len(infoCodes))
@@ -238,7 +238,7 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes
[]adbc.InfoCode) (array.Re
// earlier).
func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog
*string, dbSchema *string, tableName *string, columnName *string, tableType
[]string) (array.RecordReader, error) {
g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog,
DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType:
tableType}
- if err := g.Init(c.db.alloc, c.getObjectsDbSchemas,
c.getObjectsTables); err != nil {
+ if err := g.Init(c.db.Alloc, c.getObjectsDbSchemas,
c.getObjectsTables); err != nil {
return nil, err
}
defer g.Release()
@@ -903,7 +903,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog
*string, dbSchema *st
// ----------------|--------------
// table_type | utf8 not null
func (c *cnxn) GetTableTypes(_ context.Context) (array.RecordReader, error) {
- bldr := array.NewRecordBuilder(c.db.alloc, adbc.TableTypesSchema)
+ bldr := array.NewRecordBuilder(c.db.Alloc, adbc.TableTypesSchema)
defer bldr.Release()
bldr.Field(0).(*array.StringBuilder).AppendValues([]string{"BASE
TABLE", "TEMPORARY TABLE", "VIEW"}, nil)
@@ -957,7 +957,7 @@ func (c *cnxn) Rollback(_ context.Context) error {
// NewStatement initializes a new statement object tied to this connection
func (c *cnxn) NewStatement() (adbc.Statement, error) {
return &statement{
- alloc: c.db.alloc,
+ alloc: c.db.Alloc,
cnxn: c,
queueSize: defaultStatementQueueSize,
prefetchConcurrency: defaultPrefetchConcurrency,
diff --git a/go/adbc/driver/snowflake/driver.go
b/go/adbc/driver/snowflake/driver.go
index b4f063f7..738dd0e0 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -18,19 +18,12 @@
package snowflake
import (
- "context"
- "crypto/x509"
- "database/sql"
"errors"
- "fmt"
- "net/url"
- "os"
"runtime/debug"
- "strconv"
"strings"
- "time"
"github.com/apache/arrow-adbc/go/adbc"
+ "github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
"github.com/apache/arrow/go/v13/arrow/memory"
"github.com/snowflakedb/gosnowflake"
"golang.org/x/exp/maps"
@@ -181,389 +174,20 @@ func errToAdbcErr(code adbc.Status, err error) error {
}
}
-type Driver struct {
- Alloc memory.Allocator
+type driverImpl struct {
+ driverbase.DriverImplBase
}
-func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
- db := &database{alloc: d.Alloc}
-
- opts = maps.Clone(opts)
- if db.alloc == nil {
- db.alloc = memory.DefaultAllocator
- }
-
- return db, db.SetOptions(opts)
-}
-
-var (
- drv = gosnowflake.SnowflakeDriver{}
- authTypeMap = map[string]gosnowflake.AuthType{
- OptionValueAuthSnowflake: gosnowflake.AuthTypeSnowflake,
- OptionValueAuthOAuth: gosnowflake.AuthTypeOAuth,
- OptionValueAuthExternalBrowser:
gosnowflake.AuthTypeExternalBrowser,
- OptionValueAuthOkta: gosnowflake.AuthTypeOkta,
- OptionValueAuthJwt: gosnowflake.AuthTypeJwt,
- OptionValueAuthUserPassMFA:
gosnowflake.AuthTypeUsernamePasswordMFA,
- }
-)
-
-type database struct {
- cfg *gosnowflake.Config
- alloc memory.Allocator
-}
-
-func (d *database) GetOption(key string) (string, error) {
- switch key {
- case adbc.OptionKeyUsername:
- return d.cfg.User, nil
- case adbc.OptionKeyPassword:
- return d.cfg.Password, nil
- case OptionDatabase:
- return d.cfg.Database, nil
- case OptionSchema:
- return d.cfg.Schema, nil
- case OptionWarehouse:
- return d.cfg.Warehouse, nil
- case OptionRole:
- return d.cfg.Role, nil
- case OptionRegion:
- return d.cfg.Region, nil
- case OptionAccount:
- return d.cfg.Account, nil
- case OptionProtocol:
- return d.cfg.Protocol, nil
- case OptionHost:
- return d.cfg.Host, nil
- case OptionPort:
- return strconv.Itoa(d.cfg.Port), nil
- case OptionAuthType:
- return d.cfg.Authenticator.String(), nil
- case OptionLoginTimeout:
- return strconv.FormatFloat(d.cfg.LoginTimeout.Seconds(), 'f',
-1, 64), nil
- case OptionRequestTimeout:
- return strconv.FormatFloat(d.cfg.RequestTimeout.Seconds(), 'f',
-1, 64), nil
- case OptionJwtExpireTimeout:
- return strconv.FormatFloat(d.cfg.JWTExpireTimeout.Seconds(),
'f', -1, 64), nil
- case OptionClientTimeout:
- return strconv.FormatFloat(d.cfg.ClientTimeout.Seconds(), 'f',
-1, 64), nil
- case OptionApplicationName:
- return d.cfg.Application, nil
- case OptionSSLSkipVerify:
- if d.cfg.InsecureMode {
- return adbc.OptionValueEnabled, nil
- }
- return adbc.OptionValueDisabled, nil
- case OptionOCSPFailOpenMode:
- return strconv.FormatUint(uint64(d.cfg.OCSPFailOpen), 10), nil
- case OptionAuthToken:
- return d.cfg.Token, nil
- case OptionAuthOktaUrl:
- return d.cfg.OktaURL.String(), nil
- case OptionKeepSessionAlive:
- if d.cfg.KeepSessionAlive {
- return adbc.OptionValueEnabled, nil
- }
- return adbc.OptionValueDisabled, nil
- case OptionDisableTelemetry:
- if d.cfg.DisableTelemetry {
- return adbc.OptionValueEnabled, nil
- }
- return adbc.OptionValueDisabled, nil
- case OptionClientRequestMFAToken:
- if d.cfg.ClientRequestMfaToken == gosnowflake.ConfigBoolTrue {
- return adbc.OptionValueEnabled, nil
- }
- return adbc.OptionValueDisabled, nil
- case OptionClientStoreTempCred:
- if d.cfg.ClientStoreTemporaryCredential ==
gosnowflake.ConfigBoolTrue {
- return adbc.OptionValueEnabled, nil
- }
- return adbc.OptionValueDisabled, nil
- case OptionLogTracing:
- return d.cfg.Tracing, nil
- default:
- val, ok := d.cfg.Params[key]
- if ok {
- return *val, nil
- }
- }
- return "", adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
-}
-func (d *database) GetOptionBytes(key string) ([]byte, error) {
- return nil, adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
-}
-func (d *database) GetOptionInt(key string) (int64, error) {
- return 0, adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
-}
-func (d *database) GetOptionDouble(key string) (float64, error) {
- return 0, adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
+// NewDriver creates a new Snowflake driver using the given Arrow allocator.
+func NewDriver(alloc memory.Allocator) adbc.Driver {
+ return driverbase.NewDriver(&driverImpl{DriverImplBase:
driverbase.NewDriverImplBase("Snowflake", alloc)})
}
-func (d *database) SetOptions(cnOptions map[string]string) error {
- uri, ok := cnOptions[adbc.OptionKeyURI]
- if ok {
- cfg, err := gosnowflake.ParseDSN(uri)
- if err != nil {
- return errToAdbcErr(adbc.StatusInvalidArgument, err)
- }
-
- d.cfg = cfg
- delete(cnOptions, adbc.OptionKeyURI)
- } else {
- d.cfg = &gosnowflake.Config{
- Params: make(map[string]*string),
- }
- }
-
- var err error
- for k, v := range cnOptions {
- v := v // copy into loop scope
- switch k {
- case adbc.OptionKeyUsername:
- d.cfg.User = v
- case adbc.OptionKeyPassword:
- d.cfg.Password = v
- case OptionDatabase:
- d.cfg.Database = v
- case OptionSchema:
- d.cfg.Schema = v
- case OptionWarehouse:
- d.cfg.Warehouse = v
- case OptionRole:
- d.cfg.Role = v
- case OptionRegion:
- d.cfg.Region = v
- case OptionAccount:
- d.cfg.Account = v
- case OptionProtocol:
- d.cfg.Protocol = v
- case OptionHost:
- d.cfg.Host = v
- case OptionPort:
- d.cfg.Port, err = strconv.Atoi(v)
- if err != nil {
- return adbc.Error{
- Msg: "error encountered parsing Port
option: " + err.Error(),
- Code: adbc.StatusInvalidArgument,
- }
- }
- case OptionAuthType:
- d.cfg.Authenticator, ok = authTypeMap[v]
- if !ok {
- return adbc.Error{
- Msg: "invalid option value for " +
OptionAuthType + ": '" + v + "'",
- Code: adbc.StatusInvalidArgument,
- }
- }
- case OptionLoginTimeout:
- dur, err := time.ParseDuration(v)
- if err != nil {
- return adbc.Error{
- Msg: "could not parse duration for '"
+ OptionLoginTimeout + "': " + err.Error(),
- Code: adbc.StatusInvalidArgument,
- }
- }
- if dur < 0 {
- dur = -dur
- }
- d.cfg.LoginTimeout = dur
- case OptionRequestTimeout:
- dur, err := time.ParseDuration(v)
- if err != nil {
- return adbc.Error{
- Msg: "could not parse duration for '"
+ OptionRequestTimeout + "': " + err.Error(),
- Code: adbc.StatusInvalidArgument,
- }
- }
- if dur < 0 {
- dur = -dur
- }
- d.cfg.RequestTimeout = dur
- case OptionJwtExpireTimeout:
- dur, err := time.ParseDuration(v)
- if err != nil {
- return adbc.Error{
- Msg: "could not parse duration for '"
+ OptionJwtExpireTimeout + "': " + err.Error(),
- Code: adbc.StatusInvalidArgument,
- }
- }
- if dur < 0 {
- dur = -dur
- }
- d.cfg.JWTExpireTimeout = dur
- case OptionClientTimeout:
- dur, err := time.ParseDuration(v)
- if err != nil {
- return adbc.Error{
- Msg: "could not parse duration for '"
+ OptionClientTimeout + "': " + err.Error(),
- Code: adbc.StatusInvalidArgument,
- }
- }
- if dur < 0 {
- dur = -dur
- }
- d.cfg.ClientTimeout = dur
- case OptionApplicationName:
- d.cfg.Application = v
- case OptionSSLSkipVerify:
- switch v {
- case adbc.OptionValueEnabled:
- d.cfg.InsecureMode = true
- case adbc.OptionValueDisabled:
- d.cfg.InsecureMode = false
- default:
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for
database option '%s': '%s'", OptionSSLSkipVerify, v),
- Code: adbc.StatusInvalidArgument,
- }
- }
- case OptionOCSPFailOpenMode:
- switch v {
- case adbc.OptionValueEnabled:
- d.cfg.OCSPFailOpen =
gosnowflake.OCSPFailOpenTrue
- case adbc.OptionValueDisabled:
- d.cfg.OCSPFailOpen =
gosnowflake.OCSPFailOpenFalse
- default:
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for
database option '%s': '%s'", OptionSSLSkipVerify, v),
- Code: adbc.StatusInvalidArgument,
- }
- }
- case OptionAuthToken:
- d.cfg.Token = v
- case OptionAuthOktaUrl:
- d.cfg.OktaURL, err = url.Parse(v)
- if err != nil {
- return adbc.Error{
- Msg: fmt.Sprintf("error parsing URL
for database option '%s': '%s'", k, v),
- Code: adbc.StatusInvalidArgument,
- }
- }
- case OptionKeepSessionAlive:
- switch v {
- case adbc.OptionValueEnabled:
- d.cfg.KeepSessionAlive = true
- case adbc.OptionValueDisabled:
- d.cfg.KeepSessionAlive = false
- default:
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for
database option '%s': '%s'", OptionSSLSkipVerify, v),
- Code: adbc.StatusInvalidArgument,
- }
- }
- case OptionDisableTelemetry:
- switch v {
- case adbc.OptionValueEnabled:
- d.cfg.DisableTelemetry = true
- case adbc.OptionValueDisabled:
- d.cfg.DisableTelemetry = false
- default:
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for
database option '%s': '%s'", OptionSSLSkipVerify, v),
- Code: adbc.StatusInvalidArgument,
- }
- }
- case OptionJwtPrivateKey:
- data, err := os.ReadFile(v)
- if err != nil {
- return adbc.Error{
- Msg: "could not read private key file
'" + v + "': " + err.Error(),
- Code: adbc.StatusInvalidArgument,
- }
- }
-
- d.cfg.PrivateKey, err = x509.ParsePKCS1PrivateKey(data)
- if err != nil {
- return adbc.Error{
- Msg: "failed parsing private key file
'" + v + "': " + err.Error(),
- Code: adbc.StatusInvalidArgument,
- }
- }
- case OptionClientRequestMFAToken:
- switch v {
- case adbc.OptionValueEnabled:
- d.cfg.ClientRequestMfaToken =
gosnowflake.ConfigBoolTrue
- case adbc.OptionValueDisabled:
- d.cfg.ClientRequestMfaToken =
gosnowflake.ConfigBoolFalse
- default:
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for
database option '%s': '%s'", OptionSSLSkipVerify, v),
- Code: adbc.StatusInvalidArgument,
- }
- }
- case OptionClientStoreTempCred:
- switch v {
- case adbc.OptionValueEnabled:
- d.cfg.ClientStoreTemporaryCredential =
gosnowflake.ConfigBoolTrue
- case adbc.OptionValueDisabled:
- d.cfg.ClientStoreTemporaryCredential =
gosnowflake.ConfigBoolFalse
- default:
- return adbc.Error{
- Msg: fmt.Sprintf("Invalid value for
database option '%s': '%s'", OptionSSLSkipVerify, v),
- Code: adbc.StatusInvalidArgument,
- }
- }
- case OptionLogTracing:
- d.cfg.Tracing = v
- default:
- d.cfg.Params[k] = &v
- }
- }
- return nil
-}
-
-func (d *database) SetOption(key string, val string) error {
- // Can't set options after init
- return adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-
-func (d *database) SetOptionBytes(key string, value []byte) error {
- return adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-
-func (d *database) SetOptionInt(key string, value int64) error {
- return adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-
-func (d *database) SetOptionDouble(key string, value float64) error {
- return adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-
-func (d *database) Open(ctx context.Context) (adbc.Connection, error) {
- connector := gosnowflake.NewConnector(drv, *d.cfg)
-
- ctx = gosnowflake.WithArrowAllocator(
- gosnowflake.WithArrowBatches(ctx), d.alloc)
-
- cn, err := connector.Connect(ctx)
- if err != nil {
- return nil, errToAdbcErr(adbc.StatusIO, err)
+func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database,
error) {
+ opts = maps.Clone(opts)
+ db := &databaseImpl{DatabaseImplBase:
driverbase.NewDatabaseImplBase(&d.DriverImplBase)}
+ if err := db.SetOptions(opts); err != nil {
+ return nil, err
}
-
- return &cnxn{cn: cn.(snowflakeConn), db: d, ctor: connector, sqldb:
sql.OpenDB(connector)}, nil
+ return driverbase.NewDatabase(db), nil
}
diff --git a/go/adbc/driver/snowflake/driver_test.go
b/go/adbc/driver/snowflake/driver_test.go
index c848be23..e35dea79 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -54,7 +54,7 @@ func (s *SnowflakeQuirks) SetupDriver(t *testing.T)
adbc.Driver {
cfg.Schema = s.schemaName
s.connector = gosnowflake.NewConnector(gosnowflake.SnowflakeDriver{},
*cfg)
- return driver.Driver{Alloc: s.mem}
+ return driver.NewDriver(s.mem)
}
func (s *SnowflakeQuirks) TearDownDriver(t *testing.T, _ adbc.Driver) {
diff --git a/go/adbc/driver/snowflake/driver.go
b/go/adbc/driver/snowflake/snowflake_database.go
similarity index 58%
copy from go/adbc/driver/snowflake/driver.go
copy to go/adbc/driver/snowflake/snowflake_database.go
index b4f063f7..27aa9032 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/snowflake_database.go
@@ -21,181 +21,17 @@ import (
"context"
"crypto/x509"
"database/sql"
- "errors"
"fmt"
"net/url"
"os"
- "runtime/debug"
"strconv"
- "strings"
"time"
"github.com/apache/arrow-adbc/go/adbc"
- "github.com/apache/arrow/go/v13/arrow/memory"
+ "github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
"github.com/snowflakedb/gosnowflake"
- "golang.org/x/exp/maps"
-)
-
-const (
- infoDriverName = "ADBC Snowflake Driver - Go"
- infoVendorName = "Snowflake"
-
- OptionDatabase = "adbc.snowflake.sql.db"
- OptionSchema = "adbc.snowflake.sql.schema"
- OptionWarehouse = "adbc.snowflake.sql.warehouse"
- OptionRole = "adbc.snowflake.sql.role"
- OptionRegion = "adbc.snowflake.sql.region"
- OptionAccount = "adbc.snowflake.sql.account"
- OptionProtocol = "adbc.snowflake.sql.uri.protocol"
- OptionPort = "adbc.snowflake.sql.uri.port"
- OptionHost = "adbc.snowflake.sql.uri.host"
- // Specify auth type to use for snowflake connection based on
- // what is supported by the snowflake driver. Default is
- // "auth_snowflake" (use OptionValueAuth* consts to specify desired
- // authentication type).
- OptionAuthType = "adbc.snowflake.sql.auth_type"
- // Login retry timeout EXCLUDING network roundtrip and reading http
response
- // use format like http://pkg.go.dev/time#ParseDuration such as
- // "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values
- // but the absolute value will be used.
- OptionLoginTimeout = "adbc.snowflake.sql.client_option.login_timeout"
- // request retry timeout EXCLUDING network roundtrip and reading http
response
- // use format like http://pkg.go.dev/time#ParseDuration such as
- // "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values
- // but the absolute value will be used.
- OptionRequestTimeout =
"adbc.snowflake.sql.client_option.request_timeout"
- // JWT expiration after timeout
- // use format like http://pkg.go.dev/time#ParseDuration such as
- // "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values
- // but the absolute value will be used.
- OptionJwtExpireTimeout =
"adbc.snowflake.sql.client_option.jwt_expire_timeout"
- // Timeout for network round trip + reading http response
- // use format like http://pkg.go.dev/time#ParseDuration such as
- // "300ms", "1.5s" or "1m30s". ParseDuration accepts negative values
- // but the absolute value will be used.
- OptionClientTimeout = "adbc.snowflake.sql.client_option.client_timeout"
-
- OptionApplicationName = "adbc.snowflake.sql.client_option.app_name"
- OptionSSLSkipVerify =
"adbc.snowflake.sql.client_option.tls_skip_verify"
- OptionOCSPFailOpenMode =
"adbc.snowflake.sql.client_option.ocsp_fail_open_mode"
- // specify the token to use for OAuth or other forms of authentication
- OptionAuthToken = "adbc.snowflake.sql.client_option.auth_token"
- // specify the OKTAUrl to use for OKTA Authentication
- OptionAuthOktaUrl = "adbc.snowflake.sql.client_option.okta_url"
- // enable the session to persist even after the connection is closed
- OptionKeepSessionAlive =
"adbc.snowflake.sql.client_option.keep_session_alive"
- // specify the RSA private key to use to sign the JWT
- // this should point to a file containing a PKCS1 private key to be
- // loaded. Commonly encoded in PEM blocks of type "RSA PRIVATE KEY"
- OptionJwtPrivateKey =
"adbc.snowflake.sql.client_option.jwt_private_key"
- OptionDisableTelemetry =
"adbc.snowflake.sql.client_option.disable_telemetry"
- // snowflake driver logging level
- OptionLogTracing = "adbc.snowflake.sql.client_option.tracing"
- // When true, the MFA token is cached in the credential manager. True
by default
- // on Windows/OSX, false for Linux
- OptionClientRequestMFAToken =
"adbc.snowflake.sql.client_option.cache_mfa_token"
- // When true, the ID token is cached in the credential manager. True by
default
- // on Windows/OSX, false for Linux
- OptionClientStoreTempCred =
"adbc.snowflake.sql.client_option.store_temp_creds"
-
- // auth types are implemented by the Snowflake driver in gosnowflake
- // general username password authentication
- OptionValueAuthSnowflake = "auth_snowflake"
- // use OAuth authentication for snowflake connection
- OptionValueAuthOAuth = "auth_oauth"
- // use an external browser to access a FED and perform SSO auth
- OptionValueAuthExternalBrowser = "auth_ext_browser"
- // use a native OKTA URL to perform SSO authentication on Okta
- OptionValueAuthOkta = "auth_okta"
- // use a JWT to perform authentication
- OptionValueAuthJwt = "auth_jwt"
- // use a username and password with mfa
- OptionValueAuthUserPassMFA = "auth_mfa"
)
-var (
- infoDriverVersion string
- infoDriverArrowVersion string
- infoSupportedCodes []adbc.InfoCode
-)
-
-func init() {
- if info, ok := debug.ReadBuildInfo(); ok {
- for _, dep := range info.Deps {
- switch {
- case dep.Path ==
"github.com/apache/arrow-adbc/go/adbc/driver/snowflake":
- infoDriverVersion = dep.Version
- case strings.HasPrefix(dep.Path,
"github.com/apache/arrow/go/"):
- infoDriverArrowVersion = dep.Version
- }
- }
- }
- // XXX: Deps not populated in tests
- // https://github.com/golang/go/issues/33976
- if infoDriverVersion == "" {
- infoDriverVersion = "(unknown or development build)"
- }
- if infoDriverArrowVersion == "" {
- infoDriverArrowVersion = "(unknown or development build)"
- }
-
- infoSupportedCodes = []adbc.InfoCode{
- adbc.InfoDriverName,
- adbc.InfoDriverVersion,
- adbc.InfoDriverArrowVersion,
- adbc.InfoVendorName,
- }
-}
-
-func errToAdbcErr(code adbc.Status, err error) error {
- if err == nil {
- return nil
- }
-
- var e adbc.Error
- if errors.As(err, &e) {
- e.Code = code
- return e
- }
-
- var sferr *gosnowflake.SnowflakeError
- if errors.As(err, &sferr) {
- var sqlstate [5]byte
- copy(sqlstate[:], []byte(sferr.SQLState))
-
- if sferr.SQLState == "42S02" {
- code = adbc.StatusNotFound
- }
-
- return adbc.Error{
- Code: code,
- Msg: sferr.Error(),
- VendorCode: int32(sferr.Number),
- SqlState: sqlstate,
- }
- }
-
- return adbc.Error{
- Msg: err.Error(),
- Code: code,
- }
-}
-
-type Driver struct {
- Alloc memory.Allocator
-}
-
-func (d Driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
- db := &database{alloc: d.Alloc}
-
- opts = maps.Clone(opts)
- if db.alloc == nil {
- db.alloc = memory.DefaultAllocator
- }
-
- return db, db.SetOptions(opts)
-}
-
var (
drv = gosnowflake.SnowflakeDriver{}
authTypeMap = map[string]gosnowflake.AuthType{
@@ -208,12 +44,12 @@ var (
}
)
-type database struct {
- cfg *gosnowflake.Config
- alloc memory.Allocator
+type databaseImpl struct {
+ driverbase.DatabaseImplBase
+ cfg *gosnowflake.Config
}
-func (d *database) GetOption(key string) (string, error) {
+func (d *databaseImpl) GetOption(key string) (string, error) {
switch key {
case adbc.OptionKeyUsername:
return d.cfg.User, nil
@@ -293,26 +129,26 @@ func (d *database) GetOption(key string) (string, error) {
Code: adbc.StatusNotFound,
}
}
-func (d *database) GetOptionBytes(key string) ([]byte, error) {
+func (d *databaseImpl) GetOptionBytes(key string) ([]byte, error) {
return nil, adbc.Error{
Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
Code: adbc.StatusNotFound,
}
}
-func (d *database) GetOptionInt(key string) (int64, error) {
+func (d *databaseImpl) GetOptionInt(key string) (int64, error) {
return 0, adbc.Error{
Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
Code: adbc.StatusNotFound,
}
}
-func (d *database) GetOptionDouble(key string) (float64, error) {
+func (d *databaseImpl) GetOptionDouble(key string) (float64, error) {
return 0, adbc.Error{
Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
Code: adbc.StatusNotFound,
}
}
-func (d *database) SetOptions(cnOptions map[string]string) error {
+func (d *databaseImpl) SetOptions(cnOptions map[string]string) error {
uri, ok := cnOptions[adbc.OptionKeyURI]
if ok {
cfg, err := gosnowflake.ParseDSN(uri)
@@ -525,40 +361,11 @@ func (d *database) SetOptions(cnOptions
map[string]string) error {
return nil
}
-func (d *database) SetOption(key string, val string) error {
- // Can't set options after init
- return adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-
-func (d *database) SetOptionBytes(key string, value []byte) error {
- return adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-
-func (d *database) SetOptionInt(key string, value int64) error {
- return adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-
-func (d *database) SetOptionDouble(key string, value float64) error {
- return adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotImplemented,
- }
-}
-
-func (d *database) Open(ctx context.Context) (adbc.Connection, error) {
+func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
connector := gosnowflake.NewConnector(drv, *d.cfg)
ctx = gosnowflake.WithArrowAllocator(
- gosnowflake.WithArrowBatches(ctx), d.alloc)
+ gosnowflake.WithArrowBatches(ctx), d.Alloc)
cn, err := connector.Connect(ctx)
if err != nil {
diff --git a/go/adbc/pkg/_tmpl/driver.go.tmpl b/go/adbc/pkg/_tmpl/driver.go.tmpl
index 1ffa1f85..f63a8d5a 100644
--- a/go/adbc/pkg/_tmpl/driver.go.tmpl
+++ b/go/adbc/pkg/_tmpl/driver.go.tmpl
@@ -66,7 +66,7 @@ import (
)
// Must use malloc() to respect CGO rules
-var drv = {{.Driver}}{Alloc: mallocator.NewMallocator()}
+var drv = {{.Driver}}(mallocator.NewMallocator())
// Flag set if any method panic()ed - afterwards all calls to driver will fail
// since internal state of driver is unknown
// (Can't use atomic.Bool since that's Go 1.19)
diff --git a/go/adbc/pkg/flightsql/driver.go b/go/adbc/pkg/flightsql/driver.go
index 270ffaaf..8ccb437f 100644
--- a/go/adbc/pkg/flightsql/driver.go
+++ b/go/adbc/pkg/flightsql/driver.go
@@ -69,7 +69,7 @@ import (
)
// Must use malloc() to respect CGO rules
-var drv = flightsql.Driver{Alloc: mallocator.NewMallocator()}
+var drv = flightsql.NewDriver(mallocator.NewMallocator())
// Flag set if any method panic()ed - afterwards all calls to driver will fail
// since internal state of driver is unknown
diff --git a/go/adbc/pkg/gen/main.go b/go/adbc/pkg/gen/main.go
index 16f0001f..5367c488 100644
--- a/go/adbc/pkg/gen/main.go
+++ b/go/adbc/pkg/gen/main.go
@@ -88,7 +88,7 @@ func main() {
var (
prefix = flag.String("prefix", "", "function prefix")
driverPkg = flag.String("driver", "", "path to driver package")
- driverType = flag.String("type", "Driver", "name of the driver
type")
+ driverCtor = flag.String("type", "NewDriver", "name of the
driver constructor")
outDir = flag.String("o", "", "output directory")
tmplDir = flag.String("in", "./_tmpl", "template directory
[default=./_tmpl]")
)
@@ -127,7 +127,7 @@ func main() {
}
process(tmplData{
- Driver: pkg[0].Name + "." + *driverType,
+ Driver: pkg[0].Name + "." + *driverCtor,
Prefix: *prefix,
PrefixUpper: strings.ToUpper(*prefix),
}, specs)
diff --git a/go/adbc/pkg/panicdummy/driver.go b/go/adbc/pkg/panicdummy/driver.go
index 1e2279e0..06f424d2 100644
--- a/go/adbc/pkg/panicdummy/driver.go
+++ b/go/adbc/pkg/panicdummy/driver.go
@@ -69,7 +69,7 @@ import (
)
// Must use malloc() to respect CGO rules
-var drv = panicdummy.Driver{Alloc: mallocator.NewMallocator()}
+var drv = panicdummy.NewDriver(mallocator.NewMallocator())
// Flag set if any method panic()ed - afterwards all calls to driver will fail
// since internal state of driver is unknown
diff --git a/go/adbc/pkg/snowflake/driver.go b/go/adbc/pkg/snowflake/driver.go
index e468924d..826fb517 100644
--- a/go/adbc/pkg/snowflake/driver.go
+++ b/go/adbc/pkg/snowflake/driver.go
@@ -69,7 +69,7 @@ import (
)
// Must use malloc() to respect CGO rules
-var drv = snowflake.Driver{Alloc: mallocator.NewMallocator()}
+var drv = snowflake.NewDriver(mallocator.NewMallocator())
// Flag set if any method panic()ed - afterwards all calls to driver will fail
// since internal state of driver is unknown
diff --git a/go/adbc/sqldriver/flightsql/flightsql.go
b/go/adbc/sqldriver/flightsql/flightsql.go
index f318cbb6..9cb9046f 100644
--- a/go/adbc/sqldriver/flightsql/flightsql.go
+++ b/go/adbc/sqldriver/flightsql/flightsql.go
@@ -19,12 +19,14 @@ package flightsql
import (
"database/sql"
+
"github.com/apache/arrow-adbc/go/adbc/driver/flightsql"
"github.com/apache/arrow-adbc/go/adbc/sqldriver"
+ "github.com/apache/arrow/go/v13/arrow/memory"
)
func init() {
sql.Register("flightsql", sqldriver.Driver{
- Driver: flightsql.Driver{},
+ Driver: flightsql.NewDriver(memory.DefaultAllocator),
})
}