This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 302242849 refactor(go/adbc/driver): driverbase implementation for
connection (#1590)
302242849 is described below
commit 302242849ba09dbb4f4b6d95155421dffafb6105
Author: Joel Lubinitsky <[email protected]>
AuthorDate: Tue Mar 19 11:43:42 2024 -0400
refactor(go/adbc/driver): driverbase implementation for connection (#1590)
Implementation of Connection driver base, along with a refactor of
Driver and Database bases.
The bases have been refactored in the following way:
- The `*Impl` interface (e.g. `DatabaseImpl`) now explicitly implements
the corresponding `adbc` interface (e.g. `adbc.Database`).
- We now check to guarantee the `DatabaseImplBase` implements the entire
`DatabaseImpl` interface with stub methods or default implementations.
- A new interface has been added (e.g. `driverbase.Database`) which
contains all methods the _output_ of driverbase constructor
`NewDatabase()` should be. This helps document and guarantee the "extra"
behavior provided by using the driverbase. This interface should be
internal to the library.
- By embedding `DatabaseImpl` in the `database` struct (and similarly
for the other bases) it automatically inherits implementations coming
from the `DatabaseImpl`. This way we don't need to write out all the
implementations a second time, hence the deletes.
- The Connection base uses a builder for its constructor to register any
helper methods (see discussion in comments). The Driver and Database
bases use simple function constructors because they don't have any
helpers to register. This felt simpler but I can make those into trivial
builders as well if we prefer to have consistency between them.
A new `DriverInfo` type has been introduced to help consolidate the
collection and validation of metadata for `GetInfo()`.
There are more small changes such as refactors of the flightsql and
snowflake drivers to make use of the added functionality, as well as a
new set of tests for the driverbase. Please let me know if anything else
could use clarification.
Resolves #1105.
---
go/adbc/adbc.go | 11 +
go/adbc/driver/driverbase/driver.go | 66 ---
go/adbc/driver/flightsql/flightsql_connection.go | 578 +++++++-------------
go/adbc/driver/flightsql/flightsql_database.go | 25 +-
go/adbc/driver/flightsql/flightsql_driver.go | 45 +-
go/adbc/driver/flightsql/flightsql_statement.go | 12 +-
go/adbc/driver/internal/driverbase/connection.go | 497 +++++++++++++++++
.../driver/{ => internal}/driverbase/database.go | 111 ++--
go/adbc/driver/internal/driverbase/driver.go | 116 ++++
go/adbc/driver/internal/driverbase/driver_info.go | 176 ++++++
.../driver/internal/driverbase/driver_info_test.go | 88 +++
go/adbc/driver/internal/driverbase/driver_test.go | 595 +++++++++++++++++++++
go/adbc/driver/{ => internal}/driverbase/error.go | 0
.../driver/{ => internal}/driverbase/logging.go | 0
go/adbc/driver/snowflake/connection.go | 293 +++-------
go/adbc/driver/snowflake/driver.go | 45 +-
go/adbc/driver/snowflake/driver_test.go | 4 +
go/adbc/driver/snowflake/snowflake_database.go | 41 +-
go/adbc/driver/snowflake/statement.go | 2 +-
go/adbc/go.mod | 1 +
go/adbc/go.sum | 1 +
21 files changed, 1861 insertions(+), 846 deletions(-)
diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go
index f5514626a..6968faacf 100644
--- a/go/adbc/adbc.go
+++ b/go/adbc/adbc.go
@@ -355,6 +355,17 @@ const (
InfoDriverADBCVersion InfoCode = 103 // DriverADBCVersion
)
+type InfoValueTypeCode = arrow.UnionTypeCode
+
+const (
+ InfoValueStringType InfoValueTypeCode = 0
+ InfoValueBooleanType InfoValueTypeCode = 1
+ InfoValueInt64Type InfoValueTypeCode = 2
+ InfoValueInt32BitmaskType InfoValueTypeCode = 3
+ InfoValueStringListType InfoValueTypeCode = 4
+ InfoValueInt32ToInt32ListMapType InfoValueTypeCode = 5
+)
+
type ObjectDepth int
const (
diff --git a/go/adbc/driver/driverbase/driver.go
b/go/adbc/driver/driverbase/driver.go
deleted file mode 100644
index e4cfb9960..000000000
--- a/go/adbc/driver/driverbase/driver.go
+++ /dev/null
@@ -1,66 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-// Package driverbase provides a framework for implementing ADBC drivers in
-// Go. It intends to reduce boilerplate for common functionality and managing
-// state transitions.
-package driverbase
-
-import (
- "github.com/apache/arrow-adbc/go/adbc"
- "github.com/apache/arrow/go/v16/arrow/memory"
-)
-
-// DriverImpl is an interface that drivers implement to provide
-// vendor-specific functionality.
-type DriverImpl interface {
- Base() *DriverImplBase
- NewDatabase(opts map[string]string) (adbc.Database, error)
-}
-
-// DriverImplBase is a struct that provides default implementations of the
-// DriverImpl interface. It is meant to be used as a composite struct for a
-// driver's DriverImpl implementation.
-type DriverImplBase struct {
- Alloc memory.Allocator
- ErrorHelper ErrorHelper
-}
-
-func NewDriverImplBase(name string, alloc memory.Allocator) DriverImplBase {
- if alloc == nil {
- alloc = memory.DefaultAllocator
- }
- return DriverImplBase{Alloc: alloc, ErrorHelper:
ErrorHelper{DriverName: name}}
-}
-
-func (base *DriverImplBase) Base() *DriverImplBase {
- return base
-}
-
-// driver is the actual implementation of adbc.Driver.
-type driver struct {
- impl DriverImpl
-}
-
-// NewDriver wraps a DriverImpl to create an adbc.Driver.
-func NewDriver(impl DriverImpl) adbc.Driver {
- return &driver{impl}
-}
-
-func (drv *driver) NewDatabase(opts map[string]string) (adbc.Database, error) {
- return drv.impl.NewDatabase(opts)
-}
diff --git a/go/adbc/driver/flightsql/flightsql_connection.go
b/go/adbc/driver/flightsql/flightsql_connection.go
index e71ac308d..83807856e 100644
--- a/go/adbc/driver/flightsql/flightsql_connection.go
+++ b/go/adbc/driver/flightsql/flightsql_connection.go
@@ -28,6 +28,7 @@ import (
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/internal"
+ "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/array"
"github.com/apache/arrow/go/v16/arrow/flight"
@@ -43,7 +44,9 @@ import (
"google.golang.org/protobuf/proto"
)
-type cnxn struct {
+type connectionImpl struct {
+ driverbase.ConnectionImplBase
+
cl *flightsql.Client
db *databaseImpl
@@ -54,6 +57,82 @@ type cnxn struct {
supportInfo support
}
+// GetCurrentCatalog implements driverbase.CurrentNamespacer.
+func (c *connectionImpl) GetCurrentCatalog() (string, error) {
+ options, err := c.getSessionOptions(context.Background())
+ if err != nil {
+ return "", err
+ }
+ if catalog, ok := options["catalog"]; ok {
+ if val, ok := catalog.(string); ok {
+ return val, nil
+ }
+ return "", c.Base().ErrorHelper.Errorf(adbc.StatusInternal,
"server returned non-string catalog %#v", catalog)
+ }
+ return "", c.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "current
catalog not supported")
+}
+
+// GetCurrentDbSchema implements driverbase.CurrentNamespacer.
+func (c *connectionImpl) GetCurrentDbSchema() (string, error) {
+ options, err := c.getSessionOptions(context.Background())
+ if err != nil {
+ return "", err
+ }
+ if schema, ok := options["schema"]; ok {
+ if val, ok := schema.(string); ok {
+ return val, nil
+ }
+ return "", c.Base().ErrorHelper.Errorf(adbc.StatusInternal,
"server returned non-string schema %#v", schema)
+ }
+ return "", c.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "current
schema not supported")
+}
+
+// SetCurrentCatalog implements driverbase.CurrentNamespacer.
+func (c *connectionImpl) SetCurrentCatalog(value string) error {
+ return c.setSessionOptions(context.Background(), "catalog", value)
+}
+
+// SetCurrentDbSchema implements driverbase.CurrentNamespacer.
+func (c *connectionImpl) SetCurrentDbSchema(value string) error {
+ return c.setSessionOptions(context.Background(), "schema", value)
+}
+
+func (c *connectionImpl) SetAutocommit(enabled bool) error {
+ if enabled && c.txn == nil {
+ // no-op don't even error if the server didn't support
transactions
+ return nil
+ }
+
+ if !c.supportInfo.transactions {
+ return errNoTransactionSupport
+ }
+
+ ctx := metadata.NewOutgoingContext(context.Background(), c.hdrs)
+ var err error
+ if c.txn != nil {
+ if err = c.txn.Commit(ctx, c.timeouts); err != nil {
+ return adbc.Error{
+ Msg: "[Flight SQL] failed to update
autocommit: " + err.Error(),
+ Code: adbc.StatusIO,
+ }
+ }
+ }
+
+ if enabled {
+ c.txn = nil
+ return nil
+ }
+
+ if c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts); err != nil {
+ return adbc.Error{
+ Msg: "[Flight SQL] failed to update autocommit: " +
err.Error(),
+ Code: adbc.StatusIO,
+ }
+ }
+
+ return nil
+}
+
var adbcToFlightSQLInfo = map[adbc.InfoCode]flightsql.SqlInfo{
adbc.InfoVendorName: flightsql.SqlInfoFlightSqlServerName,
adbc.InfoVendorVersion: flightsql.SqlInfoFlightSqlServerVersion,
@@ -97,7 +176,7 @@ func doGet(ctx context.Context, cl *flightsql.Client,
endpoint *flight.FlightEnd
return nil, err
}
-func (c *cnxn) getSessionOptions(ctx context.Context) (map[string]interface{},
error) {
+func (c *connectionImpl) getSessionOptions(ctx context.Context)
(map[string]interface{}, error) {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
var header, trailer metadata.MD
rawOptions, err := c.cl.GetSessionOptions(ctx,
&flight.GetSessionOptionsRequest{}, grpc.Header(&header),
grpc.Trailer(&trailer), c.timeouts)
@@ -140,7 +219,7 @@ func (c *cnxn) getSessionOptions(ctx context.Context)
(map[string]interface{}, e
return options, nil
}
-func (c *cnxn) setSessionOptions(ctx context.Context, key string, val
interface{}) error {
+func (c *connectionImpl) setSessionOptions(ctx context.Context, key string,
val interface{}) error {
req := flight.SetSessionOptionsRequest{}
var err error
@@ -206,7 +285,7 @@ func getSessionOption[T any](options
map[string]interface{}, key string, default
return value, nil
}
-func (c *cnxn) GetOption(key string) (string, error) {
+func (c *connectionImpl) GetOption(key string) (string, error) {
if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix)
headers := c.hdrs.Get(name)
@@ -226,51 +305,6 @@ func (c *cnxn) GetOption(key string) (string, error) {
return c.timeouts.queryTimeout.String(), nil
case OptionTimeoutUpdate:
return c.timeouts.updateTimeout.String(), nil
- case adbc.OptionKeyAutoCommit:
- if c.txn != nil {
- // No autocommit
- return adbc.OptionValueDisabled, nil
- } else {
- // Autocommit
- return adbc.OptionValueEnabled, nil
- }
- case adbc.OptionKeyCurrentCatalog:
- options, err := c.getSessionOptions(context.Background())
- if err != nil {
- return "", err
- }
- if catalog, ok := options["catalog"]; ok {
- if val, ok := catalog.(string); ok {
- return val, nil
- }
- return "", adbc.Error{
- Msg: fmt.Sprintf("[FlightSQL] Server returned
non-string catalog %#v", catalog),
- Code: adbc.StatusInternal,
- }
- }
- return "", adbc.Error{
- Msg: "[FlightSQL] current catalog not supported",
- Code: adbc.StatusNotFound,
- }
-
- case adbc.OptionKeyCurrentDbSchema:
- options, err := c.getSessionOptions(context.Background())
- if err != nil {
- return "", err
- }
- if schema, ok := options["schema"]; ok {
- if val, ok := schema.(string); ok {
- return val, nil
- }
- return "", adbc.Error{
- Msg: fmt.Sprintf("[FlightSQL] Server returned
non-string schema %#v", schema),
- Code: adbc.StatusInternal,
- }
- }
- return "", adbc.Error{
- Msg: "[FlightSQL] current schema not supported",
- Code: adbc.StatusNotFound,
- }
case OptionSessionOptions:
options, err := c.getSessionOptions(context.Background())
if err != nil {
@@ -333,7 +367,7 @@ func (c *cnxn) GetOption(key string) (string, error) {
}
}
-func (c *cnxn) GetOptionBytes(key string) ([]byte, error) {
+func (c *connectionImpl) GetOptionBytes(key string) ([]byte, error) {
switch key {
case OptionSessionOptions:
options, err := c.getSessionOptions(context.Background())
@@ -356,7 +390,7 @@ func (c *cnxn) GetOptionBytes(key string) ([]byte, error) {
}
}
-func (c *cnxn) GetOptionInt(key string) (int64, error) {
+func (c *connectionImpl) GetOptionInt(key string) (int64, error) {
switch key {
case OptionTimeoutFetch:
fallthrough
@@ -378,13 +412,10 @@ func (c *cnxn) GetOptionInt(key string) (int64, error) {
return getSessionOption(options, name, int64(0), "an integer")
}
- return 0, adbc.Error{
- Msg: "[Flight SQL] unknown connection option",
- Code: adbc.StatusNotFound,
- }
+ return c.ConnectionImplBase.GetOptionInt(key)
}
-func (c *cnxn) GetOptionDouble(key string) (float64, error) {
+func (c *connectionImpl) GetOptionDouble(key string) (float64, error) {
switch key {
case OptionTimeoutFetch:
return c.timeouts.fetchTimeout.Seconds(), nil
@@ -402,13 +433,10 @@ func (c *cnxn) GetOptionDouble(key string) (float64,
error) {
return getSessionOption(options, name, float64(0.0), "a
floating-point")
}
- return 0.0, adbc.Error{
- Msg: "[Flight SQL] unknown connection option",
- Code: adbc.StatusNotFound,
- }
+ return c.ConnectionImplBase.GetOptionDouble(key)
}
-func (c *cnxn) SetOption(key, value string) error {
+func (c *connectionImpl) SetOption(key, value string) error {
if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix)
if value == "" {
@@ -422,56 +450,6 @@ func (c *cnxn) SetOption(key, value string) error {
switch key {
case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate:
return c.timeouts.setTimeoutString(key, value)
- case adbc.OptionKeyAutoCommit:
- autocommit := true
- switch value {
- case adbc.OptionValueEnabled:
- // Do nothing
- case adbc.OptionValueDisabled:
- autocommit = false
- default:
- return adbc.Error{
- Msg: "[Flight SQL] invalid value for option "
+ key + ": " + value,
- Code: adbc.StatusInvalidArgument,
- }
- }
-
- if autocommit && c.txn == nil {
- // no-op don't even error if the server didn't support
transactions
- return nil
- }
-
- if !c.supportInfo.transactions {
- return errNoTransactionSupport
- }
-
- ctx := metadata.NewOutgoingContext(context.Background(), c.hdrs)
- var err error
- if c.txn != nil {
- if err = c.txn.Commit(ctx, c.timeouts); err != nil {
- return adbc.Error{
- Msg: "[Flight SQL] failed to update
autocommit: " + err.Error(),
- Code: adbc.StatusIO,
- }
- }
- }
-
- if autocommit {
- c.txn = nil
- return nil
- }
-
- if c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts); err !=
nil {
- return adbc.Error{
- Msg: "[Flight SQL] failed to update
autocommit: " + err.Error(),
- Code: adbc.StatusIO,
- }
- }
- return nil
- case adbc.OptionKeyCurrentCatalog:
- return c.setSessionOptions(context.Background(), "catalog",
value)
- case adbc.OptionKeyCurrentDbSchema:
- return c.setSessionOptions(context.Background(), "schema",
value)
}
switch {
@@ -506,20 +484,10 @@ func (c *cnxn) SetOption(key, value string) error {
return c.setSessionOptions(context.Background(), name, nil)
}
- return adbc.Error{
- Msg: "[Flight SQL] unknown connection option",
- Code: adbc.StatusNotImplemented,
- }
+ return c.ConnectionImplBase.SetOption(key, value)
}
-func (c *cnxn) SetOptionBytes(key string, value []byte) error {
- return adbc.Error{
- Msg: "[Flight SQL] unknown connection option",
- Code: adbc.StatusNotImplemented,
- }
-}
-
-func (c *cnxn) SetOptionInt(key string, value int64) error {
+func (c *connectionImpl) SetOptionInt(key string, value int64) error {
switch key {
case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate:
return c.timeouts.setTimeout(key, float64(value))
@@ -529,13 +497,10 @@ func (c *cnxn) SetOptionInt(key string, value int64)
error {
return c.setSessionOptions(context.Background(), name, value)
}
- return adbc.Error{
- Msg: "[Flight SQL] unknown connection option",
- Code: adbc.StatusNotImplemented,
- }
+ return c.ConnectionImplBase.SetOptionInt(key, value)
}
-func (c *cnxn) SetOptionDouble(key string, value float64) error {
+func (c *connectionImpl) SetOptionDouble(key string, value float64) error {
switch key {
case OptionTimeoutFetch:
fallthrough
@@ -549,231 +514,117 @@ func (c *cnxn) SetOptionDouble(key string, value
float64) error {
return c.setSessionOptions(context.Background(), name, value)
}
- return adbc.Error{
- Msg: "[Flight SQL] unknown connection option",
- Code: adbc.StatusNotImplemented,
- }
+ return c.ConnectionImplBase.SetOptionDouble(key, value)
}
-// GetInfo returns metadata about the database/driver.
-//
-// The result is an Arrow dataset with the following schema:
-//
-// Field Name
| Field Type
-// ----------------------------|-----------------------------
-// info_name
| uint32 not null
-// info_value
| INFO_SCHEMA
-//
-// INFO_SCHEMA is a dense union with members:
-//
-// Field Name (Type Code) | Field Type
-// ----------------------------|-----------------------------
-// string_value (0) | utf8
-// bool_value (1) | bool
-// int64_value (2) | int64
-// int32_bitmask (3) | int32
-// string_list (4) |
list<utf8>
-// int32_to_int32_list_map (5) | map<int32, list<int32>>
-//
-// Each metadatum is identified by an integer code. The recognized
-// codes are defined as constants. Codes [0, 10_000) are reserved
-// for ADBC usage. Drivers/vendors will ignore requests for unrecognized
-// codes (the row will be omitted from the result).
-func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode)
(array.RecordReader, error) {
- const strValTypeID arrow.UnionTypeCode = 0
- const intValTypeID arrow.UnionTypeCode = 2
+func (c *connectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes
[]adbc.InfoCode) error {
+ driverInfo := c.ConnectionImplBase.DriverInfo
if len(infoCodes) == 0 {
- infoCodes = infoSupportedCodes
+ infoCodes = driverInfo.InfoSupportedCodes()
}
- bldr := array.NewRecordBuilder(c.cl.Alloc, adbc.GetInfoSchema)
- defer bldr.Release()
- bldr.Reserve(len(infoCodes))
-
- infoNameBldr := bldr.Field(0).(*array.Uint32Builder)
- infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder)
- strInfoBldr :=
infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder)
- intInfoBldr :=
infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder)
-
translated := make([]flightsql.SqlInfo, 0, len(infoCodes))
for _, code := range infoCodes {
if t, ok := adbcToFlightSQLInfo[code]; ok {
translated = append(translated, t)
- continue
}
+ }
- switch code {
- case adbc.InfoDriverName:
- infoNameBldr.Append(uint32(code))
- infoValueBldr.Append(strValTypeID)
- strInfoBldr.Append(infoDriverName)
- case adbc.InfoDriverVersion:
- infoNameBldr.Append(uint32(code))
- infoValueBldr.Append(strValTypeID)
- strInfoBldr.Append(infoDriverVersion)
- case adbc.InfoDriverArrowVersion:
- infoNameBldr.Append(uint32(code))
- infoValueBldr.Append(strValTypeID)
- strInfoBldr.Append(infoDriverArrowVersion)
- case adbc.InfoDriverADBCVersion:
- infoNameBldr.Append(uint32(code))
- infoValueBldr.Append(intValTypeID)
- intInfoBldr.Append(adbc.AdbcVersion1_1_0)
- }
+ // None of the requested info codes are available on the server, so
just return the local info
+ if len(translated) == 0 {
+ return nil
}
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
var header, trailer metadata.MD
info, err := c.cl.GetSqlInfo(ctx, translated, grpc.Header(&header),
grpc.Trailer(&trailer), c.timeouts)
- if err == nil {
- for i, endpoint := range info.Endpoint {
- var header, trailer metadata.MD
- rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache,
grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts)
- if err != nil {
- return nil,
adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint
%d: %s", i, endpoint.Location)
- }
- for rdr.Next() {
- rec := rdr.Record()
- field := rec.Column(0).(*array.Uint32)
- info := rec.Column(1).(*array.DenseUnion)
-
- for i := 0; i < int(rec.NumRows()); i++ {
- switch
flightsql.SqlInfo(field.Value(i)) {
- case
flightsql.SqlInfoFlightSqlServerName:
-
infoNameBldr.Append(uint32(adbc.InfoVendorName))
- case
flightsql.SqlInfoFlightSqlServerVersion:
-
infoNameBldr.Append(uint32(adbc.InfoVendorVersion))
- case
flightsql.SqlInfoFlightSqlServerArrowVersion:
-
infoNameBldr.Append(uint32(adbc.InfoVendorArrowVersion))
- default:
- continue
- }
+ // Just return local driver info if GetSqlInfo hasn't been implemented
on the server
+ if grpcstatus.Code(err) == grpccodes.Unimplemented {
+ return nil
+ }
+
+ if err != nil {
+ return adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)")
+ }
+
+ // No error, go get the SqlInfo from the server
+ for i, endpoint := range info.Endpoint {
+ var header, trailer metadata.MD
+ rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache,
grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts)
+ if err != nil {
+ return adbcFromFlightStatusWithDetails(err, header,
trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location)
+ }
- infoValueBldr.Append(info.TypeCode(i))
- // we know we're only doing string
fields here right now
- v :=
info.Field(info.ChildID(i)).(*array.String).
- Value(int(info.ValueOffset(i)))
- strInfoBldr.Append(v)
+ for rdr.Next() {
+ rec := rdr.Record()
+ field := rec.Column(0).(*array.Uint32)
+ info := rec.Column(1).(*array.DenseUnion)
+
+ var adbcInfoCode adbc.InfoCode
+ for i := 0; i < int(rec.NumRows()); i++ {
+ switch flightsql.SqlInfo(field.Value(i)) {
+ case flightsql.SqlInfoFlightSqlServerName:
+ adbcInfoCode = adbc.InfoVendorName
+ case flightsql.SqlInfoFlightSqlServerVersion:
+ adbcInfoCode = adbc.InfoVendorVersion
+ case
flightsql.SqlInfoFlightSqlServerArrowVersion:
+ adbcInfoCode =
adbc.InfoVendorArrowVersion
+ default:
+ continue
}
- }
- if err := checkContext(rdr.Err(), ctx); err != nil {
- return nil,
adbcFromFlightStatusWithDetails(err, header, trailer, "GetInfo(DoGet): endpoint
%d: %s", i, endpoint.Location)
+ // we know we're only doing string fields here
right now
+ v :=
info.Field(info.ChildID(i)).(*array.String).
+ Value(int(info.ValueOffset(i)))
+ if err :=
driverInfo.RegisterInfoCode(adbcInfoCode, strings.Clone(v)); err != nil {
+ return err
+ }
}
}
- } else if grpcstatus.Code(err) != grpccodes.Unimplemented {
- return nil, adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)")
+
+ if err := checkContext(rdr.Err(), ctx); err != nil {
+ return adbcFromFlightStatusWithDetails(err, header,
trailer, "GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location)
+ }
}
- final := bldr.NewRecord()
- defer final.Release()
- return array.NewRecordReader(adbc.GetInfoSchema, []arrow.Record{final})
+ return nil
}
-// GetObjects gets a hierarchical view of all catalogs, database schemas,
-// tables, and columns.
-//
-// The result is an Arrow Dataset with the following schema:
-//
-// Field Name
| Field Type
-// ----------------------------|----------------------------
-// catalog_name
| utf8
-// catalog_db_schemas |
list<DB_SCHEMA_SCHEMA>
-//
-// DB_SCHEMA_SCHEMA is a Struct with the fields:
-//
-// Field Name
| Field Type
-// ----------------------------|----------------------------
-// db_schema_name | utf8
-// db_schema_tables |
list<TABLE_SCHEMA>
-//
-// TABLE_SCHEMA is a Struct with the fields:
-//
-// Field Name
| Field Type
-// ----------------------------|----------------------------
-// table_name
| utf8 not null
-// table_type
| utf8 not null
-// table_columns
| list<COLUMN_SCHEMA>
-// table_constraints |
list<CONSTRAINT_SCHEMA>
-//
-// COLUMN_SCHEMA is a Struct with the fields:
-//
-// Field Name
| Field Type | Comments
-// ----------------------------|---------------------|---------
-// column_name
| utf8 not null |
-// ordinal_position
| int32 | (1)
-// remarks
| utf8
| (2)
-// xdbc_data_type
| int16 | (3)
-// xdbc_type_name
| utf8 | (3)
-// xdbc_column_size
| int32 | (3)
-// xdbc_decimal_digits | int16
| (3)
-// xdbc_num_prec_radix | int16
| (3)
-// xdbc_nullable
| int16 | (3)
-// xdbc_column_def
| utf8 | (3)
-// xdbc_sql_data_type | int16
| (3)
-// xdbc_datetime_sub
| int16 | (3)
-// xdbc_char_octet_length | int32
| (3)
-// xdbc_is_nullable
| utf8 | (3)
-// xdbc_scope_catalog | utf8
| (3)
-// xdbc_scope_schema
| utf8 | (3)
-// xdbc_scope_table
| utf8 | (3)
-// xdbc_is_autoincrement | bool
| (3)
-// xdbc_is_generatedcolumn | bool
| (3)
-//
-// 1. The column's ordinal position in the table (starting from 1).
-// 2. Database-specific description of the column.
-// 3. Optional Value. Should be null if not supported by the driver.
-// xdbc_values are meant to provide JDBC/ODBC-compatible metadata
-// in an agnostic manner.
-//
-// CONSTRAINT_SCHEMA is a Struct with the fields:
-//
-// Field Name
| Field Type | Comments
-// ----------------------------|---------------------|---------
-// constraint_name | utf8
|
-// constraint_type | utf8
not null | (1)
-// constraint_column_names | list<utf8> not null | (2)
-// constraint_column_usage | list<USAGE_SCHEMA> | (3)
-//
-// 1. One of 'CHECK', 'FOREIGN KEY', 'PRIMARY KEY', or 'UNIQUE'.
-// 2. The columns on the current table that are constrained, in order.
-// 3. For FOREIGN KEY only, the referenced table and columns.
-//
-// USAGE_SCHEMA is a Struct with fields:
-//
-// Field Name
| Field Type
-// ----------------------------|----------------------------
-// fk_catalog
| utf8
-// fk_db_schema
| utf8
-// fk_table
| utf8 not null
-// fk_column_name | utf8
not null
-//
-// For the parameters: If nil is passed, then that parameter will not
-// be filtered by at all. If an empty string, then only objects without
-// that property (ie: catalog or db schema) will be returned.
-//
-// tableName and columnName must be either nil (do not filter by
-// table name or column name) or non-empty.
-//
-// All non-empty, non-nil strings should be a search pattern (as described
-// earlier).
-func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog
*string, dbSchema *string, tableName *string, columnName *string, tableType
[]string) (array.RecordReader, error) {
- ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
- g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog,
DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType:
tableType}
- if err := g.Init(c.db.Alloc, c.getObjectsDbSchemas,
c.getObjectsTables); err != nil {
- return nil, err
+// Helper function to read and validate a metadata stream
+func (c *connectionImpl) readInfo(ctx context.Context, expectedSchema
*arrow.Schema, info *flight.FlightInfo, opts ...grpc.CallOption)
(array.RecordReader, error) {
+ // use a default queueSize for the reader
+ rdr, err := newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache,
5, opts...)
+ if err != nil {
+ return nil, adbcFromFlightStatus(err, "DoGet")
}
- defer g.Release()
- var header, trailer metadata.MD
+ if !rdr.Schema().Equal(expectedSchema) {
+ rdr.Release()
+ return nil, adbc.Error{
+ Msg: fmt.Sprintf("Invalid schema returned for:
expected %s, got %s", expectedSchema.String(), rdr.Schema().String()),
+ Code: adbc.StatusInternal,
+ }
+ }
+ return rdr, nil
+}
+
+func (c *connectionImpl) GetObjectsCatalogs(ctx context.Context, catalog
*string) ([]string, error) {
+ var (
+ header, trailer metadata.MD
+ numCatalogs int64
+ )
// To avoid an N+1 query problem, we assume result sets here will fit
in memory and build up a single response.
info, err := c.cl.GetCatalogs(ctx, grpc.Header(&header),
grpc.Trailer(&trailer), c.timeouts)
if err != nil {
return nil, adbcFromFlightStatusWithDetails(err, header,
trailer, "GetObjects(GetCatalogs)")
}
+ if info.TotalRecords > 0 {
+ numCatalogs = info.TotalRecords
+ }
+
header = metadata.MD{}
trailer = metadata.MD{}
rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info, c.timeouts,
grpc.Header(&header), grpc.Trailer(&trailer))
@@ -782,48 +633,25 @@ func (c *cnxn) GetObjects(ctx context.Context, depth
adbc.ObjectDepth, catalog *
}
defer rdr.Release()
- foundCatalog := false
+ catalogs := make([]string, 0, numCatalogs)
for rdr.Next() {
arr := rdr.Record().Column(0).(*array.String)
for i := 0; i < arr.Len(); i++ {
// XXX: force copy since accessor is unsafe
catalogName := string([]byte(arr.Value(i)))
- g.AppendCatalog(catalogName)
- foundCatalog = true
+ catalogs = append(catalogs, catalogName)
}
}
- // Implementations like Dremio report no catalogs, but still have
schemas
- if !foundCatalog && depth != adbc.ObjectDepthCatalogs {
- g.AppendCatalog("")
- }
-
if err := checkContext(rdr.Err(), ctx); err != nil {
return nil, adbcFromFlightStatusWithDetails(err, header,
trailer, "GetObjects(GetCatalogs)")
}
- return g.Finish()
-}
-
-// Helper function to read and validate a metadata stream
-func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema,
info *flight.FlightInfo, opts ...grpc.CallOption) (array.RecordReader, error) {
- // use a default queueSize for the reader
- rdr, err := newRecordReader(ctx, c.db.Alloc, c.cl, info, c.clientCache,
5, opts...)
- if err != nil {
- return nil, adbcFromFlightStatus(err, "DoGet")
- }
- if !rdr.Schema().Equal(expectedSchema) {
- rdr.Release()
- return nil, adbc.Error{
- Msg: fmt.Sprintf("Invalid schema returned for:
expected %s, got %s", expectedSchema.String(), rdr.Schema().String()),
- Code: adbc.StatusInternal,
- }
- }
- return rdr, nil
+ return catalogs, nil
}
// Helper function to build up a map of catalogs to DB schemas
-func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords
[]internal.Metadata) (result map[string][]string, err error) {
+func (c *connectionImpl) GetObjectsDbSchemas(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords
[]internal.Metadata) (result map[string][]string, err error) {
if depth == adbc.ObjectDepthCatalogs {
return
}
@@ -864,7 +692,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context,
depth adbc.ObjectDepth,
return
}
-func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, dbSchema *string, tableName *string, columnName *string,
tableType []string, metadataRecords []internal.Metadata) (result
internal.SchemaToTableInfo, err error) {
+func (c *connectionImpl) GetObjectsTables(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string,
columnName *string, tableType []string, metadataRecords []internal.Metadata)
(result internal.SchemaToTableInfo, err error) {
if depth == adbc.ObjectDepthCatalogs || depth ==
adbc.ObjectDepthDBSchemas {
return
}
@@ -944,7 +772,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth
adbc.ObjectDepth, cat
return
}
-func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema
*string, tableName string) (*arrow.Schema, error) {
+func (c *connectionImpl) GetTableSchema(ctx context.Context, catalog *string,
dbSchema *string, tableName string) (*arrow.Schema, error) {
opts := &flightsql.GetTablesOpts{
Catalog: catalog,
DbSchemaFilterPattern: dbSchema,
@@ -1023,7 +851,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, catalog
*string, dbSchema *st
// Field Name | Field Type
// ----------------|--------------
// table_type | utf8 not null
-func (c *cnxn) GetTableTypes(ctx context.Context) (array.RecordReader, error) {
+func (c *connectionImpl) GetTableTypes(ctx context.Context)
(array.RecordReader, error) {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
var header, trailer metadata.MD
info, err := c.cl.GetTableTypes(ctx, c.timeouts, grpc.Header(&header),
grpc.Trailer(&trailer))
@@ -1040,18 +868,7 @@ func (c *cnxn) GetTableTypes(ctx context.Context)
(array.RecordReader, error) {
// Behavior is undefined if this is mixed with SQL transaction statements.
// When not supported, the convention is that it should act as if autocommit
// is enabled and return INVALID_STATE errors.
-func (c *cnxn) Commit(ctx context.Context) error {
- if c.txn == nil {
- return adbc.Error{
- Msg: "[Flight SQL] Cannot commit when autocommit is
enabled",
- Code: adbc.StatusInvalidState,
- }
- }
-
- if !c.supportInfo.transactions {
- return errNoTransactionSupport
- }
-
+func (c *connectionImpl) Commit(ctx context.Context) error {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
var header, trailer metadata.MD
err := c.txn.Commit(ctx, c.timeouts, grpc.Header(&header),
grpc.Trailer(&trailer))
@@ -1074,18 +891,7 @@ func (c *cnxn) Commit(ctx context.Context) error {
// Behavior is undefined if this is mixed with SQL transaction statements.
// When not supported, the convention is that it should act as if autocommit
// is enabled and return INVALID_STATE errors.
-func (c *cnxn) Rollback(ctx context.Context) error {
- if c.txn == nil {
- return adbc.Error{
- Msg: "[Flight SQL] Cannot rollback when autocommit is
enabled",
- Code: adbc.StatusInvalidState,
- }
- }
-
- if !c.supportInfo.transactions {
- return errNoTransactionSupport
- }
-
+func (c *connectionImpl) Rollback(ctx context.Context) error {
ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
var header, trailer metadata.MD
err := c.txn.Rollback(ctx, c.timeouts, grpc.Header(&header),
grpc.Trailer(&trailer))
@@ -1103,7 +909,7 @@ func (c *cnxn) Rollback(ctx context.Context) error {
}
// NewStatement initializes a new statement object tied to this connection
-func (c *cnxn) NewStatement() (adbc.Statement, error) {
+func (c *connectionImpl) NewStatement() (adbc.Statement, error) {
return &statement{
alloc: c.db.Alloc,
clientCache: c.clientCache,
@@ -1114,7 +920,7 @@ func (c *cnxn) NewStatement() (adbc.Statement, error) {
}, nil
}
-func (c *cnxn) execute(ctx context.Context, query string, opts
...grpc.CallOption) (*flight.FlightInfo, error) {
+func (c *connectionImpl) execute(ctx context.Context, query string, opts
...grpc.CallOption) (*flight.FlightInfo, error) {
if c.txn != nil {
return c.txn.Execute(ctx, query, opts...)
}
@@ -1122,7 +928,7 @@ func (c *cnxn) execute(ctx context.Context, query string,
opts ...grpc.CallOptio
return c.cl.Execute(ctx, query, opts...)
}
-func (c *cnxn) executeSchema(ctx context.Context, query string, opts
...grpc.CallOption) (*flight.SchemaResult, error) {
+func (c *connectionImpl) executeSchema(ctx context.Context, query string, opts
...grpc.CallOption) (*flight.SchemaResult, error) {
if c.txn != nil {
return c.txn.GetExecuteSchema(ctx, query, opts...)
}
@@ -1130,7 +936,7 @@ func (c *cnxn) executeSchema(ctx context.Context, query
string, opts ...grpc.Cal
return c.cl.GetExecuteSchema(ctx, query, opts...)
}
-func (c *cnxn) executeSubstrait(ctx context.Context, plan
flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) {
+func (c *connectionImpl) executeSubstrait(ctx context.Context, plan
flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) {
if c.txn != nil {
return c.txn.ExecuteSubstrait(ctx, plan, opts...)
}
@@ -1138,7 +944,7 @@ func (c *cnxn) executeSubstrait(ctx context.Context, plan
flightsql.SubstraitPla
return c.cl.ExecuteSubstrait(ctx, plan, opts...)
}
-func (c *cnxn) executeSubstraitSchema(ctx context.Context, plan
flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error)
{
+func (c *connectionImpl) executeSubstraitSchema(ctx context.Context, plan
flightsql.SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error)
{
if c.txn != nil {
return c.txn.GetExecuteSubstraitSchema(ctx, plan, opts...)
}
@@ -1146,7 +952,7 @@ func (c *cnxn) executeSubstraitSchema(ctx context.Context,
plan flightsql.Substr
return c.cl.GetExecuteSubstraitSchema(ctx, plan, opts...)
}
-func (c *cnxn) executeUpdate(ctx context.Context, query string, opts
...grpc.CallOption) (n int64, err error) {
+func (c *connectionImpl) executeUpdate(ctx context.Context, query string, opts
...grpc.CallOption) (n int64, err error) {
if c.txn != nil {
return c.txn.ExecuteUpdate(ctx, query, opts...)
}
@@ -1154,7 +960,7 @@ func (c *cnxn) executeUpdate(ctx context.Context, query
string, opts ...grpc.Cal
return c.cl.ExecuteUpdate(ctx, query, opts...)
}
-func (c *cnxn) executeSubstraitUpdate(ctx context.Context, plan
flightsql.SubstraitPlan, opts ...grpc.CallOption) (n int64, err error) {
+func (c *connectionImpl) executeSubstraitUpdate(ctx context.Context, plan
flightsql.SubstraitPlan, opts ...grpc.CallOption) (n int64, err error) {
if c.txn != nil {
return c.txn.ExecuteSubstraitUpdate(ctx, plan, opts...)
}
@@ -1162,7 +968,7 @@ func (c *cnxn) executeSubstraitUpdate(ctx context.Context,
plan flightsql.Substr
return c.cl.ExecuteSubstraitUpdate(ctx, plan, opts...)
}
-func (c *cnxn) poll(ctx context.Context, query string, retryDescriptor
*flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
+func (c *connectionImpl) poll(ctx context.Context, query string,
retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption)
(*flight.PollInfo, error) {
if c.txn != nil {
return c.txn.ExecutePoll(ctx, query, retryDescriptor, opts...)
}
@@ -1170,7 +976,7 @@ func (c *cnxn) poll(ctx context.Context, query string,
retryDescriptor *flight.F
return c.cl.ExecutePoll(ctx, query, retryDescriptor, opts...)
}
-func (c *cnxn) pollSubstrait(ctx context.Context, plan
flightsql.SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts
...grpc.CallOption) (*flight.PollInfo, error) {
+func (c *connectionImpl) pollSubstrait(ctx context.Context, plan
flightsql.SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts
...grpc.CallOption) (*flight.PollInfo, error) {
if c.txn != nil {
return c.txn.ExecuteSubstraitPoll(ctx, plan, retryDescriptor,
opts...)
}
@@ -1178,7 +984,7 @@ func (c *cnxn) pollSubstrait(ctx context.Context, plan
flightsql.SubstraitPlan,
return c.cl.ExecuteSubstraitPoll(ctx, plan, retryDescriptor, opts...)
}
-func (c *cnxn) prepare(ctx context.Context, query string, opts
...grpc.CallOption) (*flightsql.PreparedStatement, error) {
+func (c *connectionImpl) prepare(ctx context.Context, query string, opts
...grpc.CallOption) (*flightsql.PreparedStatement, error) {
if c.txn != nil {
return c.txn.Prepare(ctx, query, opts...)
}
@@ -1186,7 +992,7 @@ func (c *cnxn) prepare(ctx context.Context, query string,
opts ...grpc.CallOptio
return c.cl.Prepare(ctx, query, opts...)
}
-func (c *cnxn) prepareSubstrait(ctx context.Context, plan
flightsql.SubstraitPlan, opts ...grpc.CallOption)
(*flightsql.PreparedStatement, error) {
+func (c *connectionImpl) prepareSubstrait(ctx context.Context, plan
flightsql.SubstraitPlan, opts ...grpc.CallOption)
(*flightsql.PreparedStatement, error) {
if c.txn != nil {
return c.txn.PrepareSubstrait(ctx, plan, opts...)
}
@@ -1195,7 +1001,7 @@ func (c *cnxn) prepareSubstrait(ctx context.Context, plan
flightsql.SubstraitPla
}
// Close closes this connection and releases any associated resources.
-func (c *cnxn) Close() error {
+func (c *connectionImpl) Close() error {
if c.cl == nil {
return adbc.Error{
Msg: "[Flight SQL Connection] trying to close already
closed connection",
@@ -1225,7 +1031,7 @@ func (c *cnxn) Close() error {
// results can then be read independently using the returned RecordReader.
//
// A partition can be retrieved by using ExecutePartitions on a statement.
-func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte)
(rdr array.RecordReader, err error) {
+func (c *connectionImpl) ReadPartition(ctx context.Context,
serializedPartition []byte) (rdr array.RecordReader, err error) {
var info flight.FlightInfo
if err := proto.Unmarshal(serializedPartition, &info); err != nil {
return nil, adbc.Error{
@@ -1251,5 +1057,5 @@ func (c *cnxn) ReadPartition(ctx context.Context,
serializedPartition []byte) (r
}
var (
- _ adbc.PostInitOptions = (*cnxn)(nil)
+ _ adbc.PostInitOptions = (*connectionImpl)(nil)
)
diff --git a/go/adbc/driver/flightsql/flightsql_database.go
b/go/adbc/driver/flightsql/flightsql_database.go
index 5e5e3af97..9f0848c3f 100644
--- a/go/adbc/driver/flightsql/flightsql_database.go
+++ b/go/adbc/driver/flightsql/flightsql_database.go
@@ -29,7 +29,7 @@ import (
"time"
"github.com/apache/arrow-adbc/go/adbc"
- "github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
+ "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/apache/arrow/go/v16/arrow/array"
"github.com/apache/arrow/go/v16/arrow/flight"
"github.com/apache/arrow/go/v16/arrow/flight/flightsql"
@@ -51,7 +51,6 @@ func (d *dbDialOpts) rebuild() {
d.opts = []grpc.DialOption{
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(d.maxMsgSize),
grpc.MaxCallSendMsgSize(d.maxMsgSize)),
- grpc.WithUserAgent("ADBC Flight SQL Driver " +
infoDriverVersion),
}
if d.block {
d.opts = append(d.opts, grpc.WithBlock())
@@ -383,7 +382,12 @@ func getFlightClient(ctx context.Context, loc string, d
*databaseImpl, authMiddl
creds = insecure.NewCredentials()
target = "unix:" + uri.Path
}
- dialOpts := append(d.dialOpts.opts,
grpc.WithConnectParams(d.timeout.connectParams()),
grpc.WithTransportCredentials(creds))
+
+ driverVersion, ok :=
d.DatabaseImplBase.DriverInfo.GetInfoDriverVersion()
+ if !ok {
+ driverVersion = driverbase.UnknownVersion
+ }
+ dialOpts := append(d.dialOpts.opts,
grpc.WithConnectParams(d.timeout.connectParams()),
grpc.WithTransportCredentials(creds), grpc.WithUserAgent("ADBC Flight SQL
Driver "+driverVersion))
d.Logger.DebugContext(ctx, "new client", "location", loc)
cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
@@ -503,9 +507,18 @@ func (d *databaseImpl) Open(ctx context.Context)
(adbc.Connection, error) {
}
}
- return &cnxn{cl: cl, db: d, clientCache: cache,
- hdrs: make(metadata.MD), timeouts: d.timeout,
- supportInfo: cnxnSupport}, nil
+ conn := &connectionImpl{
+ cl: cl, db: d, clientCache: cache,
+ hdrs: make(metadata.MD), timeouts: d.timeout, supportInfo:
cnxnSupport,
+ ConnectionImplBase:
driverbase.NewConnectionImplBase(&d.DatabaseImplBase),
+ }
+
+ return driverbase.NewConnectionBuilder(conn).
+ WithDriverInfoPreparer(conn).
+ WithAutocommitSetter(conn).
+ WithDbObjectsEnumerator(conn).
+ WithCurrentNamespacer(conn).
+ Connection(), nil
}
type bearerAuthMiddleware struct {
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go
b/go/adbc/driver/flightsql/flightsql_driver.go
index d437f0829..441370a9e 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -33,12 +33,10 @@ package flightsql
import (
"net/url"
- "runtime/debug"
- "strings"
"time"
"github.com/apache/arrow-adbc/go/adbc"
- "github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
+ "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/apache/arrow/go/v16/arrow/memory"
"golang.org/x/exp/maps"
"google.golang.org/grpc/metadata"
@@ -69,56 +67,19 @@ const (
infoDriverName = "ADBC Flight SQL Driver - Go"
)
-var (
- infoDriverVersion string
- infoDriverArrowVersion string
- infoSupportedCodes []adbc.InfoCode
-)
-
var errNoTransactionSupport = adbc.Error{
Msg: "[Flight SQL] server does not report transaction support",
Code: adbc.StatusNotImplemented,
}
-func init() {
- if info, ok := debug.ReadBuildInfo(); ok {
- for _, dep := range info.Deps {
- switch {
- case dep.Path ==
"github.com/apache/arrow-adbc/go/adbc/driver/flightsql":
- infoDriverVersion = dep.Version
- case strings.HasPrefix(dep.Path,
"github.com/apache/arrow/go/"):
- infoDriverArrowVersion = dep.Version
- }
- }
- }
- // XXX: Deps not populated in tests
- // https://github.com/golang/go/issues/33976
- if infoDriverVersion == "" {
- infoDriverVersion = "(unknown or development build)"
- }
- if infoDriverArrowVersion == "" {
- infoDriverArrowVersion = "(unknown or development build)"
- }
-
- infoSupportedCodes = []adbc.InfoCode{
- adbc.InfoDriverName,
- adbc.InfoDriverVersion,
- adbc.InfoDriverArrowVersion,
- adbc.InfoDriverADBCVersion,
- adbc.InfoVendorName,
- adbc.InfoVendorVersion,
- adbc.InfoVendorArrowVersion,
- }
-}
-
type driverImpl struct {
driverbase.DriverImplBase
}
// NewDriver creates a new Flight SQL driver using the given Arrow allocator.
func NewDriver(alloc memory.Allocator) adbc.Driver {
- impl := driverImpl{DriverImplBase: driverbase.NewDriverImplBase("Flight
SQL", alloc)}
- return driverbase.NewDriver(&impl)
+ info := driverbase.DefaultDriverInfo("Flight SQL")
+ return driverbase.NewDriver(&driverImpl{DriverImplBase:
driverbase.NewDriverImplBase(info, alloc)})
}
func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database,
error) {
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go
b/go/adbc/driver/flightsql/flightsql_statement.go
index d78b653c8..c68eba8cd 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -72,7 +72,7 @@ func (s *sqlOrSubstrait) setSubstraitPlan(plan []byte) {
s.substraitPlan = plan
}
-func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *cnxn, opts
...grpc.CallOption) (*flight.FlightInfo, error) {
+func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn *connectionImpl,
opts ...grpc.CallOption) (*flight.FlightInfo, error) {
if s.sqlQuery != "" {
return cnxn.execute(ctx, s.sqlQuery, opts...)
} else if s.substraitPlan != nil {
@@ -85,7 +85,7 @@ func (s *sqlOrSubstrait) execute(ctx context.Context, cnxn
*cnxn, opts ...grpc.C
}
}
-func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn *cnxn, opts
...grpc.CallOption) (*arrow.Schema, error) {
+func (s *sqlOrSubstrait) executeSchema(ctx context.Context, cnxn
*connectionImpl, opts ...grpc.CallOption) (*arrow.Schema, error) {
var (
res *flight.SchemaResult
err error
@@ -108,7 +108,7 @@ func (s *sqlOrSubstrait) executeSchema(ctx context.Context,
cnxn *cnxn, opts ...
return flight.DeserializeSchema(res.Schema, cnxn.cl.Alloc)
}
-func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn *cnxn, opts
...grpc.CallOption) (int64, error) {
+func (s *sqlOrSubstrait) executeUpdate(ctx context.Context, cnxn
*connectionImpl, opts ...grpc.CallOption) (int64, error) {
if s.sqlQuery != "" {
return cnxn.executeUpdate(ctx, s.sqlQuery, opts...)
} else if s.substraitPlan != nil {
@@ -121,7 +121,7 @@ func (s *sqlOrSubstrait) executeUpdate(ctx context.Context,
cnxn *cnxn, opts ...
}
}
-func (s *sqlOrSubstrait) poll(ctx context.Context, cnxn *cnxn, retryDescriptor
*flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
+func (s *sqlOrSubstrait) poll(ctx context.Context, cnxn *connectionImpl,
retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption)
(*flight.PollInfo, error) {
if s.sqlQuery != "" {
return cnxn.poll(ctx, s.sqlQuery, retryDescriptor, opts...)
} else if s.substraitPlan != nil {
@@ -134,7 +134,7 @@ func (s *sqlOrSubstrait) poll(ctx context.Context, cnxn
*cnxn, retryDescriptor *
}
}
-func (s *sqlOrSubstrait) prepare(ctx context.Context, cnxn *cnxn, opts
...grpc.CallOption) (*flightsql.PreparedStatement, error) {
+func (s *sqlOrSubstrait) prepare(ctx context.Context, cnxn *connectionImpl,
opts ...grpc.CallOption) (*flightsql.PreparedStatement, error) {
if s.sqlQuery != "" {
return cnxn.prepare(ctx, s.sqlQuery, opts...)
} else if s.substraitPlan != nil {
@@ -156,7 +156,7 @@ type incrementalState struct {
type statement struct {
alloc memory.Allocator
- cnxn *cnxn
+ cnxn *connectionImpl
clientCache gcache.Cache
hdrs metadata.MD
diff --git a/go/adbc/driver/internal/driverbase/connection.go
b/go/adbc/driver/internal/driverbase/connection.go
new file mode 100644
index 000000000..68b0a9bc6
--- /dev/null
+++ b/go/adbc/driver/internal/driverbase/connection.go
@@ -0,0 +1,497 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package driverbase
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+ "github.com/apache/arrow-adbc/go/adbc/driver/internal"
+ "github.com/apache/arrow/go/v16/arrow"
+ "github.com/apache/arrow/go/v16/arrow/array"
+ "github.com/apache/arrow/go/v16/arrow/memory"
+ "golang.org/x/exp/slog"
+)
+
+const (
+ ConnectionMessageOptionUnknown = "Unknown connection option"
+ ConnectionMessageOptionUnsupported = "Unsupported connection option"
+ ConnectionMessageCannotCommit = "Cannot commit when autocommit is
enabled"
+ ConnectionMessageCannotRollback = "Cannot rollback when autocommit
is enabled"
+)
+
+// ConnectionImpl is an interface that drivers implement to provide
+// vendor-specific functionality.
+type ConnectionImpl interface {
+ adbc.Connection
+ adbc.GetSetOptions
+ Base() *ConnectionImplBase
+}
+
+// CurrentNamespacer is an interface that drivers may implement to delegate
+// stateful namespacing with DB catalogs and schemas. The appropriate
(Get/Set)Options
+// implementations will be provided using the results of these methods.
+type CurrentNamespacer interface {
+ GetCurrentCatalog() (string, error)
+ GetCurrentDbSchema() (string, error)
+ SetCurrentCatalog(string) error
+ SetCurrentDbSchema(string) error
+}
+
+// DriverInfoPreparer is an interface that drivers may implement to add/update
+// DriverInfo values whenever adbc.Connection.GetInfo() is called.
+type DriverInfoPreparer interface {
+ PrepareDriverInfo(ctx context.Context, infoCodes []adbc.InfoCode) error
+}
+
+// TableTypeLister is an interface that drivers may implement to simplify the
+// implementation of adbc.Connection.GetTableTypes() for backends that do not
natively
+// send these values as arrow records. The conversion of the result to a
RecordReader
+// is handled automatically.
+type TableTypeLister interface {
+ ListTableTypes(ctx context.Context) ([]string, error)
+}
+
+// AutocommitSetter is an interface that drivers may implement to simplify the
+// implementation of autocommit state management. There is no need to implement
+// this for backends that do not support autocommit, as this is already the
default
+// behavior. SetAutocommit should only attempt to update the autocommit state
in the
+// backend. Local driver state is automatically updated if the result of this
call
+// does not produce an error. (Get/Set)Options implementations are provided
automatically
+// as well/
+type AutocommitSetter interface {
+ SetAutocommit(enabled bool) error
+}
+
+// DbObjectsEnumerator is an interface that drivers may implement to simplify
the
+// implementation of adbc.Connection.GetObjects(). By independently
implementing lookup
+// for catalogs, dbSchemas and tables, the driverbase is able to provide the
full
+// GetObjects functionality for arbitrary search patterns and lookup depth.
+type DbObjectsEnumerator interface {
+ GetObjectsCatalogs(ctx context.Context, catalog *string) ([]string,
error)
+ GetObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, schema *string, metadataRecords []internal.Metadata)
(map[string][]string, error)
+ GetObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog
*string, schema *string, tableName *string, columnName *string, tableType
[]string, metadataRecords []internal.Metadata)
(map[internal.CatalogAndSchema][]internal.TableInfo, error)
+}
+
+// Connection is the interface satisfied by the result of the NewConnection
constructor,
+// given that an input is provided satisfying the ConnectionImpl interface.
+type Connection interface {
+ adbc.Connection
+ adbc.GetSetOptions
+}
+
+// ConnectionImplBase is a struct that provides default implementations of the
+// ConnectionImpl interface. It is meant to be used as a composite struct for a
+// driver's ConnectionImpl implementation.
+type ConnectionImplBase struct {
+ Alloc memory.Allocator
+ ErrorHelper ErrorHelper
+ DriverInfo *DriverInfo
+ Logger *slog.Logger
+
+ Autocommit bool
+ Closed bool
+}
+
+// NewConnectionImplBase instantiates ConnectionImplBase.
+//
+// - database is a DatabaseImplBase containing the common resources from the
parent
+// database, allowing the Arrow allocator, error handler, and logger to be
reused.
+func NewConnectionImplBase(database *DatabaseImplBase) ConnectionImplBase {
+ return ConnectionImplBase{
+ Alloc: database.Alloc,
+ ErrorHelper: database.ErrorHelper,
+ DriverInfo: database.DriverInfo,
+ Logger: database.Logger,
+ Autocommit: true,
+ Closed: false,
+ }
+}
+
+func (base *ConnectionImplBase) Base() *ConnectionImplBase {
+ return base
+}
+
+func (base *ConnectionImplBase) Commit(ctx context.Context) error {
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Commit")
+}
+
+func (base *ConnectionImplBase) Rollback(context.Context) error {
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Rollback")
+}
+
+func (base *ConnectionImplBase) GetInfo(ctx context.Context, infoCodes
[]adbc.InfoCode) (array.RecordReader, error) {
+
+ if len(infoCodes) == 0 {
+ infoCodes = base.DriverInfo.InfoSupportedCodes()
+ }
+
+ bldr := array.NewRecordBuilder(base.Alloc, adbc.GetInfoSchema)
+ defer bldr.Release()
+ bldr.Reserve(len(infoCodes))
+
+ infoNameBldr := bldr.Field(0).(*array.Uint32Builder)
+ infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder)
+ strInfoBldr :=
infoValueBldr.Child(int(adbc.InfoValueStringType)).(*array.StringBuilder)
+ intInfoBldr :=
infoValueBldr.Child(int(adbc.InfoValueInt64Type)).(*array.Int64Builder)
+
+ for _, code := range infoCodes {
+ switch code {
+ case adbc.InfoDriverName:
+ name, ok := base.DriverInfo.GetInfoDriverName()
+ if !ok {
+ continue
+ }
+
+ infoNameBldr.Append(uint32(code))
+ infoValueBldr.Append(adbc.InfoValueStringType)
+ strInfoBldr.Append(name)
+ case adbc.InfoDriverVersion:
+ version, ok := base.DriverInfo.GetInfoDriverVersion()
+ if !ok {
+ continue
+ }
+
+ infoNameBldr.Append(uint32(code))
+ infoValueBldr.Append(adbc.InfoValueStringType)
+ strInfoBldr.Append(version)
+ case adbc.InfoDriverArrowVersion:
+ arrowVersion, ok :=
base.DriverInfo.GetInfoDriverArrowVersion()
+ if !ok {
+ continue
+ }
+
+ infoNameBldr.Append(uint32(code))
+ infoValueBldr.Append(adbc.InfoValueStringType)
+ strInfoBldr.Append(arrowVersion)
+ case adbc.InfoDriverADBCVersion:
+ adbcVersion, ok :=
base.DriverInfo.GetInfoDriverADBCVersion()
+ if !ok {
+ continue
+ }
+
+ infoNameBldr.Append(uint32(code))
+ infoValueBldr.Append(adbc.InfoValueInt64Type)
+ intInfoBldr.Append(adbcVersion)
+ case adbc.InfoVendorName:
+ name, ok := base.DriverInfo.GetInfoVendorName()
+ if !ok {
+ continue
+ }
+
+ infoNameBldr.Append(uint32(code))
+ infoValueBldr.Append(adbc.InfoValueStringType)
+ strInfoBldr.Append(name)
+ default:
+ infoNameBldr.Append(uint32(code))
+ value, ok := base.DriverInfo.GetInfoForInfoCode(code)
+ if !ok {
+ infoValueBldr.AppendNull()
+ continue
+ }
+
+ // TODO: Handle other custom info types
+ infoValueBldr.Append(adbc.InfoValueStringType)
+ strInfoBldr.Append(fmt.Sprint(value))
+ }
+ }
+
+ final := bldr.NewRecord()
+ defer final.Release()
+ return array.NewRecordReader(adbc.GetInfoSchema, []arrow.Record{final})
+}
+
+func (base *ConnectionImplBase) Close() error {
+ return nil
+}
+
+func (base *ConnectionImplBase) GetObjects(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string,
columnName *string, tableType []string) (array.RecordReader, error) {
+ return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented,
"GetObjects")
+}
+
+func (base *ConnectionImplBase) GetTableSchema(ctx context.Context, catalog
*string, dbSchema *string, tableName string) (*arrow.Schema, error) {
+ return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented,
"GetTableSchema")
+}
+
+func (base *ConnectionImplBase) GetTableTypes(context.Context)
(array.RecordReader, error) {
+ return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented,
"GetTableTypes")
+}
+
+func (base *ConnectionImplBase) NewStatement() (adbc.Statement, error) {
+ return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented,
"NewStatement")
+}
+
+func (base *ConnectionImplBase) ReadPartition(ctx context.Context,
serializedPartition []byte) (array.RecordReader, error) {
+ return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented,
"ReadPartition")
+}
+
+func (base *ConnectionImplBase) GetOption(key string) (string, error) {
+ return "", base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'",
ConnectionMessageOptionUnknown, key)
+}
+
+func (base *ConnectionImplBase) GetOptionBytes(key string) ([]byte, error) {
+ return nil, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'",
ConnectionMessageOptionUnknown, key)
+}
+
+func (base *ConnectionImplBase) GetOptionDouble(key string) (float64, error) {
+ return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'",
ConnectionMessageOptionUnknown, key)
+}
+
+func (base *ConnectionImplBase) GetOptionInt(key string) (int64, error) {
+ return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'",
ConnectionMessageOptionUnknown, key)
+}
+
+func (base *ConnectionImplBase) SetOption(key string, val string) error {
+ switch key {
+ case adbc.OptionKeyAutoCommit:
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s
'%s'", ConnectionMessageOptionUnsupported, key)
+ }
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'",
ConnectionMessageOptionUnknown, key)
+}
+
+func (base *ConnectionImplBase) SetOptionBytes(key string, val []byte) error {
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'",
ConnectionMessageOptionUnknown, key)
+}
+
+func (base *ConnectionImplBase) SetOptionDouble(key string, val float64) error
{
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'",
ConnectionMessageOptionUnknown, key)
+}
+
+func (base *ConnectionImplBase) SetOptionInt(key string, val int64) error {
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'",
ConnectionMessageOptionUnknown, key)
+}
+
+type connection struct {
+ ConnectionImpl
+
+ dbObjectsEnumerator DbObjectsEnumerator
+ currentNamespacer CurrentNamespacer
+ driverInfoPreparer DriverInfoPreparer
+ tableTypeLister TableTypeLister
+ autocommitSetter AutocommitSetter
+}
+
+type ConnectionBuilder struct {
+ connection *connection
+}
+
+func NewConnectionBuilder(impl ConnectionImpl) *ConnectionBuilder {
+ return &ConnectionBuilder{connection: &connection{ConnectionImpl: impl}}
+}
+
+func (b *ConnectionBuilder) WithDbObjectsEnumerator(helper
DbObjectsEnumerator) *ConnectionBuilder {
+ if b == nil {
+ panic("nil ConnectionBuilder: cannot reuse after calling
Connection()")
+ }
+ b.connection.dbObjectsEnumerator = helper
+ return b
+}
+
+func (b *ConnectionBuilder) WithCurrentNamespacer(helper CurrentNamespacer)
*ConnectionBuilder {
+ if b == nil {
+ panic("nil ConnectionBuilder: cannot reuse after calling
Connection()")
+ }
+ b.connection.currentNamespacer = helper
+ return b
+}
+
+func (b *ConnectionBuilder) WithDriverInfoPreparer(helper DriverInfoPreparer)
*ConnectionBuilder {
+ if b == nil {
+ panic("nil ConnectionBuilder: cannot reuse after calling
Connection()")
+ }
+ b.connection.driverInfoPreparer = helper
+ return b
+}
+
+func (b *ConnectionBuilder) WithAutocommitSetter(helper AutocommitSetter)
*ConnectionBuilder {
+ if b == nil {
+ panic("nil ConnectionBuilder: cannot reuse after calling
Connection()")
+ }
+ b.connection.autocommitSetter = helper
+ return b
+}
+
+func (b *ConnectionBuilder) WithTableTypeLister(helper TableTypeLister)
*ConnectionBuilder {
+ if b == nil {
+ panic("nil ConnectionBuilder: cannot reuse after calling
Connection()")
+ }
+ b.connection.tableTypeLister = helper
+ return b
+}
+
+func (b *ConnectionBuilder) Connection() Connection {
+ conn := b.connection
+ b.connection = nil
+ return conn
+}
+
+// GetObjects implements Connection.
+func (cnxn *connection) GetObjects(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string,
columnName *string, tableType []string) (array.RecordReader, error) {
+ helper := cnxn.dbObjectsEnumerator
+
+ // If the dbObjectsEnumerator has not been set, then the driver
implementor has elected to provide their own GetObjects implementation
+ if helper == nil {
+ return cnxn.ConnectionImpl.GetObjects(ctx, depth, catalog,
dbSchema, tableName, columnName, tableType)
+ }
+
+ // To avoid an N+1 query problem, we assume result sets here will fit
in memory and build up a single response.
+ g := internal.GetObjects{Ctx: ctx, Depth: depth, Catalog: catalog,
DbSchema: dbSchema, TableName: tableName, ColumnName: columnName, TableType:
tableType}
+ if err := g.Init(cnxn.Base().Alloc, helper.GetObjectsDbSchemas,
helper.GetObjectsTables); err != nil {
+ return nil, err
+ }
+ defer g.Release()
+
+ catalogs, err := helper.GetObjectsCatalogs(ctx, catalog)
+ if err != nil {
+ return nil, err
+ }
+
+ foundCatalog := false
+ for _, catalog := range catalogs {
+ g.AppendCatalog(catalog)
+ foundCatalog = true
+ }
+
+ // Implementations like Dremio report no catalogs, but still have
schemas
+ if !foundCatalog && depth != adbc.ObjectDepthCatalogs {
+ g.AppendCatalog("")
+ }
+ return g.Finish()
+}
+
+func (cnxn *connection) GetOption(key string) (string, error) {
+ switch key {
+ case adbc.OptionKeyAutoCommit:
+ if cnxn.Base().Autocommit {
+ return adbc.OptionValueEnabled, nil
+ } else {
+ return adbc.OptionValueDisabled, nil
+ }
+ case adbc.OptionKeyCurrentCatalog:
+ if cnxn.currentNamespacer != nil {
+ val, err := cnxn.currentNamespacer.GetCurrentCatalog()
+ if err != nil {
+ return "",
cnxn.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "failed to get current
catalog: %s", err)
+ }
+ return val, nil
+ }
+ case adbc.OptionKeyCurrentDbSchema:
+ if cnxn.currentNamespacer != nil {
+ val, err := cnxn.currentNamespacer.GetCurrentDbSchema()
+ if err != nil {
+ return "",
cnxn.Base().ErrorHelper.Errorf(adbc.StatusNotFound, "failed to get current db
schema: %s", err)
+ }
+ return val, nil
+ }
+ }
+ return cnxn.ConnectionImpl.GetOption(key)
+}
+
+func (cnxn *connection) SetOption(key string, val string) error {
+ switch key {
+ case adbc.OptionKeyAutoCommit:
+ if cnxn.autocommitSetter != nil {
+
+ var autocommit bool
+ switch val {
+ case adbc.OptionValueEnabled:
+ autocommit = true
+ case adbc.OptionValueDisabled:
+ autocommit = false
+ default:
+ return
cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidArgument, "cannot set value %s
for key %s", val, key)
+ }
+
+ err := cnxn.autocommitSetter.SetAutocommit(autocommit)
+ if err == nil {
+ // Only update the driver state if the action
was successful
+ cnxn.Base().Autocommit = autocommit
+ }
+
+ return err
+ }
+ case adbc.OptionKeyCurrentCatalog:
+ if cnxn.currentNamespacer != nil {
+ return cnxn.currentNamespacer.SetCurrentCatalog(val)
+ }
+ case adbc.OptionKeyCurrentDbSchema:
+ if cnxn.currentNamespacer != nil {
+ return cnxn.currentNamespacer.SetCurrentDbSchema(val)
+ }
+ }
+ return cnxn.ConnectionImpl.SetOption(key, val)
+}
+
+func (cnxn *connection) GetInfo(ctx context.Context, infoCodes
[]adbc.InfoCode) (array.RecordReader, error) {
+ if cnxn.driverInfoPreparer != nil {
+ if err := cnxn.driverInfoPreparer.PrepareDriverInfo(ctx,
infoCodes); err != nil {
+ return nil, err
+ }
+ }
+
+ return cnxn.Base().GetInfo(ctx, infoCodes)
+}
+
+func (cnxn *connection) GetTableTypes(ctx context.Context)
(array.RecordReader, error) {
+ if cnxn.tableTypeLister == nil {
+ return cnxn.ConnectionImpl.GetTableTypes(ctx)
+ }
+
+ tableTypes, err := cnxn.tableTypeLister.ListTableTypes(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ bldr := array.NewRecordBuilder(cnxn.Base().Alloc, adbc.TableTypesSchema)
+ defer bldr.Release()
+
+ bldr.Field(0).(*array.StringBuilder).AppendValues(tableTypes, nil)
+ final := bldr.NewRecord()
+ defer final.Release()
+ return array.NewRecordReader(adbc.TableTypesSchema,
[]arrow.Record{final})
+}
+
+func (cnxn *connection) Commit(ctx context.Context) error {
+ if cnxn.Base().Autocommit {
+ return cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidState,
ConnectionMessageCannotCommit)
+ }
+ return cnxn.ConnectionImpl.Commit(ctx)
+}
+
+func (cnxn *connection) Rollback(ctx context.Context) error {
+ if cnxn.Base().Autocommit {
+ return cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidState,
ConnectionMessageCannotRollback)
+ }
+ return cnxn.ConnectionImpl.Rollback(ctx)
+}
+
+func (cnxn *connection) Close() error {
+ if cnxn.Base().Closed {
+ return cnxn.Base().ErrorHelper.Errorf(adbc.StatusInvalidState,
"Trying to close already closed connection")
+ }
+
+ err := cnxn.ConnectionImpl.Close()
+ if err == nil {
+ cnxn.Base().Closed = true
+ }
+
+ return err
+}
+
+var _ ConnectionImpl = (*ConnectionImplBase)(nil)
diff --git a/go/adbc/driver/driverbase/database.go
b/go/adbc/driver/internal/driverbase/database.go
similarity index 52%
rename from go/adbc/driver/driverbase/database.go
rename to go/adbc/driver/internal/driverbase/database.go
index b08b77fca..9ab00967a 100644
--- a/go/adbc/driver/driverbase/database.go
+++ b/go/adbc/driver/internal/driverbase/database.go
@@ -25,14 +25,24 @@ import (
"golang.org/x/exp/slog"
)
+const (
+ DatabaseMessageOptionUnknown = "Unknown database option"
+)
+
// DatabaseImpl is an interface that drivers implement to provide
// vendor-specific functionality.
type DatabaseImpl interface {
+ adbc.Database
adbc.GetSetOptions
Base() *DatabaseImplBase
- Open(context.Context) (adbc.Connection, error)
- Close() error
- SetOptions(map[string]string) error
+}
+
+// Database is the interface satisfied by the result of the NewDatabase
constructor,
+// given an input is provided satisfying the DatabaseImpl interface.
+type Database interface {
+ adbc.Database
+ adbc.GetSetOptions
+ adbc.DatabaseLogging
}
// DatabaseImplBase is a struct that provides default implementations of the
@@ -41,14 +51,16 @@ type DatabaseImpl interface {
type DatabaseImplBase struct {
Alloc memory.Allocator
ErrorHelper ErrorHelper
+ DriverInfo *DriverInfo
Logger *slog.Logger
}
-// NewDatabaseImplBase instantiates DatabaseImplBase. name is the driver's
-// name and is used to construct error messages. alloc is an Arrow allocator
-// to use.
+// NewDatabaseImplBase instantiates DatabaseImplBase.
+//
+// - driver is a DriverImplBase containing the common resources from the
parent
+// driver, allowing the Arrow allocator and error handler to be reused.
func NewDatabaseImplBase(driver *DriverImplBase) DatabaseImplBase {
- return DatabaseImplBase{Alloc: driver.Alloc, ErrorHelper:
driver.ErrorHelper, Logger: nilLogger()}
+ return DatabaseImplBase{Alloc: driver.Alloc, ErrorHelper:
driver.ErrorHelper, DriverInfo: driver.DriverInfo, Logger: nilLogger()}
}
func (base *DatabaseImplBase) Base() *DatabaseImplBase {
@@ -56,97 +68,72 @@ func (base *DatabaseImplBase) Base() *DatabaseImplBase {
}
func (base *DatabaseImplBase) GetOption(key string) (string, error) {
- return "", base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown
database option '%s'", key)
+ return "", base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'",
DatabaseMessageOptionUnknown, key)
}
func (base *DatabaseImplBase) GetOptionBytes(key string) ([]byte, error) {
- return nil, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown
database option '%s'", key)
+ return nil, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'",
DatabaseMessageOptionUnknown, key)
}
func (base *DatabaseImplBase) GetOptionDouble(key string) (float64, error) {
- return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown
database option '%s'", key)
+ return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'",
DatabaseMessageOptionUnknown, key)
}
func (base *DatabaseImplBase) GetOptionInt(key string) (int64, error) {
- return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "Unknown
database option '%s'", key)
+ return 0, base.ErrorHelper.Errorf(adbc.StatusNotFound, "%s '%s'",
DatabaseMessageOptionUnknown, key)
}
func (base *DatabaseImplBase) SetOption(key string, val string) error {
- return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown
database option '%s'", key)
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'",
DatabaseMessageOptionUnknown, key)
}
func (base *DatabaseImplBase) SetOptionBytes(key string, val []byte) error {
- return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown
database option '%s'", key)
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'",
DatabaseMessageOptionUnknown, key)
}
func (base *DatabaseImplBase) SetOptionDouble(key string, val float64) error {
- return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown
database option '%s'", key)
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'",
DatabaseMessageOptionUnknown, key)
}
func (base *DatabaseImplBase) SetOptionInt(key string, val int64) error {
- return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Unknown
database option '%s'", key)
-}
-
-// database is the implementation of adbc.Database.
-type database struct {
- impl DatabaseImpl
-}
-
-// NewDatabase wraps a DatabaseImpl to create an adbc.Database.
-func NewDatabase(impl DatabaseImpl) adbc.Database {
- return &database{
- impl: impl,
- }
-}
-
-func (db *database) GetOption(key string) (string, error) {
- return db.impl.GetOption(key)
-}
-
-func (db *database) GetOptionBytes(key string) ([]byte, error) {
- return db.impl.GetOptionBytes(key)
-}
-
-func (db *database) GetOptionDouble(key string) (float64, error) {
- return db.impl.GetOptionDouble(key)
-}
-
-func (db *database) GetOptionInt(key string) (int64, error) {
- return db.impl.GetOptionInt(key)
-}
-
-func (db *database) SetOption(key string, val string) error {
- return db.impl.SetOption(key, val)
+ return base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "%s '%s'",
DatabaseMessageOptionUnknown, key)
}
-func (db *database) SetOptionBytes(key string, val []byte) error {
- return db.impl.SetOptionBytes(key, val)
+func (base *DatabaseImplBase) Close() error {
+ return nil
}
-func (db *database) SetOptionDouble(key string, val float64) error {
- return db.impl.SetOptionDouble(key, val)
+func (base *DatabaseImplBase) Open(ctx context.Context) (adbc.Connection,
error) {
+ return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented, "Open")
}
-func (db *database) SetOptionInt(key string, val int64) error {
- return db.impl.SetOptionInt(key, val)
+func (base *DatabaseImplBase) SetOptions(options map[string]string) error {
+ for key, val := range options {
+ if err := base.SetOption(key, val); err != nil {
+ return err
+ }
+ }
+ return nil
}
-func (db *database) Open(ctx context.Context) (adbc.Connection, error) {
- return db.impl.Open(ctx)
+// database is the implementation of adbc.Database.
+type database struct {
+ DatabaseImpl
}
-func (db *database) Close() error {
- return db.impl.Close()
+// NewDatabase wraps a DatabaseImpl to create an adbc.Database.
+func NewDatabase(impl DatabaseImpl) Database {
+ return &database{
+ DatabaseImpl: impl,
+ }
}
func (db *database) SetLogger(logger *slog.Logger) {
if logger != nil {
- db.impl.Base().Logger = logger
+ db.Base().Logger = logger
} else {
- db.impl.Base().Logger = nilLogger()
+ db.Base().Logger = nilLogger()
}
}
-func (db *database) SetOptions(opts map[string]string) error {
- return db.impl.SetOptions(opts)
-}
+var _ DatabaseImpl = (*DatabaseImplBase)(nil)
diff --git a/go/adbc/driver/internal/driverbase/driver.go
b/go/adbc/driver/internal/driverbase/driver.go
new file mode 100644
index 000000000..bd3e11c08
--- /dev/null
+++ b/go/adbc/driver/internal/driverbase/driver.go
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Package driverbase provides a framework for implementing ADBC drivers in
+// Go. It intends to reduce boilerplate for common functionality and managing
+// state transitions.
+package driverbase
+
+import (
+ "runtime/debug"
+ "strings"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+ "github.com/apache/arrow/go/v16/arrow/memory"
+)
+
+var (
+ infoDriverVersion string
+ infoDriverArrowVersion string
+)
+
+func init() {
+ if info, ok := debug.ReadBuildInfo(); ok {
+ for _, dep := range info.Deps {
+ switch {
+ case dep.Path == "github.com/apache/arrow-adbc/go/adbc":
+ infoDriverVersion = dep.Version
+ case strings.HasPrefix(dep.Path,
"github.com/apache/arrow/go/"):
+ infoDriverArrowVersion = dep.Version
+ }
+ }
+ }
+}
+
+// DriverImpl is an interface that drivers implement to provide
+// vendor-specific functionality.
+type DriverImpl interface {
+ adbc.Driver
+ Base() *DriverImplBase
+}
+
+// Driver is the interface satisfied by the result of the NewDriver
constructor,
+// given an input is provided satisfying the DriverImpl interface.
+type Driver interface {
+ adbc.Driver
+}
+
+// DriverImplBase is a struct that provides default implementations of the
+// DriverImpl interface. It is meant to be used as a composite struct for a
+// driver's DriverImpl implementation.
+type DriverImplBase struct {
+ Alloc memory.Allocator
+ ErrorHelper ErrorHelper
+ DriverInfo *DriverInfo
+}
+
+func (base *DriverImplBase) NewDatabase(opts map[string]string)
(adbc.Database, error) {
+ return nil, base.ErrorHelper.Errorf(adbc.StatusNotImplemented,
"NewDatabase")
+}
+
+// NewDriverImplBase instantiates DriverImplBase.
+//
+// - info contains build and vendor info, as well as the name to construct
error messages.
+// - alloc is an Arrow allocator to use.
+func NewDriverImplBase(info *DriverInfo, alloc memory.Allocator)
DriverImplBase {
+ if alloc == nil {
+ alloc = memory.DefaultAllocator
+ }
+
+ if infoDriverVersion != "" {
+ if err := info.RegisterInfoCode(adbc.InfoDriverVersion,
infoDriverVersion); err != nil {
+ panic(err)
+ }
+ }
+
+ if infoDriverArrowVersion != "" {
+ if err := info.RegisterInfoCode(adbc.InfoDriverArrowVersion,
infoDriverArrowVersion); err != nil {
+ panic(err)
+ }
+ }
+
+ return DriverImplBase{
+ Alloc: alloc,
+ ErrorHelper: ErrorHelper{DriverName: info.GetName()},
+ DriverInfo: info,
+ }
+}
+
+func (base *DriverImplBase) Base() *DriverImplBase {
+ return base
+}
+
+type driver struct {
+ DriverImpl
+}
+
+// NewDriver wraps a DriverImpl to create a Driver.
+func NewDriver(impl DriverImpl) Driver {
+ return &driver{DriverImpl: impl}
+}
+
+var _ DriverImpl = (*DriverImplBase)(nil)
diff --git a/go/adbc/driver/internal/driverbase/driver_info.go
b/go/adbc/driver/internal/driverbase/driver_info.go
new file mode 100644
index 000000000..e68aa16c2
--- /dev/null
+++ b/go/adbc/driver/internal/driverbase/driver_info.go
@@ -0,0 +1,176 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package driverbase
+
+import (
+ "fmt"
+ "sort"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+)
+
+const (
+ UnknownVersion = "(unknown or development build)"
+ DefaultInfoDriverADBCVersion = adbc.AdbcVersion1_1_0
+)
+
+func DefaultDriverInfo(name string) *DriverInfo {
+ defaultInfoVendorName := name
+ defaultInfoDriverName := fmt.Sprintf("ADBC %s Driver - Go", name)
+
+ return &DriverInfo{
+ name: name,
+ info: map[adbc.InfoCode]any{
+ adbc.InfoVendorName: defaultInfoVendorName,
+ adbc.InfoDriverName: defaultInfoDriverName,
+ adbc.InfoDriverVersion: UnknownVersion,
+ adbc.InfoDriverArrowVersion: UnknownVersion,
+ adbc.InfoVendorVersion: UnknownVersion,
+ adbc.InfoVendorArrowVersion: UnknownVersion,
+ adbc.InfoDriverADBCVersion:
DefaultInfoDriverADBCVersion,
+ },
+ }
+}
+
+type DriverInfo struct {
+ name string
+ info map[adbc.InfoCode]any
+}
+
+func (di *DriverInfo) GetName() string { return di.name }
+
+func (di *DriverInfo) InfoSupportedCodes() []adbc.InfoCode {
+ // The keys of the info map are used to determine which info codes are
supported.
+ // This means that any info codes the driver knows about should be set
to some default
+ // at init, even if we don't know the value yet.
+ codes := make([]adbc.InfoCode, 0, len(di.info))
+ for code := range di.info {
+ codes = append(codes, code)
+ }
+
+ // Sorting info codes helps present them to the client in a consistent
way.
+ // It also helps add some determinism to internal tests.
+ // The ordering is in no way part of the API contract and should not be
relied upon.
+ sort.SliceStable(codes, func(i, j int) bool {
+ return codes[i] < codes[j]
+ })
+ return codes
+}
+
+func (di *DriverInfo) RegisterInfoCode(code adbc.InfoCode, value any) error {
+ switch code {
+ case adbc.InfoVendorName:
+ if err := ensureType[string](value); err != nil {
+ return fmt.Errorf("info_code %d: %w", code, err)
+ }
+ case adbc.InfoVendorVersion:
+ if err := ensureType[string](value); err != nil {
+ return fmt.Errorf("info_code %d: %w", code, err)
+ }
+ case adbc.InfoVendorArrowVersion:
+ if err := ensureType[string](value); err != nil {
+ return fmt.Errorf("info_code %d: %w", code, err)
+ }
+ case adbc.InfoDriverName:
+ if err := ensureType[string](value); err != nil {
+ return fmt.Errorf("info_code %d: %w", code, err)
+ }
+ case adbc.InfoDriverVersion:
+ if err := ensureType[string](value); err != nil {
+ return fmt.Errorf("info_code %d: %w", code, err)
+ }
+ case adbc.InfoDriverArrowVersion:
+ if err := ensureType[string](value); err != nil {
+ return fmt.Errorf("info_code %d: %w", code, err)
+ }
+ case adbc.InfoDriverADBCVersion:
+ if err := ensureType[int64](value); err != nil {
+ return fmt.Errorf("info_code %d: %w", code, err)
+ }
+ }
+
+ di.info[code] = value
+ return nil
+}
+
+func (di *DriverInfo) GetInfoForInfoCode(code adbc.InfoCode) (any, bool) {
+ val, ok := di.info[code]
+ return val, ok
+}
+
+func (di *DriverInfo) GetInfoVendorName() (string, bool) {
+ return di.getStringInfoCode(adbc.InfoVendorName)
+}
+
+func (di *DriverInfo) GetInfoVendorVersion() (string, bool) {
+ return di.getStringInfoCode(adbc.InfoVendorVersion)
+}
+
+func (di *DriverInfo) GetInfoVendorArrowVersion() (string, bool) {
+ return di.getStringInfoCode(adbc.InfoVendorArrowVersion)
+}
+
+func (di *DriverInfo) GetInfoDriverName() (string, bool) {
+ return di.getStringInfoCode(adbc.InfoDriverName)
+}
+
+func (di *DriverInfo) GetInfoDriverVersion() (string, bool) {
+ return di.getStringInfoCode(adbc.InfoDriverVersion)
+}
+
+func (di *DriverInfo) GetInfoDriverArrowVersion() (string, bool) {
+ return di.getStringInfoCode(adbc.InfoDriverArrowVersion)
+}
+
+func (di *DriverInfo) GetInfoDriverADBCVersion() (int64, bool) {
+ return di.getInt64InfoCode(adbc.InfoDriverADBCVersion)
+}
+
+func (di *DriverInfo) getStringInfoCode(code adbc.InfoCode) (string, bool) {
+ val, ok := di.GetInfoForInfoCode(code)
+ if !ok {
+ return "", false
+ }
+
+ if err := ensureType[string](val); err != nil {
+ panic(err)
+ }
+
+ return val.(string), true
+}
+
+func (di *DriverInfo) getInt64InfoCode(code adbc.InfoCode) (int64, bool) {
+ val, ok := di.GetInfoForInfoCode(code)
+ if !ok {
+ return int64(0), false
+ }
+
+ if err := ensureType[int64](val); err != nil {
+ panic(err)
+ }
+
+ return val.(int64), true
+}
+
+func ensureType[T any](value any) error {
+ typedVal, ok := value.(T)
+ if !ok {
+ return fmt.Errorf("expected info_value %v to be of type %T but
found %T", value, typedVal, value)
+ }
+ return nil
+}
diff --git a/go/adbc/driver/internal/driverbase/driver_info_test.go
b/go/adbc/driver/internal/driverbase/driver_info_test.go
new file mode 100644
index 000000000..2bad25d05
--- /dev/null
+++ b/go/adbc/driver/internal/driverbase/driver_info_test.go
@@ -0,0 +1,88 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package driverbase_test
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+ "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
+ "github.com/stretchr/testify/require"
+)
+
+func TestDriverInfo(t *testing.T) {
+ driverInfo := driverbase.DefaultDriverInfo("test")
+
+ // The provided name is used for ErrorHelper, certain info code values,
etc
+ require.Equal(t, "test", driverInfo.GetName())
+
+ // These are the info codes that are set for every driver
+ expectedDefaultInfoCodes := []adbc.InfoCode{
+ adbc.InfoVendorName,
+ adbc.InfoVendorVersion,
+ adbc.InfoVendorArrowVersion,
+ adbc.InfoDriverName,
+ adbc.InfoDriverVersion,
+ adbc.InfoDriverArrowVersion,
+ adbc.InfoDriverADBCVersion,
+ }
+ require.ElementsMatch(t, expectedDefaultInfoCodes,
driverInfo.InfoSupportedCodes())
+
+ // We get some formatted default values out of the box
+ vendorName, ok := driverInfo.GetInfoVendorName()
+ require.True(t, ok)
+ require.Equal(t, "test", vendorName)
+
+ driverName, ok := driverInfo.GetInfoDriverName()
+ require.True(t, ok)
+ require.Equal(t, "ADBC test Driver - Go", driverName)
+
+ // We can register a string value to an info code that expects a string
+ require.NoError(t, driverInfo.RegisterInfoCode(adbc.InfoDriverVersion,
"string_value"))
+
+ // We cannot register a non-string value to that same info code
+ err := driverInfo.RegisterInfoCode(adbc.InfoDriverVersion, 123)
+ require.Error(t, err)
+ require.Equal(t, "info_code 101: expected info_value 123 to be of type
string but found int", err.Error())
+
+ // We can also set vendor-specific info codes but they won't get type
checked
+ require.NoError(t, driverInfo.RegisterInfoCode(adbc.InfoCode(10_001),
"string_value"))
+ require.NoError(t, driverInfo.RegisterInfoCode(adbc.InfoCode(10_001),
123))
+
+ // Retrieving known info codes is type-safe
+ driverVersion, ok := driverInfo.GetInfoDriverName()
+ require.True(t, ok)
+ require.NotEmpty(t, strings.Clone(driverVersion)) // do string stuff
+
+ adbcVersion, ok := driverInfo.GetInfoDriverADBCVersion()
+ require.True(t, ok)
+ require.NotEmpty(t, adbcVersion+int64(123)) // do int64 stuff
+
+ // We can also retrieve arbitrary info codes, but the result's type
must be asserted
+ arrowVersion, ok :=
driverInfo.GetInfoForInfoCode(adbc.InfoDriverArrowVersion)
+ require.True(t, ok)
+ _, ok = arrowVersion.(string)
+ require.True(t, ok)
+
+ // We can check if info codes have been set or not
+ _, ok = driverInfo.GetInfoForInfoCode(adbc.InfoCode(10_001))
+ require.True(t, ok)
+ _, ok = driverInfo.GetInfoForInfoCode(adbc.InfoCode(10_002))
+ require.False(t, ok)
+}
diff --git a/go/adbc/driver/internal/driverbase/driver_test.go
b/go/adbc/driver/internal/driverbase/driver_test.go
new file mode 100644
index 000000000..f43a049bb
--- /dev/null
+++ b/go/adbc/driver/internal/driverbase/driver_test.go
@@ -0,0 +1,595 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package driverbase_test
+
+import (
+ "context"
+ "fmt"
+ "testing"
+
+ "golang.org/x/exp/slog"
+
+ "github.com/apache/arrow-adbc/go/adbc"
+ "github.com/apache/arrow-adbc/go/adbc/driver/internal"
+ "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
+ "github.com/apache/arrow/go/v16/arrow"
+ "github.com/apache/arrow/go/v16/arrow/array"
+ "github.com/apache/arrow/go/v16/arrow/memory"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+)
+
+const (
+ OptionKeyRecognized = "recognized"
+ OptionKeyUnrecognized = "unrecognized"
+)
+
+// NewDriver creates a new adbc.Driver for testing. In addition to a
memory.Allocator, it takes
+// a slog.Handler to use for all structured logging as well as a useHelpers
flag to determine whether
+// the test should register helper methods or use the default driverbase
implementation.
+func NewDriver(alloc memory.Allocator, handler slog.Handler, useHelpers bool)
adbc.Driver {
+ info := driverbase.DefaultDriverInfo("MockDriver")
+ _ = info.RegisterInfoCode(adbc.InfoCode(10_001), "my custom info")
+ return driverbase.NewDriver(&driverImpl{DriverImplBase:
driverbase.NewDriverImplBase(info, alloc), handler: handler, useHelpers:
useHelpers})
+}
+
+func TestDefaultDriver(t *testing.T) {
+ var handler MockedHandler
+ handler.On("Handle", mock.Anything, mock.Anything).Return(nil)
+
+ ctx := context.TODO()
+ alloc := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer alloc.AssertSize(t, 0)
+
+ drv := NewDriver(alloc, &handler, false) // Do not use helper
implementations; only default behavior
+
+ db, err := drv.NewDatabase(nil)
+ require.NoError(t, err)
+ defer db.Close()
+
+ require.NoError(t, db.SetOptions(map[string]string{OptionKeyRecognized:
"should-pass"}))
+
+ err = db.SetOptions(map[string]string{OptionKeyUnrecognized:
"should-fail"})
+ require.Error(t, err)
+ require.Equal(t, "Not Implemented: [MockDriver] Unknown database option
'unrecognized'", err.Error())
+
+ cnxn, err := db.Open(ctx)
+ require.NoError(t, err)
+ defer func() {
+ // Cannot close more than once
+ require.NoError(t, cnxn.Close())
+ require.Error(t, cnxn.Close())
+ }()
+
+ err = cnxn.Commit(ctx)
+ require.Error(t, err)
+ require.Equal(t, "Invalid State: [MockDriver] Cannot commit when
autocommit is enabled", err.Error())
+
+ err = cnxn.Rollback(ctx)
+ require.Error(t, err)
+ require.Equal(t, "Invalid State: [MockDriver] Cannot rollback when
autocommit is enabled", err.Error())
+
+ info, err := cnxn.GetInfo(ctx, nil)
+ require.NoError(t, err)
+ getInfoTable := tableFromRecordReader(info)
+ defer getInfoTable.Release()
+
+ // This is what the driverbase provided GetInfo result should look like
out of the box,
+ // with one custom setting registered at initialization
+ expectedGetInfoTable, err := array.TableFromJSON(alloc,
adbc.GetInfoSchema, []string{`[
+ {
+ "info_name": 0,
+ "info_value": [0, "MockDriver"]
+ },
+ {
+ "info_name": 1,
+ "info_value": [0, "(unknown or development build)"]
+ },
+ {
+ "info_name": 2,
+ "info_value": [0, "(unknown or development build)"]
+ },
+ {
+ "info_name": 100,
+ "info_value": [0, "ADBC MockDriver Driver - Go"]
+ },
+ {
+ "info_name": 101,
+ "info_value": [0, "(unknown or development build)"]
+ },
+ {
+ "info_name": 102,
+ "info_value": [0, "(unknown or development build)"]
+ },
+ {
+ "info_name": 103,
+ "info_value": [2, 1001000]
+ },
+ {
+ "info_name": 10001,
+ "info_value": [0, "my custom info"]
+ }
+ ]`})
+ require.NoError(t, err)
+ defer expectedGetInfoTable.Release()
+
+ require.Truef(t, array.TableEqual(expectedGetInfoTable, getInfoTable),
"expected: %s\ngot: %s", expectedGetInfoTable, getInfoTable)
+
+ _, err = cnxn.GetObjects(ctx, adbc.ObjectDepthAll, nil, nil, nil, nil,
nil)
+ require.Error(t, err)
+ require.Equal(t, "Not Implemented: [MockDriver] GetObjects",
err.Error())
+
+ _, err = cnxn.GetTableTypes(ctx)
+ require.Error(t, err)
+ require.Equal(t, "Not Implemented: [MockDriver] GetTableTypes",
err.Error())
+
+ autocommit, err :=
cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyAutoCommit)
+ require.NoError(t, err)
+ require.Equal(t, adbc.OptionValueEnabled, autocommit)
+
+ err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyAutoCommit,
"false")
+ require.Error(t, err)
+ require.Equal(t, "Not Implemented: [MockDriver] Unsupported connection
option 'adbc.connection.autocommit'", err.Error())
+
+ _, err =
cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog)
+ require.Error(t, err)
+ require.Equal(t, "Not Found: [MockDriver] Unknown connection option
'adbc.connection.catalog'", err.Error())
+
+ err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog,
"test_catalog")
+ require.Error(t, err)
+ require.Equal(t, "Not Implemented: [MockDriver] Unknown connection
option 'adbc.connection.catalog'", err.Error())
+
+ // We passed a mock handler into the driver to use for logs, so we can
check actual messages logged
+ expectedLogMessages := []logMessage{
+ {Message: "Opening a new connection", Level: "INFO", Attrs:
map[string]string{"withHelpers": "false"}},
+ }
+
+ logMessages := make([]logMessage, 0, len(handler.Calls))
+ for _, call := range handler.Calls {
+ sr, ok := call.Arguments.Get(1).(slog.Record)
+ require.True(t, ok)
+ logMessages = append(logMessages, newLogMessage(sr))
+ }
+
+ for _, expected := range expectedLogMessages {
+ var found bool
+ for _, message := range logMessages {
+ if messagesEqual(message, expected) {
+ found = true
+ break
+ }
+ }
+ require.Truef(t, found, "expected message was never logged:
%v", expected)
+ }
+
+}
+
+func TestCustomizedDriver(t *testing.T) {
+ var handler MockedHandler
+ handler.On("Handle", mock.Anything, mock.Anything).Return(nil)
+
+ ctx := context.TODO()
+ alloc := memory.NewCheckedAllocator(memory.DefaultAllocator)
+ defer alloc.AssertSize(t, 0)
+
+ drv := NewDriver(alloc, &handler, true) // Use helper implementations
+
+ db, err := drv.NewDatabase(nil)
+ require.NoError(t, err)
+ defer db.Close()
+
+ require.NoError(t, db.SetOptions(map[string]string{OptionKeyRecognized:
"should-pass"}))
+
+ err = db.SetOptions(map[string]string{OptionKeyUnrecognized:
"should-fail"})
+ require.Error(t, err)
+ require.Equal(t, "Not Implemented: [MockDriver] Unknown database option
'unrecognized'", err.Error())
+
+ cnxn, err := db.Open(ctx)
+ require.NoError(t, err)
+ defer cnxn.Close()
+
+ err = cnxn.Commit(ctx)
+ require.Error(t, err)
+ require.Equal(t, "Invalid State: [MockDriver] Cannot commit when
autocommit is enabled", err.Error())
+
+ err = cnxn.Rollback(ctx)
+ require.Error(t, err)
+ require.Equal(t, "Invalid State: [MockDriver] Cannot rollback when
autocommit is enabled", err.Error())
+
+ info, err := cnxn.GetInfo(ctx, nil)
+ require.NoError(t, err)
+ getInfoTable := tableFromRecordReader(info)
+ defer getInfoTable.Release()
+
+ // This is the arrow table representation of GetInfo produced by
merging:
+ // - the default DriverInfo set at initialization
+ // - the DriverInfo set once in the NewDriver constructor
+ // - the DriverInfo set dynamically when GetInfo is called by
implementing DriverInfoPreparer interface
+ expectedGetInfoTable, err := array.TableFromJSON(alloc,
adbc.GetInfoSchema, []string{`[
+ {
+ "info_name": 0,
+ "info_value": [0, "MockDriver"]
+ },
+ {
+ "info_name": 1,
+ "info_value": [0, "(unknown or development build)"]
+ },
+ {
+ "info_name": 2,
+ "info_value": [0, "(unknown or development build)"]
+ },
+ {
+ "info_name": 100,
+ "info_value": [0, "ADBC MockDriver Driver - Go"]
+ },
+ {
+ "info_name": 101,
+ "info_value": [0, "(unknown or development build)"]
+ },
+ {
+ "info_name": 102,
+ "info_value": [0, "(unknown or development build)"]
+ },
+ {
+ "info_name": 103,
+ "info_value": [2, 1001000]
+ },
+ {
+ "info_name": 10001,
+ "info_value": [0, "my custom info"]
+ },
+ {
+ "info_name": 10002,
+ "info_value": [0, "this was fetched dynamically"]
+ }
+ ]`})
+ require.NoError(t, err)
+ defer expectedGetInfoTable.Release()
+
+ require.Truef(t, array.TableEqual(expectedGetInfoTable, getInfoTable),
"expected: %s\ngot: %s", expectedGetInfoTable, getInfoTable)
+
+ dbObjects, err := cnxn.GetObjects(ctx, adbc.ObjectDepthAll, nil, nil,
nil, nil, nil)
+ require.NoError(t, err)
+ dbObjectsTable := tableFromRecordReader(dbObjects)
+ defer dbObjectsTable.Release()
+
+ // This is the arrow table representation of the GetObjects output we
get by implementing
+ // the simplified TableTypeLister interface
+ expectedDbObjectsTable, err := array.TableFromJSON(alloc,
adbc.GetObjectsSchema, []string{`[
+ {
+ "catalog_name": "default",
+ "catalog_db_schemas": [
+ {
+ "db_schema_name": "public",
+ "db_schema_tables": [
+ {
+ "table_name": "foo",
+ "table_type": "TABLE",
+ "table_columns": [],
+ "table_constraints": []
+ }
+ ]
+ },
+ {
+ "db_schema_name": "test",
+ "db_schema_tables": [
+ {
+ "table_name": "bar",
+ "table_type": "TABLE",
+ "table_columns": [],
+ "table_constraints": []
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "catalog_name": "my_db",
+ "catalog_db_schemas": [
+ {
+ "db_schema_name": "public",
+ "db_schema_tables": [
+ {
+ "table_name": "baz",
+ "table_type": "TABLE",
+ "table_columns": [],
+ "table_constraints": []
+ }
+ ]
+ }
+ ]
+ }
+ ]`})
+ require.NoError(t, err)
+ defer expectedDbObjectsTable.Release()
+
+ require.Truef(t, array.TableEqual(expectedDbObjectsTable,
dbObjectsTable), "expected: %s\ngot: %s", expectedDbObjectsTable,
dbObjectsTable)
+
+ tableTypes, err := cnxn.GetTableTypes(ctx)
+ require.NoError(t, err)
+ tableTypeTable := tableFromRecordReader(tableTypes)
+ defer tableTypeTable.Release()
+
+ // This is the arrow table representation of the GetTableTypes output
we get by implementing
+ // the simplified TableTypeLister interface
+ expectedTableTypesTable, err := array.TableFromJSON(alloc,
adbc.TableTypesSchema, []string{`[
+ { "table_type": "TABLE" },
+ { "table_type": "VIEW" }
+ ]`})
+ require.NoError(t, err)
+ defer expectedTableTypesTable.Release()
+
+ require.Truef(t, array.TableEqual(expectedTableTypesTable,
tableTypeTable), "expected: %s\ngot: %s", expectedTableTypesTable,
tableTypeTable)
+
+ autocommit, err :=
cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyAutoCommit)
+ require.NoError(t, err)
+ require.Equal(t, adbc.OptionValueEnabled, autocommit)
+
+ // By implementing AutocommitSetter, we are able to successfully toggle
autocommit
+ err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyAutoCommit,
"false")
+ require.NoError(t, err)
+
+ // We haven't implemented Commit, but we get NotImplemented instead of
InvalidState because
+ // Autocommit has been explicitly disabled
+ err = cnxn.Commit(ctx)
+ require.Error(t, err)
+ require.Equal(t, "Not Implemented: [MockDriver] Commit", err.Error())
+
+ // By implementing CurrentNamespacer, we can now get/set the current
catalog/dbschema
+ // Default current(catalog|dbSchema) is driver-specific, but the stub
implementation falls back
+ // to a 'not found' error instead of 'not implemented'
+ _, err =
cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog)
+ require.Error(t, err)
+ require.Equal(t, "Not Found: [MockDriver] failed to get current
catalog: current catalog is not set", err.Error())
+
+ err = cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog,
"test_catalog")
+ require.NoError(t, err)
+
+ currentCatalog, err :=
cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog)
+ require.NoError(t, err)
+ require.Equal(t, "test_catalog", currentCatalog)
+
+ _, err =
cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema)
+ require.Error(t, err)
+ require.Equal(t, "Not Found: [MockDriver] failed to get current db
schema: current db schema is not set", err.Error())
+
+ err =
cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentDbSchema,
"test_schema")
+ require.NoError(t, err)
+
+ currentDbSchema, err :=
cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema)
+ require.NoError(t, err)
+ require.Equal(t, "test_schema", currentDbSchema)
+
+ // We passed a mock handler into the driver to use for logs, so we can
check actual messages logged
+ expectedLogMessages := []logMessage{
+ {Message: "Opening a new connection", Level: "INFO", Attrs:
map[string]string{"withHelpers": "true"}},
+ {Message: "SetAutocommit", Level: "DEBUG", Attrs:
map[string]string{"enabled": "false"}},
+ {Message: "SetCurrentCatalog", Level: "DEBUG", Attrs:
map[string]string{"val": "test_catalog"}},
+ {Message: "SetCurrentDbSchema", Level: "DEBUG", Attrs:
map[string]string{"val": "test_schema"}},
+ }
+
+ logMessages := make([]logMessage, 0, len(handler.Calls))
+ for _, call := range handler.Calls {
+ sr, ok := call.Arguments.Get(1).(slog.Record)
+ require.True(t, ok)
+ logMessages = append(logMessages, newLogMessage(sr))
+ }
+
+ for _, expected := range expectedLogMessages {
+ var found bool
+ for _, message := range logMessages {
+ if messagesEqual(message, expected) {
+ found = true
+ break
+ }
+ }
+ require.Truef(t, found, "expected message was never logged:
%v", expected)
+ }
+}
+
+type driverImpl struct {
+ driverbase.DriverImplBase
+
+ handler slog.Handler
+ useHelpers bool
+}
+
+func (drv *driverImpl) NewDatabase(opts map[string]string) (adbc.Database,
error) {
+ db := driverbase.NewDatabase(
+ &databaseImpl{DatabaseImplBase:
driverbase.NewDatabaseImplBase(&drv.DriverImplBase),
+ drv: drv,
+ useHelpers: drv.useHelpers,
+ })
+ db.SetLogger(slog.New(drv.handler))
+ return db, nil
+}
+
+type databaseImpl struct {
+ driverbase.DatabaseImplBase
+ drv *driverImpl
+
+ useHelpers bool
+}
+
+// SetOptions implements adbc.Database.
+func (d *databaseImpl) SetOptions(options map[string]string) error {
+ for k, v := range options {
+ if err := d.SetOption(k, v); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Only need to implement keys we recognize.
+// Any other values will fallthrough to default failure message.
+func (d *databaseImpl) SetOption(key, value string) error {
+ switch key {
+ case OptionKeyRecognized:
+ _ = value // pretend to recognize the setting
+ return nil
+ }
+ return d.DatabaseImplBase.SetOption(key, value)
+}
+
+func (db *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
+ db.DatabaseImplBase.Logger.Info("Opening a new connection",
"withHelpers", db.useHelpers)
+ cnxn := &connectionImpl{ConnectionImplBase:
driverbase.NewConnectionImplBase(&db.DatabaseImplBase), db: db}
+ bldr := driverbase.NewConnectionBuilder(cnxn)
+ if db.useHelpers { // this toggles between the NewDefaultDriver and
NewCustomizedDriver scenarios
+ return bldr.
+ WithAutocommitSetter(cnxn).
+ WithCurrentNamespacer(cnxn).
+ WithTableTypeLister(cnxn).
+ WithDriverInfoPreparer(cnxn).
+ WithDbObjectsEnumerator(cnxn).
+ Connection(), nil
+ }
+ return bldr.Connection(), nil
+}
+
+type connectionImpl struct {
+ driverbase.ConnectionImplBase
+ db *databaseImpl
+
+ currentCatalog string
+ currentDbSchema string
+}
+
+func (c *connectionImpl) SetAutocommit(enabled bool) error {
+ c.Base().Logger.Debug("SetAutocommit", "enabled", enabled)
+ return nil
+}
+
+func (c *connectionImpl) GetCurrentCatalog() (string, error) {
+ if c.currentCatalog == "" {
+ return "", fmt.Errorf("current catalog is not set")
+ }
+ return c.currentCatalog, nil
+}
+
+func (c *connectionImpl) GetCurrentDbSchema() (string, error) {
+ if c.currentDbSchema == "" {
+ return "", fmt.Errorf("current db schema is not set")
+ }
+ return c.currentDbSchema, nil
+}
+
+func (c *connectionImpl) SetCurrentCatalog(val string) error {
+ c.Base().Logger.Debug("SetCurrentCatalog", "val", val)
+ c.currentCatalog = val
+ return nil
+}
+
+func (c *connectionImpl) SetCurrentDbSchema(val string) error {
+ c.Base().Logger.Debug("SetCurrentDbSchema", "val", val)
+ c.currentDbSchema = val
+ return nil
+}
+
+func (c *connectionImpl) ListTableTypes(ctx context.Context) ([]string, error)
{
+ return []string{"TABLE", "VIEW"}, nil
+}
+
+func (c *connectionImpl) PrepareDriverInfo(ctx context.Context, infoCodes
[]adbc.InfoCode) error {
+ return
c.ConnectionImplBase.DriverInfo.RegisterInfoCode(adbc.InfoCode(10_002), "this
was fetched dynamically")
+}
+
+func (c *connectionImpl) GetObjectsCatalogs(ctx context.Context, catalog
*string) ([]string, error) {
+ return []string{"default", "my_db"}, nil
+}
+
+func (c *connectionImpl) GetObjectsDbSchemas(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, schema *string, metadataRecords
[]internal.Metadata) (map[string][]string, error) {
+ return map[string][]string{
+ "default": {"public", "test"},
+ "my_db": {"public"},
+ }, nil
+}
+
+func (c *connectionImpl) GetObjectsTables(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, schema *string, tableName *string,
columnName *string, tableType []string, metadataRecords []internal.Metadata)
(map[internal.CatalogAndSchema][]internal.TableInfo, error) {
+ return map[internal.CatalogAndSchema][]internal.TableInfo{
+ {Catalog: "default", Schema: "public"}:
{internal.TableInfo{Name: "foo", TableType: "TABLE"}},
+ {Catalog: "default", Schema: "test"}:
{internal.TableInfo{Name: "bar", TableType: "TABLE"}},
+ {Catalog: "my_db", Schema: "public"}:
{internal.TableInfo{Name: "baz", TableType: "TABLE"}},
+ }, nil
+}
+
+// MockedHandler is a mock.Mock that implements the slog.Handler interface.
+// It is used to assert specific behavior for loggers it is injected into.
+type MockedHandler struct {
+ mock.Mock
+}
+
+func (h *MockedHandler) Enabled(ctx context.Context, level slog.Level) bool {
return true }
+func (h *MockedHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
return h }
+func (h *MockedHandler) WithGroup(name string) slog.Handler {
return h }
+func (h *MockedHandler) Handle(ctx context.Context, r slog.Record) error {
+ // We only care to assert the message value, and want to isolate
nondetermistic behavior (e.g. timestamp)
+ args := h.Called(ctx, r)
+ return args.Error(0)
+}
+
+// logMessage is a container for log attributes we would like to compare for
equality during tests.
+// It intentionally omits timestamps and other sources of nondeterminism.
+type logMessage struct {
+ Message string
+ Level string
+ Attrs map[string]string
+}
+
+// newLogMessage constructs a logMessage from a slog.Record, containing only
deterministic fields.
+func newLogMessage(r slog.Record) logMessage {
+ message := logMessage{Message: r.Message, Level: r.Level.String(),
Attrs: make(map[string]string)}
+ r.Attrs(func(a slog.Attr) bool {
+ message.Attrs[a.Key] = a.Value.String()
+ return true
+ })
+ return message
+}
+
+// messagesEqual compares two logMessages and returns whether they are equal.
+func messagesEqual(expected, actual logMessage) bool {
+ if expected.Message != actual.Message {
+ return false
+ }
+ if expected.Level != actual.Level {
+ return false
+ }
+ if len(expected.Attrs) != len(actual.Attrs) {
+ return false
+ }
+ for k, v := range expected.Attrs {
+ if actual.Attrs[k] != v {
+ return false
+ }
+ }
+ return true
+}
+
+func tableFromRecordReader(rdr array.RecordReader) arrow.Table {
+ defer rdr.Release()
+
+ recs := make([]arrow.Record, 0)
+ for rdr.Next() {
+ rec := rdr.Record()
+ rec.Retain()
+ defer rec.Release()
+ recs = append(recs, rec)
+ }
+ return array.NewTableFromRecords(rdr.Schema(), recs)
+}
diff --git a/go/adbc/driver/driverbase/error.go
b/go/adbc/driver/internal/driverbase/error.go
similarity index 100%
rename from go/adbc/driver/driverbase/error.go
rename to go/adbc/driver/internal/driverbase/error.go
diff --git a/go/adbc/driver/driverbase/logging.go
b/go/adbc/driver/internal/driverbase/logging.go
similarity index 100%
rename from go/adbc/driver/driverbase/logging.go
rename to go/adbc/driver/internal/driverbase/logging.go
diff --git a/go/adbc/driver/snowflake/connection.go
b/go/adbc/driver/snowflake/connection.go
index 825266507..4b023b050 100644
--- a/go/adbc/driver/snowflake/connection.go
+++ b/go/adbc/driver/snowflake/connection.go
@@ -30,6 +30,7 @@ import (
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-adbc/go/adbc/driver/internal"
+ "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/array"
"github.com/snowflakedb/gosnowflake"
@@ -50,7 +51,9 @@ type snowflakeConn interface {
QueryArrowStream(context.Context, string, ...driver.NamedValue)
(gosnowflake.ArrowStreamLoader, error)
}
-type cnxn struct {
+type connectionImpl struct {
+ driverbase.ConnectionImplBase
+
cn snowflakeConn
db *databaseImpl
ctor gosnowflake.Connector
@@ -60,6 +63,59 @@ type cnxn struct {
useHighPrecision bool
}
+// ListTableTypes implements driverbase.TableTypeLister.
+func (*connectionImpl) ListTableTypes(ctx context.Context) ([]string, error) {
+ return []string{"BASE TABLE", "TEMPORARY TABLE", "VIEW"}, nil
+}
+
+// GetCurrentCatalog implements driverbase.CurrentNamespacer.
+func (c *connectionImpl) GetCurrentCatalog() (string, error) {
+ return c.getStringQuery("SELECT CURRENT_DATABASE()")
+}
+
+// GetCurrentDbSchema implements driverbase.CurrentNamespacer.
+func (c *connectionImpl) GetCurrentDbSchema() (string, error) {
+ return c.getStringQuery("SELECT CURRENT_SCHEMA()")
+}
+
+// SetCurrentCatalog implements driverbase.CurrentNamespacer.
+func (c *connectionImpl) SetCurrentCatalog(value string) error {
+ _, err := c.cn.ExecContext(context.Background(), "USE DATABASE ?",
[]driver.NamedValue{{Value: value}})
+ return err
+}
+
+// SetCurrentDbSchema implements driverbase.CurrentNamespacer.
+func (c *connectionImpl) SetCurrentDbSchema(value string) error {
+ _, err := c.cn.ExecContext(context.Background(), "USE SCHEMA ?",
[]driver.NamedValue{{Value: value}})
+ return err
+}
+
+// SetAutocommit implements driverbase.AutocommitSetter.
+func (c *connectionImpl) SetAutocommit(enabled bool) error {
+ if enabled {
+ if c.activeTransaction {
+ _, err := c.cn.ExecContext(context.Background(),
"COMMIT", nil)
+ if err != nil {
+ return errToAdbcErr(adbc.StatusInternal, err)
+ }
+ c.activeTransaction = false
+ }
+ _, err := c.cn.ExecContext(context.Background(), "ALTER SESSION
SET AUTOCOMMIT = true", nil)
+ return err
+ }
+
+ if !c.activeTransaction {
+ _, err := c.cn.ExecContext(context.Background(), "BEGIN", nil)
+ if err != nil {
+ return errToAdbcErr(adbc.StatusInternal, err)
+ }
+ c.activeTransaction = true
+ }
+ _, err := c.cn.ExecContext(context.Background(), "ALTER SESSION SET
AUTOCOMMIT = false", nil)
+ return err
+
+}
+
// Metadata methods
// Generally these methods return an array.RecordReader that
// can be consumed to retrieve metadata about the database as Arrow
@@ -77,80 +133,6 @@ type cnxn struct {
// characters, or "_" to match exactly one character. (See the
// documentation of DatabaseMetaData in JDBC or "Pattern Value Arguments"
// in the ODBC documentation.) Escaping is not currently supported.
-// GetInfo returns metadata about the database/driver.
-//
-// The result is an Arrow dataset with the following schema:
-//
-// Field Name
| Field Type
-// ----------------------------|-----------------------------
-// info_name
| uint32 not null
-// info_value
| INFO_SCHEMA
-//
-// INFO_SCHEMA is a dense union with members:
-//
-// Field Name (Type Code) | Field Type
-// ----------------------------|-----------------------------
-// string_value (0) | utf8
-// bool_value (1) | bool
-// int64_value (2) | int64
-// int32_bitmask (3) | int32
-// string_list (4) |
list<utf8>
-// int32_to_int32_list_map (5) | map<int32, list<int32>>
-//
-// Each metadatum is identified by an integer code. The recognized
-// codes are defined as constants. Codes [0, 10_000) are reserved
-// for ADBC usage. Drivers/vendors will ignore requests for unrecognized
-// codes (the row will be omitted from the result).
-func (c *cnxn) GetInfo(ctx context.Context, infoCodes []adbc.InfoCode)
(array.RecordReader, error) {
- const strValTypeID arrow.UnionTypeCode = 0
- const intValTypeID arrow.UnionTypeCode = 2
-
- if len(infoCodes) == 0 {
- infoCodes = infoSupportedCodes
- }
-
- bldr := array.NewRecordBuilder(c.db.Alloc, adbc.GetInfoSchema)
- defer bldr.Release()
- bldr.Reserve(len(infoCodes))
-
- infoNameBldr := bldr.Field(0).(*array.Uint32Builder)
- infoValueBldr := bldr.Field(1).(*array.DenseUnionBuilder)
- strInfoBldr :=
infoValueBldr.Child(int(strValTypeID)).(*array.StringBuilder)
- intInfoBldr :=
infoValueBldr.Child(int(intValTypeID)).(*array.Int64Builder)
-
- for _, code := range infoCodes {
- switch code {
- case adbc.InfoDriverName:
- infoNameBldr.Append(uint32(code))
- infoValueBldr.Append(strValTypeID)
- strInfoBldr.Append(infoDriverName)
- case adbc.InfoDriverVersion:
- infoNameBldr.Append(uint32(code))
- infoValueBldr.Append(strValTypeID)
- strInfoBldr.Append(infoDriverVersion)
- case adbc.InfoDriverArrowVersion:
- infoNameBldr.Append(uint32(code))
- infoValueBldr.Append(strValTypeID)
- strInfoBldr.Append(infoDriverArrowVersion)
- case adbc.InfoDriverADBCVersion:
- infoNameBldr.Append(uint32(code))
- infoValueBldr.Append(intValTypeID)
- intInfoBldr.Append(adbc.AdbcVersion1_1_0)
- case adbc.InfoVendorName:
- infoNameBldr.Append(uint32(code))
- infoValueBldr.Append(strValTypeID)
- strInfoBldr.Append(infoVendorName)
- default:
- infoNameBldr.Append(uint32(code))
- infoValueBldr.AppendNull()
- }
- }
-
- final := bldr.NewRecord()
- defer final.Release()
- return array.NewRecordReader(adbc.GetInfoSchema, []arrow.Record{final})
-}
-
// GetObjects gets a hierarchical view of all catalogs, database schemas,
// tables, and columns.
//
@@ -238,7 +220,7 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes
[]adbc.InfoCode) (array.Re
//
// All non-empty, non-nil strings should be a search pattern (as described
// earlier).
-func (c *cnxn) GetObjects(ctx context.Context, depth adbc.ObjectDepth, catalog
*string, dbSchema *string, tableName *string, columnName *string, tableType
[]string) (array.RecordReader, error) {
+func (c *connectionImpl) GetObjects(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string,
columnName *string, tableType []string) (array.RecordReader, error) {
metadataRecords, err := c.populateMetadata(ctx, depth, catalog,
dbSchema, tableName, columnName, tableType)
if err != nil {
return nil, err
@@ -266,7 +248,7 @@ func (c *cnxn) GetObjects(ctx context.Context, depth
adbc.ObjectDepth, catalog *
return g.Finish()
}
-func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords
[]internal.Metadata) (result map[string][]string, err error) {
+func (c *connectionImpl) getObjectsDbSchemas(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords
[]internal.Metadata) (result map[string][]string, err error) {
if depth == adbc.ObjectDepthCatalogs {
return
}
@@ -452,7 +434,7 @@ func toXdbcDataType(dt arrow.DataType) (xdbcType
internal.XdbcDataType) {
}
}
-func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, dbSchema *string, tableName *string, columnName *string,
tableType []string, metadataRecords []internal.Metadata) (result
internal.SchemaToTableInfo, err error) {
+func (c *connectionImpl) getObjectsTables(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string,
columnName *string, tableType []string, metadataRecords []internal.Metadata)
(result internal.SchemaToTableInfo, err error) {
if depth == adbc.ObjectDepthCatalogs || depth ==
adbc.ObjectDepthDBSchemas {
return
}
@@ -524,7 +506,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, depth
adbc.ObjectDepth, cat
return
}
-func (c *cnxn) populateMetadata(ctx context.Context, depth adbc.ObjectDepth,
catalog *string, dbSchema *string, tableName *string, columnName *string,
tableType []string) ([]internal.Metadata, error) {
+func (c *connectionImpl) populateMetadata(ctx context.Context, depth
adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string,
columnName *string, tableType []string) ([]internal.Metadata, error) {
var metadataRecords []internal.Metadata
catalogMetadataRecords, err := c.getCatalogsMetadata(ctx)
if err != nil {
@@ -556,7 +538,7 @@ func (c *cnxn) populateMetadata(ctx context.Context, depth
adbc.ObjectDepth, cat
return metadataRecords, nil
}
-func (c *cnxn) getCatalogsMetadata(ctx context.Context) ([]internal.Metadata,
error) {
+func (c *connectionImpl) getCatalogsMetadata(ctx context.Context)
([]internal.Metadata, error) {
metadataRecords := make([]internal.Metadata, 0)
rows, err := c.sqldb.QueryContext(ctx, prepareCatalogsSQL(), nil)
@@ -585,7 +567,7 @@ func (c *cnxn) getCatalogsMetadata(ctx context.Context)
([]internal.Metadata, er
return metadataRecords, nil
}
-func (c *cnxn) getDbSchemasMetadata(ctx context.Context, matchingCatalogNames
[]string, catalog *string, dbSchema *string) ([]internal.Metadata, error) {
+func (c *connectionImpl) getDbSchemasMetadata(ctx context.Context,
matchingCatalogNames []string, catalog *string, dbSchema *string)
([]internal.Metadata, error) {
var metadataRecords []internal.Metadata
query, queryArgs := prepareDbSchemasSQL(matchingCatalogNames, catalog,
dbSchema)
rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...)
@@ -604,7 +586,7 @@ func (c *cnxn) getDbSchemasMetadata(ctx context.Context,
matchingCatalogNames []
return metadataRecords, nil
}
-func (c *cnxn) getTablesMetadata(ctx context.Context, matchingCatalogNames
[]string, catalog *string, dbSchema *string, tableName *string, tableType
[]string) ([]internal.Metadata, error) {
+func (c *connectionImpl) getTablesMetadata(ctx context.Context,
matchingCatalogNames []string, catalog *string, dbSchema *string, tableName
*string, tableType []string) ([]internal.Metadata, error) {
metadataRecords := make([]internal.Metadata, 0)
query, queryArgs := prepareTablesSQL(matchingCatalogNames, catalog,
dbSchema, tableName, tableType)
rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...)
@@ -623,7 +605,7 @@ func (c *cnxn) getTablesMetadata(ctx context.Context,
matchingCatalogNames []str
return metadataRecords, nil
}
-func (c *cnxn) getColumnsMetadata(ctx context.Context, matchingCatalogNames
[]string, catalog *string, dbSchema *string, tableName *string, columnName
*string, tableType []string) ([]internal.Metadata, error) {
+func (c *connectionImpl) getColumnsMetadata(ctx context.Context,
matchingCatalogNames []string, catalog *string, dbSchema *string, tableName
*string, columnName *string, tableType []string) ([]internal.Metadata, error) {
metadataRecords := make([]internal.Metadata, 0)
query, queryArgs := prepareColumnsSQL(matchingCatalogNames, catalog,
dbSchema, tableName, columnName, tableType)
rows, err := c.sqldb.QueryContext(ctx, query, queryArgs...)
@@ -870,29 +852,7 @@ func descToField(name, typ, isnull, primary string,
comment sql.NullString) (fie
return
}
-func (c *cnxn) GetOption(key string) (string, error) {
- switch key {
- case adbc.OptionKeyAutoCommit:
- if c.activeTransaction {
- // No autocommit
- return adbc.OptionValueDisabled, nil
- } else {
- // Autocommit
- return adbc.OptionValueEnabled, nil
- }
- case adbc.OptionKeyCurrentCatalog:
- return c.getStringQuery("SELECT CURRENT_DATABASE()")
- case adbc.OptionKeyCurrentDbSchema:
- return c.getStringQuery("SELECT CURRENT_SCHEMA()")
- }
-
- return "", adbc.Error{
- Msg: "[Snowflake] unknown connection option",
- Code: adbc.StatusNotFound,
- }
-}
-
-func (c *cnxn) getStringQuery(query string) (string, error) {
+func (c *connectionImpl) getStringQuery(query string) (string, error) {
result, err := c.cn.QueryContext(context.Background(), query, nil)
if err != nil {
return "", errToAdbcErr(adbc.StatusInternal, err)
@@ -928,28 +888,7 @@ func (c *cnxn) getStringQuery(query string) (string,
error) {
return value, nil
}
-func (c *cnxn) GetOptionBytes(key string) ([]byte, error) {
- return nil, adbc.Error{
- Msg: "[Snowflake] unknown connection option",
- Code: adbc.StatusNotFound,
- }
-}
-
-func (c *cnxn) GetOptionInt(key string) (int64, error) {
- return 0, adbc.Error{
- Msg: "[Snowflake] unknown connection option",
- Code: adbc.StatusNotFound,
- }
-}
-
-func (c *cnxn) GetOptionDouble(key string) (float64, error) {
- return 0.0, adbc.Error{
- Msg: "[Snowflake] unknown connection option",
- Code: adbc.StatusNotFound,
- }
-}
-
-func (c *cnxn) GetTableSchema(ctx context.Context, catalog *string, dbSchema
*string, tableName string) (*arrow.Schema, error) {
+func (c *connectionImpl) GetTableSchema(ctx context.Context, catalog *string,
dbSchema *string, tableName string) (*arrow.Schema, error) {
tblParts := make([]string, 0, 3)
if catalog != nil {
tblParts = append(tblParts, strconv.Quote(*catalog))
@@ -990,35 +929,11 @@ func (c *cnxn) GetTableSchema(ctx context.Context,
catalog *string, dbSchema *st
return sc, nil
}
-// GetTableTypes returns a list of the table types in the database.
-//
-// The result is an arrow dataset with the following schema:
-//
-// Field Name | Field Type
-// ----------------|--------------
-// table_type | utf8 not null
-func (c *cnxn) GetTableTypes(_ context.Context) (array.RecordReader, error) {
- bldr := array.NewRecordBuilder(c.db.Alloc, adbc.TableTypesSchema)
- defer bldr.Release()
-
- bldr.Field(0).(*array.StringBuilder).AppendValues([]string{"BASE
TABLE", "TEMPORARY TABLE", "VIEW"}, nil)
- final := bldr.NewRecord()
- defer final.Release()
- return array.NewRecordReader(adbc.TableTypesSchema,
[]arrow.Record{final})
-}
-
// Commit commits any pending transactions on this connection, it should
// only be used if autocommit is disabled.
//
// Behavior is undefined if this is mixed with SQL transaction statements.
-func (c *cnxn) Commit(_ context.Context) error {
- if !c.activeTransaction {
- return adbc.Error{
- Msg: "no active transaction, cannot commit",
- Code: adbc.StatusInvalidState,
- }
- }
-
+func (c *connectionImpl) Commit(_ context.Context) error {
_, err := c.cn.ExecContext(context.Background(), "COMMIT", nil)
if err != nil {
return errToAdbcErr(adbc.StatusInternal, err)
@@ -1032,14 +947,7 @@ func (c *cnxn) Commit(_ context.Context) error {
// is disabled.
//
// Behavior is undefined if this is mixed with SQL transaction statements.
-func (c *cnxn) Rollback(_ context.Context) error {
- if !c.activeTransaction {
- return adbc.Error{
- Msg: "no active transaction, cannot rollback",
- Code: adbc.StatusInvalidState,
- }
- }
-
+func (c *connectionImpl) Rollback(_ context.Context) error {
_, err := c.cn.ExecContext(context.Background(), "ROLLBACK", nil)
if err != nil {
return errToAdbcErr(adbc.StatusInternal, err)
@@ -1050,7 +958,7 @@ func (c *cnxn) Rollback(_ context.Context) error {
}
// NewStatement initializes a new statement object tied to this connection
-func (c *cnxn) NewStatement() (adbc.Statement, error) {
+func (c *connectionImpl) NewStatement() (adbc.Statement, error) {
defaultIngestOptions := DefaultIngestOptions()
return &statement{
alloc: c.db.Alloc,
@@ -1063,7 +971,7 @@ func (c *cnxn) NewStatement() (adbc.Statement, error) {
}
// Close closes this connection and releases any associated resources.
-func (c *cnxn) Close() error {
+func (c *connectionImpl) Close() error {
if c.sqldb == nil || c.cn == nil {
return adbc.Error{Code: adbc.StatusInvalidState}
}
@@ -1083,49 +991,15 @@ func (c *cnxn) Close() error {
// results can then be read independently using the returned RecordReader.
//
// A partition can be retrieved by using ExecutePartitions on a statement.
-func (c *cnxn) ReadPartition(ctx context.Context, serializedPartition []byte)
(array.RecordReader, error) {
+func (c *connectionImpl) ReadPartition(ctx context.Context,
serializedPartition []byte) (array.RecordReader, error) {
return nil, adbc.Error{
Code: adbc.StatusNotImplemented,
Msg: "ReadPartition not yet implemented for snowflake driver",
}
}
-func (c *cnxn) SetOption(key, value string) error {
+func (c *connectionImpl) SetOption(key, value string) error {
switch key {
- case adbc.OptionKeyAutoCommit:
- switch value {
- case adbc.OptionValueEnabled:
- if c.activeTransaction {
- _, err :=
c.cn.ExecContext(context.Background(), "COMMIT", nil)
- if err != nil {
- return
errToAdbcErr(adbc.StatusInternal, err)
- }
- c.activeTransaction = false
- }
- _, err := c.cn.ExecContext(context.Background(), "ALTER
SESSION SET AUTOCOMMIT = true", nil)
- return err
- case adbc.OptionValueDisabled:
- if !c.activeTransaction {
- _, err :=
c.cn.ExecContext(context.Background(), "BEGIN", nil)
- if err != nil {
- return
errToAdbcErr(adbc.StatusInternal, err)
- }
- c.activeTransaction = true
- }
- _, err := c.cn.ExecContext(context.Background(), "ALTER
SESSION SET AUTOCOMMIT = false", nil)
- return err
- default:
- return adbc.Error{
- Msg: "[Snowflake] invalid value for option " +
key + ": " + value,
- Code: adbc.StatusInvalidArgument,
- }
- }
- case adbc.OptionKeyCurrentCatalog:
- _, err := c.cn.ExecContext(context.Background(), "USE DATABASE
?", []driver.NamedValue{{Value: value}})
- return err
- case adbc.OptionKeyCurrentDbSchema:
- _, err := c.cn.ExecContext(context.Background(), "USE SCHEMA
?", []driver.NamedValue{{Value: value}})
- return err
case OptionUseHighPrecision:
// statements will inherit the value of the
OptionUseHighPrecision
// from the connection, but the option can be overridden at the
@@ -1149,24 +1023,3 @@ func (c *cnxn) SetOption(key, value string) error {
}
}
}
-
-func (c *cnxn) SetOptionBytes(key string, value []byte) error {
- return adbc.Error{
- Msg: "[Snowflake] unknown connection option",
- Code: adbc.StatusNotImplemented,
- }
-}
-
-func (c *cnxn) SetOptionInt(key string, value int64) error {
- return adbc.Error{
- Msg: "[Snowflake] unknown connection option",
- Code: adbc.StatusNotImplemented,
- }
-}
-
-func (c *cnxn) SetOptionDouble(key string, value float64) error {
- return adbc.Error{
- Msg: "[Snowflake] unknown connection option",
- Code: adbc.StatusNotImplemented,
- }
-}
diff --git a/go/adbc/driver/snowflake/driver.go
b/go/adbc/driver/snowflake/driver.go
index 77bdcdaac..124c4d388 100644
--- a/go/adbc/driver/snowflake/driver.go
+++ b/go/adbc/driver/snowflake/driver.go
@@ -20,19 +20,15 @@ package snowflake
import (
"errors"
"runtime/debug"
- "strings"
"github.com/apache/arrow-adbc/go/adbc"
- "github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
+ "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/apache/arrow/go/v16/arrow/memory"
"github.com/snowflakedb/gosnowflake"
"golang.org/x/exp/maps"
)
const (
- infoDriverName = "ADBC Snowflake Driver - Go"
- infoVendorName = "Snowflake"
-
OptionDatabase = "adbc.snowflake.sql.db"
OptionSchema = "adbc.snowflake.sql.schema"
OptionWarehouse = "adbc.snowflake.sql.warehouse"
@@ -119,37 +115,18 @@ const (
)
var (
- infoDriverVersion string
- infoDriverArrowVersion string
- infoSupportedCodes []adbc.InfoCode
+ infoVendorVersion string
)
func init() {
if info, ok := debug.ReadBuildInfo(); ok {
for _, dep := range info.Deps {
switch {
- case dep.Path ==
"github.com/apache/arrow-adbc/go/adbc/driver/snowflake":
- infoDriverVersion = dep.Version
- case strings.HasPrefix(dep.Path,
"github.com/apache/arrow/go/"):
- infoDriverArrowVersion = dep.Version
+ case dep.Path == "github.com/snowflakedb/gosnowflake":
+ infoVendorVersion = dep.Version
}
}
}
- // XXX: Deps not populated in tests
- // https://github.com/golang/go/issues/33976
- if infoDriverVersion == "" {
- infoDriverVersion = "(unknown or development build)"
- }
- if infoDriverArrowVersion == "" {
- infoDriverArrowVersion = "(unknown or development build)"
- }
-
- infoSupportedCodes = []adbc.InfoCode{
- adbc.InfoDriverName,
- adbc.InfoDriverVersion,
- adbc.InfoDriverArrowVersion,
- adbc.InfoVendorName,
- }
}
func errToAdbcErr(code adbc.Status, err error) error {
@@ -192,13 +169,21 @@ type driverImpl struct {
// NewDriver creates a new Snowflake driver using the given Arrow allocator.
func NewDriver(alloc memory.Allocator) adbc.Driver {
- return driverbase.NewDriver(&driverImpl{DriverImplBase:
driverbase.NewDriverImplBase("Snowflake", alloc)})
+ info := driverbase.DefaultDriverInfo("Snowflake")
+ if infoVendorVersion != "" {
+ if err := info.RegisterInfoCode(adbc.InfoVendorVersion,
infoVendorVersion); err != nil {
+ panic(err)
+ }
+ }
+ return driverbase.NewDriver(&driverImpl{DriverImplBase:
driverbase.NewDriverImplBase(info, alloc)})
}
func (d *driverImpl) NewDatabase(opts map[string]string) (adbc.Database,
error) {
opts = maps.Clone(opts)
- db := &databaseImpl{DatabaseImplBase:
driverbase.NewDatabaseImplBase(&d.DriverImplBase),
- useHighPrecision: true}
+ db := &databaseImpl{
+ DatabaseImplBase:
driverbase.NewDatabaseImplBase(&d.DriverImplBase),
+ useHighPrecision: true,
+ }
if err := db.SetOptions(opts); err != nil {
return nil, err
}
diff --git a/go/adbc/driver/snowflake/driver_test.go
b/go/adbc/driver/snowflake/driver_test.go
index 5752ae5ee..3f93dbdb5 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -217,6 +217,10 @@ func (s *SnowflakeQuirks) GetMetadata(code adbc.InfoCode)
interface{} {
return "(unknown or development build)"
case adbc.InfoDriverArrowVersion:
return "(unknown or development build)"
+ case adbc.InfoVendorVersion:
+ return "(unknown or development build)"
+ case adbc.InfoVendorArrowVersion:
+ return "(unknown or development build)"
case adbc.InfoDriverADBCVersion:
return adbc.AdbcVersion1_1_0
case adbc.InfoVendorName:
diff --git a/go/adbc/driver/snowflake/snowflake_database.go
b/go/adbc/driver/snowflake/snowflake_database.go
index 76ab4684b..5c5f32b69 100644
--- a/go/adbc/driver/snowflake/snowflake_database.go
+++ b/go/adbc/driver/snowflake/snowflake_database.go
@@ -32,7 +32,7 @@ import (
"time"
"github.com/apache/arrow-adbc/go/adbc"
- "github.com/apache/arrow-adbc/go/adbc/driver/driverbase"
+ "github.com/apache/arrow-adbc/go/adbc/driver/internal/driverbase"
"github.com/snowflakedb/gosnowflake"
"github.com/youmark/pkcs8"
)
@@ -136,28 +136,7 @@ func (d *databaseImpl) GetOption(key string) (string,
error) {
return *val, nil
}
}
- return "", adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
-}
-func (d *databaseImpl) GetOptionBytes(key string) ([]byte, error) {
- return nil, adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
-}
-func (d *databaseImpl) GetOptionInt(key string) (int64, error) {
- return 0, adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
-}
-func (d *databaseImpl) GetOptionDouble(key string) (float64, error) {
- return 0, adbc.Error{
- Msg: fmt.Sprintf("[Snowflake] Unknown database option '%s'",
key),
- Code: adbc.StatusNotFound,
- }
+ return d.DatabaseImplBase.GetOption(key)
}
func (d *databaseImpl) SetOptions(cnOptions map[string]string) error {
@@ -176,7 +155,8 @@ func (d *databaseImpl) SetOptions(cnOptions
map[string]string) error {
}
}
- defaultAppName := "[ADBC][Go-" + infoDriverVersion + "]"
+ driverVersion, _ := d.DatabaseImplBase.DriverInfo.GetInfoDriverVersion()
+ defaultAppName := "[ADBC][Go-" + driverVersion + "]"
// set default application name to track
// unless user overrides it
d.cfg.Application = defaultAppName
@@ -464,15 +444,22 @@ func (d *databaseImpl) Open(ctx context.Context)
(adbc.Connection, error) {
return nil, errToAdbcErr(adbc.StatusIO, err)
}
- return &cnxn{
+ conn := &connectionImpl{
cn: cn.(snowflakeConn),
db: d, ctor: connector,
sqldb: sql.OpenDB(connector),
// default enable high precision
// SetOption(OptionUseHighPrecision, adbc.OptionValueDisabled)
to
// get Int64/Float64 instead
- useHighPrecision: d.useHighPrecision,
- }, nil
+ useHighPrecision: d.useHighPrecision,
+ ConnectionImplBase:
driverbase.NewConnectionImplBase(&d.DatabaseImplBase),
+ }
+
+ return driverbase.NewConnectionBuilder(conn).
+ WithAutocommitSetter(conn).
+ WithCurrentNamespacer(conn).
+ WithTableTypeLister(conn).
+ Connection(), nil
}
func (d *databaseImpl) Close() error {
diff --git a/go/adbc/driver/snowflake/statement.go
b/go/adbc/driver/snowflake/statement.go
index 8439ddfcd..3f446662e 100644
--- a/go/adbc/driver/snowflake/statement.go
+++ b/go/adbc/driver/snowflake/statement.go
@@ -42,7 +42,7 @@ const (
)
type statement struct {
- cnxn *cnxn
+ cnxn *connectionImpl
alloc memory.Allocator
queueSize int
prefetchConcurrency int
diff --git a/go/adbc/go.mod b/go/adbc/go.mod
index fe30fd7cb..35bdc7081 100644
--- a/go/adbc/go.mod
+++ b/go/adbc/go.mod
@@ -83,6 +83,7 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec //
indirect
github.com/sirupsen/logrus v1.9.3 // indirect
+ github.com/stretchr/objx v0.5.2 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/mod v0.16.0 // indirect
diff --git a/go/adbc/go.sum b/go/adbc/go.sum
index a5bea3b7f..cf74b7baa 100644
--- a/go/adbc/go.sum
+++ b/go/adbc/go.sum
@@ -130,6 +130,7 @@ github.com/snowflakedb/gosnowflake v1.8.0
h1:4bQj8eAYGMkou/nICiIEb9jSbBLDDp5cB6J
github.com/snowflakedb/gosnowflake v1.8.0/go.mod
h1:7yyY2MxtDti2eXgtvlZ8QxzCN6KV2B4qb1HuygMI+0U=
github.com/stretchr/objx v0.1.0/go.mod
h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
+github.com/stretchr/objx v0.5.2/go.mod
h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.7.0/go.mod
h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0
h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod
h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=