This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new c40e658fbb GH-34332: [Go][FlightRPC] Add driver for `database/sql` 
framework (#34331)
c40e658fbb is described below

commit c40e658fbbd6201132c4378eb0fefb746ff5915f
Author: Sven Rebhan <[email protected]>
AuthorDate: Tue Apr 11 17:45:58 2023 +0200

    GH-34332: [Go][FlightRPC] Add driver for `database/sql` framework (#34331)
    
    ### Rationale for this change
    
    Using Golang's `database/sql` framework is well known, offers goodies like 
connection pooling and is easy to use. Therefore using FlightSQL trough this 
framework is a good starting point for users performing simple queries, inserts 
etc.
    
    ### What changes are included in this PR?
    
    This PR adds an `database/sql/driver` implementation currently supporting 
`sqlite` and `InfluxData IOx` (query only). Unit-tests are added using the 
SQLite server example implementation and the driver and driver settings are 
documented.
    
    ### Are these changes tested?
    
    Yes, a test-suite is added for the driver. Futhermore, the IOx backend is 
additionally tested against a real local instance using [this 
code](https://github.com/srebhan/go-flightsql-example).
    
    ### Are there any user-facing changes?
    
    This PR does not contain breaking changes. All modifications to the 
FlightSQL client code are transparent to the user.
    * Closes: #34332
    
    Authored-by: Sven Rebhan <[email protected]>
    Signed-off-by: Matt Topol <[email protected]>
---
 go/arrow/flight/client.go                       |   6 +-
 go/arrow/flight/flightsql/client.go             |  10 +-
 go/arrow/flight/flightsql/driver/README.md      | 151 +++++
 go/arrow/flight/flightsql/driver/config.go      | 125 ++++
 go/arrow/flight/flightsql/driver/driver.go      | 492 ++++++++++++++
 go/arrow/flight/flightsql/driver/driver_test.go | 816 ++++++++++++++++++++++++
 go/arrow/flight/flightsql/driver/utils.go       | 272 ++++++++
 7 files changed, 1869 insertions(+), 3 deletions(-)

diff --git a/go/arrow/flight/client.go b/go/arrow/flight/client.go
index 5ad3c9be07..da6b60c89b 100644
--- a/go/arrow/flight/client.go
+++ b/go/arrow/flight/client.go
@@ -271,6 +271,10 @@ func NewFlightClient(addr string, auth ClientAuthHandler, 
opts ...grpc.DialOptio
 // being the inner most wrapper around the actual call. It also passes along 
the dialoptions passed in such
 // as TLS certs and so on.
 func NewClientWithMiddleware(addr string, auth ClientAuthHandler, middleware 
[]ClientMiddleware, opts ...grpc.DialOption) (Client, error) {
+       return NewClientWithMiddlewareCtx(context.Background(), addr, auth, 
middleware, opts...)
+}
+
+func NewClientWithMiddlewareCtx(ctx context.Context, addr string, auth 
ClientAuthHandler, middleware []ClientMiddleware, opts ...grpc.DialOption) 
(Client, error) {
        unary := make([]grpc.UnaryClientInterceptor, 0, len(middleware))
        stream := make([]grpc.StreamClientInterceptor, 0, len(middleware))
        if auth != nil {
@@ -288,7 +292,7 @@ func NewClientWithMiddleware(addr string, auth 
ClientAuthHandler, middleware []C
                }
        }
        opts = append(opts, grpc.WithChainUnaryInterceptor(unary...), 
grpc.WithChainStreamInterceptor(stream...))
-       conn, err := grpc.Dial(addr, opts...)
+       conn, err := grpc.DialContext(ctx, addr, opts...)
        if err != nil {
                return nil, err
        }
diff --git a/go/arrow/flight/flightsql/client.go 
b/go/arrow/flight/flightsql/client.go
index a73fc4657c..a148f83e96 100644
--- a/go/arrow/flight/flightsql/client.go
+++ b/go/arrow/flight/flightsql/client.go
@@ -39,7 +39,11 @@ import (
 // its arguments to flight.NewClientWithMiddleware to create the
 // underlying Flight Client.
 func NewClient(addr string, auth flight.ClientAuthHandler, middleware 
[]flight.ClientMiddleware, opts ...grpc.DialOption) (*Client, error) {
-       cl, err := flight.NewClientWithMiddleware(addr, auth, middleware, 
opts...)
+       return NewClientCtx(context.Background(), addr, auth, middleware, 
opts...)
+}
+
+func NewClientCtx(ctx context.Context, addr string, auth 
flight.ClientAuthHandler, middleware []flight.ClientMiddleware, opts 
...grpc.DialOption) (*Client, error) {
+       cl, err := flight.NewClientWithMiddlewareCtx(ctx, addr, auth, 
middleware, opts...)
        if err != nil {
                return nil, err
        }
@@ -1110,7 +1114,9 @@ func (p *PreparedStatement) clearParameters() {
 func (p *PreparedStatement) SetParameters(binding arrow.Record) {
        p.clearParameters()
        p.paramBinding = binding
-       p.paramBinding.Retain()
+       if p.paramBinding != nil {
+               p.paramBinding.Retain()
+       }
 }
 
 // SetRecordReader takes a RecordReader to send as the parameter bindings when
diff --git a/go/arrow/flight/flightsql/driver/README.md 
b/go/arrow/flight/flightsql/driver/README.md
new file mode 100644
index 0000000000..cfb33ba2c6
--- /dev/null
+++ b/go/arrow/flight/flightsql/driver/README.md
@@ -0,0 +1,151 @@
+<!---
+  Licensed to the Apache Software Foundation (ASF) under one
+  or more contributor license agreements.  See the NOTICE file
+  distributed with this work for additional information
+  regarding copyright ownership.  The ASF licenses this file
+  to you under the Apache License, Version 2.0 (the
+  "License"); you may not use this file except in compliance
+  with the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+  Unless required by applicable law or agreed to in writing,
+  software distributed under the License is distributed on an
+  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+  KIND, either express or implied.  See the License for the
+  specific language governing permissions and limitations
+  under the License.
+-->
+# FlightSQL driver
+
+A FlightSQL-Driver for Go's 
[database/sql](https://golang.org/pkg/database/sql/)
+package. This driver is a lightweight wrapper around the FlightSQL client in
+pure Go. It provides all advantages of a `database/sql` driver like automatic
+connection pooling, transactions combined with ease of use (see (#usage)).
+
+---------------------------------------
+
+* [Prerequisits](#prerequisits)
+* [Usage](#usage)
+* [Data Source Name (DSN)](#data-source-name-dsn)
+* [Driver config usage](#driver-config-usage)
+* [TLS setup](#tls-setup)
+
+---------------------------------------
+
+## Prerequisits
+
+* Go 1.19+
+* Installation via `go get -u 
github.com/apache/arrow/go/v12/arrow/flight/flightsql`
+* Backend speaking FlightSQL
+
+---------------------------------------
+
+## Usage
+
+_Go FlightQL Driver_ is an implementation of Go's `database/sql/driver`
+interface to use the [`database/sql`](https://golang.org/pkg/database/sql/)
+framework. The driver is registered as `flightsql` and configured using a
+[data-source name (DSN)](#data-source-name-dsn).
+
+A basic example using a SQLite backend looks like this
+
+```go
+import (
+    "database/sql"
+    "time"
+
+    _ "github.com/apache/arrow/go/v12/arrow/flight/flightsql"
+)
+
+// Open the connection to an SQLite backend
+db, err := sql.Open("flightsql", "flightsql://localhost:12345?timeout=5s")
+if err != nil {
+    panic(err)
+}
+// Make sure we close the connection to the database
+defer db.Close()
+
+// Use the connection e.g. for querying
+rows, err := db.Query("SELECT * FROM mytable")
+if err != nil {
+    panic(err)
+}
+// ...
+```
+
+## Data Source Name (DSN)
+
+A Data Source Name has the following format:
+
+```text
+flightsql://[user[:password]@]<address>[:port][?param1=value1&...&paramN=valueN]
+```
+
+The data-source-name (DSN) requires the `address` of the backend with an
+optional port setting. The `user` and `password` parameters are passed to the
+backend as GRPC Basic-Auth headers. If your backend requires a token based
+authentication, please use a `token` parameter (see
+[common parameters](#common-parameters) below).
+
+**Please note**: All parameters are case-sensitive!
+
+Alternatively to specifying the DSN directly you can use the `DriverConfig`
+structure to generate the DSN string. See the
+[Driver config usage section](#driver-config-usage) for details.
+
+### Common parameters
+
+The following common parameters exist
+
+#### `token`
+
+The `token` parameter can be used to specify the token for token-based
+authentication. The value is passed on to the backend as a GRPC Bearer-Auth
+header.
+
+#### `timeout`
+
+The `timeout` parameter can be set using a duration string e.g. `timeout=5s`
+to limit the maximum time an operation can take. This prevents calls that wait
+forever, e.g. if the backend is down or a query is taking very long. When
+not set, the driver will use an _infinite_ timeout.
+
+## Driver config usage
+
+Alternatively to specifying the DSN directly you can fill the `DriverConfig`
+structure and generate the DSN out of this. Here is some example
+
+```golang
+package main
+
+import (
+    "database/sql"
+    "log"
+    "time"
+
+    "github.com/apache/arrow/go/v12/arrow/flight/flightsql"
+)
+
+func main() {
+    config := flightsql.DriverConfig{
+        Address: "localhost:12345",
+        Token:   "your token",
+        Timeout: 10 * time.Second,
+        Params: map[string]string{
+            "my-custom-parameter": "foobar",
+        },
+    }
+    db, err := sql.Open("flightsql", config.DSN())
+    if err != nil {
+        log.Fatalf("open failed: %v", err)
+    }
+    defer db.Close()
+
+    ...
+}
+```
+
+## TLS setup
+
+Currently TLS is not yet supported and will be added later.
diff --git a/go/arrow/flight/flightsql/driver/config.go 
b/go/arrow/flight/flightsql/driver/config.go
new file mode 100644
index 0000000000..d4a785dc6b
--- /dev/null
+++ b/go/arrow/flight/flightsql/driver/config.go
@@ -0,0 +1,125 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package driver
+
+import (
+       "crypto/tls"
+       "fmt"
+       "net/url"
+       "time"
+)
+
+type DriverConfig struct {
+       Address  string
+       Username string
+       Password string
+       Token    string
+       Timeout  time.Duration
+       Params   map[string]string
+
+       TLSEnabled bool
+       TLSConfig  *tls.Config
+}
+
+func NewDriverConfigFromDSN(dsn string) (*DriverConfig, error) {
+       u, err := url.Parse(dsn)
+       if err != nil {
+               return nil, err
+       }
+
+       // Sanity checks on the given connection string
+       if u.Scheme != "flightsql" {
+               return nil, fmt.Errorf("invalid scheme %q", u.Scheme)
+       }
+       if u.Path != "" {
+               return nil, fmt.Errorf("unexpected path %q", u.Path)
+       }
+
+       // Extract the settings
+       var username, password string
+       if u.User != nil {
+               username = u.User.Username()
+               if v, set := u.User.Password(); set {
+                       password = v
+               }
+       }
+
+       config := &DriverConfig{
+               Address:  u.Host,
+               Username: username,
+               Password: password,
+               Params:   make(map[string]string),
+       }
+
+       // Determine the parameters
+       for key, values := range u.Query() {
+               // We only support single instances
+               if len(values) > 1 {
+                       return nil, fmt.Errorf("too many values for %q", key)
+               }
+               var v string
+               if len(values) > 0 {
+                       v = values[0]
+               }
+
+               switch key {
+               case "token":
+                       config.Token = v
+               case "timeout":
+                       config.Timeout, err = time.ParseDuration(v)
+                       if err != nil {
+                               return nil, err
+                       }
+               default:
+                       config.Params[key] = v
+               }
+       }
+
+       return config, nil
+}
+
+func (config *DriverConfig) DSN() string {
+       u := url.URL{
+               Scheme: "flightsql",
+               Host:   config.Address,
+       }
+       if config.Username != "" {
+               if config.Password == "" {
+                       u.User = url.User(config.Username)
+               } else {
+                       u.User = url.UserPassword(config.Username, 
config.Password)
+               }
+       }
+
+       // Set the parameters
+       values := url.Values{}
+       if config.Token != "" {
+               values.Add("token", config.Token)
+       }
+       if config.Timeout > 0 {
+               values.Add("timeout", config.Timeout.String())
+       }
+       for k, v := range config.Params {
+               values.Add(k, v)
+       }
+
+       // Check if we do have parameters at all and set them
+       if len(values) > 0 {
+               u.RawQuery = values.Encode()
+       }
+
+       return u.String()
+}
diff --git a/go/arrow/flight/flightsql/driver/driver.go 
b/go/arrow/flight/flightsql/driver/driver.go
new file mode 100644
index 0000000000..970d7a4dfe
--- /dev/null
+++ b/go/arrow/flight/flightsql/driver/driver.go
@@ -0,0 +1,492 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package driver
+
+import (
+       "context"
+       "database/sql"
+       "database/sql/driver"
+       "errors"
+       "fmt"
+       "io"
+       "sort"
+       "time"
+
+       "github.com/apache/arrow/go/v12/arrow"
+       "github.com/apache/arrow/go/v12/arrow/array"
+       "github.com/apache/arrow/go/v12/arrow/flight/flightsql"
+       "github.com/apache/arrow/go/v12/arrow/memory"
+
+       "google.golang.org/grpc"
+       "google.golang.org/grpc/credentials"
+       "google.golang.org/grpc/credentials/insecure"
+)
+
+var (
+       ErrNotSupported          = errors.New("not supported")
+       ErrOutOfRange            = errors.New("index out of range")
+       ErrTransactionInProgress = errors.New("transaction still in progress")
+)
+
+type Rows struct {
+       schema        *arrow.Schema
+       records       []arrow.Record
+       currentRecord int
+       currentRow    int
+}
+
+// Columns returns the names of the columns.
+func (r *Rows) Columns() []string {
+       if len(r.records) == 0 {
+               return nil
+       }
+
+       // All records have the same columns
+       var cols []string
+       for _, c := range r.schema.Fields() {
+               cols = append(cols, c.Name)
+       }
+
+       return cols
+}
+
+// Close closes the rows iterator.
+func (r *Rows) Close() error {
+       for _, rec := range r.records {
+               rec.Release()
+       }
+       r.currentRecord = 0
+       r.currentRow = 0
+
+       return nil
+}
+
+// Next is called to populate the next row of data into
+// the provided slice. The provided slice will be the same
+// size as the Columns() are wide.
+//
+// Next should return io.EOF when there are no more rows.
+//
+// The dest should not be written to outside of Next. Care
+// should be taken when closing Rows not to modify
+// a buffer held in dest.
+func (r *Rows) Next(dest []driver.Value) error {
+       if r.currentRecord >= len(r.records) {
+               return io.EOF
+       }
+       record := r.records[r.currentRecord]
+
+       if int64(r.currentRow) >= record.NumRows() {
+               return ErrOutOfRange
+       }
+
+       for i, arr := range record.Columns() {
+               v, err := fromArrowType(arr, r.currentRow)
+               if err != nil {
+                       return err
+               }
+               dest[i] = v
+       }
+
+       r.currentRow++
+       if int64(r.currentRow) >= record.NumRows() {
+               r.currentRecord++
+               r.currentRow = 0
+       }
+
+       return nil
+}
+
+type Result struct {
+       affected   int64
+       lastinsert int64
+}
+
+// LastInsertId returns the database's auto-generated ID after, for example,
+// an INSERT into a table with primary key.
+func (r *Result) LastInsertId() (int64, error) {
+       if r.lastinsert < 0 {
+               return -1, ErrNotSupported
+       }
+       return r.lastinsert, nil
+}
+
+// RowsAffected returns the number of rows affected by the query.
+func (r *Result) RowsAffected() (int64, error) {
+       if r.affected < 0 {
+               return -1, ErrNotSupported
+       }
+       return r.affected, nil
+}
+
+type Stmt struct {
+       stmt   *flightsql.PreparedStatement
+       client *flightsql.Client
+
+       timeout time.Duration
+}
+
+// Close closes the statement.
+func (s *Stmt) Close() error {
+       ctx := context.Background()
+       if s.timeout > 0 {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithTimeout(ctx, s.timeout)
+               defer cancel()
+       }
+
+       return s.stmt.Close(ctx)
+}
+
+// NumInput returns the number of placeholder parameters.
+func (s *Stmt) NumInput() int {
+       schema := s.stmt.ParameterSchema()
+       if schema == nil {
+               // NumInput may also return -1, if the driver doesn't know its 
number
+               // of placeholders. In that case, the sql package will not 
sanity check
+               // Exec or Query argument counts.
+               return -1
+       }
+
+       // If NumInput returns >= 0, the sql package will sanity check argument
+       // counts from callers and return errors to the caller before the
+       // statement's Exec or Query methods are called.
+       return len(schema.Fields())
+}
+
+// Exec executes a query that doesn't return rows, such
+// as an INSERT or UPDATE.
+func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) {
+       var params []driver.NamedValue
+       for i, arg := range args {
+               params = append(params, driver.NamedValue{
+                       Ordinal: i,
+                       Value:   arg,
+               })
+       }
+
+       return s.ExecContext(context.Background(), params)
+}
+
+// ExecContext executes a query that doesn't return rows, such as an INSERT or 
UPDATE.
+func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) 
(driver.Result, error) {
+       if err := s.setParameters(args); err != nil {
+               return nil, err
+       }
+
+       if _, set := ctx.Deadline(); !set && s.timeout > 0 {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithTimeout(ctx, s.timeout)
+               defer cancel()
+       }
+
+       n, err := s.stmt.ExecuteUpdate(ctx)
+       if err != nil {
+               return nil, err
+       }
+
+       return &Result{affected: n, lastinsert: -1}, nil
+}
+
+// Query executes a query that may return rows, such as a SELECT.
+func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
+       var params []driver.NamedValue
+       for i, arg := range args {
+               params = append(params, driver.NamedValue{
+                       Ordinal: i,
+                       Value:   arg,
+               })
+       }
+
+       return s.QueryContext(context.Background(), params)
+}
+
+// QueryContext executes a query that may return rows, such as a SELECT.
+func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) 
(driver.Rows, error) {
+       if err := s.setParameters(args); err != nil {
+               return nil, err
+       }
+
+       if _, set := ctx.Deadline(); !set && s.timeout > 0 {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithTimeout(ctx, s.timeout)
+               defer cancel()
+       }
+
+       info, err := s.stmt.Execute(ctx)
+       if err != nil {
+               return nil, err
+       }
+
+       rows := Rows{}
+       for _, endpoint := range info.Endpoint {
+               reader, err := s.client.DoGet(ctx, endpoint.GetTicket())
+               if err != nil {
+                       return nil, fmt.Errorf("getting ticket failed: %w", err)
+               }
+
+               rows.schema = reader.Schema()
+               for reader.Next() {
+                       record := reader.Record()
+                       record.Retain()
+                       rows.records = append(rows.records, record)
+
+               }
+               if err := reader.Err(); err != nil {
+                       return &rows, err
+               }
+       }
+
+       return &rows, nil
+}
+
+func (s *Stmt) setParameters(args []driver.NamedValue) error {
+       if len(args) == 0 {
+               s.stmt.SetParameters(nil)
+               return nil
+       }
+
+       sort.SliceStable(args, func(i, j int) bool {
+               return args[i].Ordinal < args[j].Ordinal
+       })
+
+       schema := s.stmt.ParameterSchema()
+       if schema == nil {
+               var fields []arrow.Field
+               for _, arg := range args {
+                       dt, err := toArrowDataType(arg.Value)
+                       if err != nil {
+                               return fmt.Errorf("schema: %w", err)
+                       }
+                       fields = append(fields, arrow.Field{
+                               Name: arg.Name,
+                               Type: dt,
+                       })
+               }
+               schema = arrow.NewSchema(fields, nil)
+       }
+
+       recBuilder := array.NewRecordBuilder(memory.DefaultAllocator, schema)
+       defer recBuilder.Release()
+
+       for i, arg := range args {
+               fieldBuilder := recBuilder.Field(i)
+               if err := setFieldValue(fieldBuilder, arg.Value); err != nil {
+                       return err
+               }
+       }
+
+       rec := recBuilder.NewRecord()
+       defer rec.Release()
+
+       s.stmt.SetParameters(rec)
+
+       return nil
+}
+
+type Tx struct {
+       tx      *flightsql.Txn
+       timeout time.Duration
+}
+
+func (t *Tx) Commit() error {
+       ctx := context.Background()
+       if t.timeout > 0 {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithTimeout(ctx, t.timeout)
+               defer cancel()
+       }
+
+       return t.tx.Commit(ctx)
+}
+
+func (t *Tx) Rollback() error {
+       ctx := context.Background()
+       if t.timeout > 0 {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithTimeout(ctx, t.timeout)
+               defer cancel()
+       }
+
+       return t.tx.Rollback(ctx)
+}
+
+type Driver struct{}
+
+// Open returns a new connection to the database.
+func (d *Driver) Open(name string) (driver.Conn, error) {
+       c, err := d.OpenConnector(name)
+       if err != nil {
+               return nil, err
+       }
+
+       return c.Connect(context.Background())
+}
+
+// OpenConnector must parse the name in the same format that Driver.Open
+// parses the name parameter.
+func (d *Driver) OpenConnector(name string) (driver.Connector, error) {
+       config, err := NewDriverConfigFromDSN(name)
+       if err != nil {
+               return nil, err
+       }
+
+       c := &Connector{}
+       if err := c.Configure(config); err != nil {
+               return nil, err
+       }
+
+       return c, nil
+}
+
+type Connector struct {
+       addr    string
+       timeout time.Duration
+       options []grpc.DialOption
+}
+
+// Configure the driver with the corresponding config
+func (c *Connector) Configure(config *DriverConfig) error {
+       // Set the driver properties
+       c.addr = config.Address
+       c.timeout = config.Timeout
+       c.options = []grpc.DialOption{grpc.WithBlock()}
+
+       // Create GRPC options necessary for the backend
+       var transportCreds credentials.TransportCredentials
+       if !config.TLSEnabled {
+               transportCreds = insecure.NewCredentials()
+       } else {
+               transportCreds = credentials.NewTLS(config.TLSConfig)
+       }
+       c.options = append(c.options, 
grpc.WithTransportCredentials(transportCreds))
+
+       // Set authentication credentials
+       rpcCreds := grpcCredentials{
+               username: config.Username,
+               password: config.Password,
+               token:    config.Token,
+               params:   config.Params,
+       }
+       c.options = append(c.options, grpc.WithPerRPCCredentials(rpcCreds))
+
+       return nil
+}
+
+// Connect returns a connection to the database.
+func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
+       if _, set := ctx.Deadline(); !set && c.timeout > 0 {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithTimeout(ctx, c.timeout)
+               defer cancel()
+       }
+
+       client, err := flightsql.NewClientCtx(ctx, c.addr, nil, nil, 
c.options...)
+       if err != nil {
+               return nil, err
+       }
+
+       return &Connection{
+               client:  client,
+               timeout: c.timeout,
+       }, nil
+}
+
+// Driver returns the underlying Driver of the Connector,
+// mainly to maintain compatibility with the Driver method
+// on sql.DB.
+func (c *Connector) Driver() driver.Driver {
+       return &Driver{}
+}
+
+type Connection struct {
+       client *flightsql.Client
+       txn    *flightsql.Txn
+
+       timeout time.Duration
+}
+
+// Prepare returns a prepared statement, bound to this connection.
+func (c *Connection) Prepare(query string) (driver.Stmt, error) {
+       return c.PrepareContext(context.Background(), query)
+}
+
+// PrepareContext returns a prepared statement, bound to this connection.
+// context is for the preparation of the statement,
+// it must not store the context within the statement itself.
+func (c *Connection) PrepareContext(ctx context.Context, query string) 
(driver.Stmt, error) {
+       if _, set := ctx.Deadline(); !set && c.timeout > 0 {
+               var cancel context.CancelFunc
+               ctx, cancel = context.WithTimeout(ctx, c.timeout)
+               defer cancel()
+       }
+
+       var err error
+       var stmt *flightsql.PreparedStatement
+       if c.txn != nil && c.txn.ID().IsValid() {
+               stmt, err = c.txn.Prepare(ctx, query)
+       } else {
+               stmt, err = c.client.Prepare(ctx, query)
+               c.txn = nil
+       }
+       if err != nil {
+               return nil, err
+       }
+
+       return &Stmt{
+               stmt:    stmt,
+               client:  c.client,
+               timeout: c.timeout,
+       }, nil
+}
+
+// Close invalidates and potentially stops any current
+// prepared statements and transactions, marking this
+// connection as no longer in use.
+func (c *Connection) Close() error {
+       if c.txn != nil && c.txn.ID().IsValid() {
+               return ErrTransactionInProgress
+       }
+
+       if c.client == nil {
+               return nil
+       }
+
+       err := c.client.Close()
+       c.client = nil
+
+       return err
+}
+
+// Begin starts and returns a new transaction.
+func (c *Connection) Begin() (driver.Tx, error) {
+       return c.BeginTx(context.Background(), sql.TxOptions{})
+}
+
+func (c *Connection) BeginTx(ctx context.Context, opts sql.TxOptions) 
(driver.Tx, error) {
+       tx, err := c.client.BeginTransaction(ctx)
+       if err != nil {
+               return nil, err
+       }
+       c.txn = tx
+
+       return &Tx{tx: tx, timeout: c.timeout}, nil
+}
+
+// Register the driver on load.
+func init() {
+       sql.Register("flightsql", &Driver{})
+}
diff --git a/go/arrow/flight/flightsql/driver/driver_test.go 
b/go/arrow/flight/flightsql/driver/driver_test.go
new file mode 100644
index 0000000000..60cfb32364
--- /dev/null
+++ b/go/arrow/flight/flightsql/driver/driver_test.go
@@ -0,0 +1,816 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//go:build go1.18
+// +build go1.18
+
+package driver_test
+
+import (
+       "context"
+       "database/sql"
+       "errors"
+       "fmt"
+       "os"
+       "strings"
+       "sync"
+       "testing"
+       "time"
+
+       "github.com/stretchr/testify/require"
+       "github.com/stretchr/testify/suite"
+
+       "github.com/apache/arrow/go/v12/arrow"
+       "github.com/apache/arrow/go/v12/arrow/array"
+       "github.com/apache/arrow/go/v12/arrow/flight"
+       "github.com/apache/arrow/go/v12/arrow/flight/flightsql"
+       "github.com/apache/arrow/go/v12/arrow/flight/flightsql/driver"
+       "github.com/apache/arrow/go/v12/arrow/flight/flightsql/example"
+       "github.com/apache/arrow/go/v12/arrow/memory"
+)
+
+const defaultTableName = "drivertest"
+
+var defaultStatements = map[string]string{
+       "create table": `
+CREATE TABLE %s (
+  id INTEGER PRIMARY KEY AUTOINCREMENT,
+  name varchar(100),
+  value int
+);`,
+       "insert":            `INSERT INTO %s (name, value) VALUES ('%s', %d);`,
+       "query":             `SELECT * FROM %s;`,
+       "constraint query":  `SELECT * FROM %s WHERE name LIKE '%%%s%%'`,
+       "placeholder query": `SELECT * FROM %s WHERE name LIKE ?`,
+}
+
+type SqlTestSuite struct {
+       suite.Suite
+
+       Config     driver.DriverConfig
+       TableName  string
+       Statements map[string]string
+
+       createServer func() (flight.Server, string, error)
+       startServer  func(flight.Server) error
+       stopServer   func(flight.Server)
+}
+
+func (s *SqlTestSuite) SetupSuite() {
+       if s.TableName == "" {
+               s.TableName = defaultTableName
+       }
+
+       if s.Statements == nil {
+               s.Statements = make(map[string]string)
+       }
+       // Fill in the statements. Keep statements already defined e.g. by the
+       // user or suite-generator.
+       for k, v := range defaultStatements {
+               if _, found := s.Statements[k]; !found {
+                       s.Statements[k] = v
+               }
+       }
+
+       require.Contains(s.T(), s.Statements, "create table")
+       require.Contains(s.T(), s.Statements, "insert")
+       require.Contains(s.T(), s.Statements, "query")
+       require.Contains(s.T(), s.Statements, "constraint query")
+       require.Contains(s.T(), s.Statements, "placeholder query")
+}
+
+func (s *SqlTestSuite) TestOpenClose() {
+       t := s.T()
+
+       // Create and start the server
+       server, addr, err := s.createServer()
+       require.NoError(t, err)
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               require.NoError(s.T(), s.startServer(server))
+       }()
+       defer s.stopServer(server)
+       time.Sleep(100 * time.Millisecond)
+
+       // Configure client
+       cfg := s.Config
+       cfg.Address = addr
+       db, err := sql.Open("flightsql", cfg.DSN())
+       require.NoError(t, err)
+       require.NoError(t, db.Close())
+
+       // Tear-down server
+       s.stopServer(server)
+       wg.Wait()
+}
+
+func (s *SqlTestSuite) TestCreateTable() {
+       t := s.T()
+
+       // Create and start the server
+       server, addr, err := s.createServer()
+       require.NoError(t, err)
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               require.NoError(s.T(), s.startServer(server))
+       }()
+       defer s.stopServer(server)
+       time.Sleep(100 * time.Millisecond)
+
+       // Configure client
+       cfg := s.Config
+       cfg.Address = addr
+       db, err := sql.Open("flightsql", cfg.DSN())
+       require.NoError(t, err)
+       defer db.Close()
+
+       result, err := db.Exec(fmt.Sprintf(s.Statements["create table"], 
s.TableName))
+       require.NoError(t, err)
+
+       affected, err := result.RowsAffected()
+       require.Equal(t, int64(0), affected)
+       require.NoError(t, err)
+
+       last, err := result.LastInsertId()
+       require.Equal(t, int64(-1), last)
+       require.ErrorIs(t, err, driver.ErrNotSupported)
+
+       require.NoError(t, db.Close())
+
+       // Tear-down server
+       s.stopServer(server)
+       wg.Wait()
+}
+
+func (s *SqlTestSuite) TestInsert() {
+       t := s.T()
+
+       // Create and start the server
+       server, addr, err := s.createServer()
+       require.NoError(t, err)
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               require.NoError(s.T(), s.startServer(server))
+       }()
+       defer s.stopServer(server)
+       time.Sleep(100 * time.Millisecond)
+
+       // Configure client
+       cfg := s.Config
+       cfg.Address = addr
+       db, err := sql.Open("flightsql", cfg.DSN())
+       require.NoError(t, err)
+       defer db.Close()
+
+       // Create the table
+       _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+       require.NoError(t, err)
+
+       // Insert data
+       values := map[string]int{
+               "zero":      0,
+               "one":       1,
+               "minus one": -1,
+               "twelve":    12,
+       }
+       var stmts []string
+       for k, v := range values {
+               stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], 
s.TableName, k, v))
+       }
+       result, err := db.Exec(strings.Join(stmts, "\n"))
+       require.NoError(t, err)
+
+       affected, err := result.RowsAffected()
+       require.Equal(t, int64(1), affected)
+       require.NoError(t, err)
+
+       require.NoError(t, db.Close())
+
+       // Tear-down server
+       s.stopServer(server)
+       wg.Wait()
+}
+
+func (s *SqlTestSuite) TestQuery() {
+       t := s.T()
+
+       // Create and start the server
+       server, addr, err := s.createServer()
+       require.NoError(t, err)
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               require.NoError(s.T(), s.startServer(server))
+       }()
+       defer s.stopServer(server)
+       time.Sleep(100 * time.Millisecond)
+
+       // Configure client
+       cfg := s.Config
+       cfg.Address = addr
+       db, err := sql.Open("flightsql", cfg.DSN())
+       require.NoError(t, err)
+       defer db.Close()
+
+       // Create the table
+       _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+       require.NoError(t, err)
+
+       // Insert data
+       expected := map[string]int{
+               "zero":      0,
+               "one":       1,
+               "minus one": -1,
+               "twelve":    12,
+       }
+       var stmts []string
+       for k, v := range expected {
+               stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], 
s.TableName, k, v))
+       }
+       _, err = db.Exec(strings.Join(stmts, "\n"))
+       require.NoError(t, err)
+
+       rows, err := db.Query(fmt.Sprintf(s.Statements["query"], s.TableName))
+       require.NoError(t, err)
+
+       // Check result
+       actual := make(map[string]int, len(expected))
+       for rows.Next() {
+               var name string
+               var id, value int
+               require.NoError(t, rows.Scan(&id, &name, &value))
+               actual[name] = value
+       }
+       require.NoError(t, db.Close())
+       require.EqualValues(t, expected, actual)
+
+       // Tear-down server
+       s.stopServer(server)
+       wg.Wait()
+}
+
+func (s *SqlTestSuite) TestPreparedQuery() {
+       t := s.T()
+
+       // Create and start the server
+       server, addr, err := s.createServer()
+       require.NoError(t, err)
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               require.NoError(s.T(), s.startServer(server))
+       }()
+       defer s.stopServer(server)
+       time.Sleep(100 * time.Millisecond)
+
+       // Configure client
+       cfg := s.Config
+       cfg.Address = addr
+       db, err := sql.Open("flightsql", cfg.DSN())
+       require.NoError(t, err)
+       defer db.Close()
+
+       // Create the table
+       _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+       require.NoError(t, err)
+
+       // Insert data
+       expected := map[string]int{
+               "zero":      0,
+               "one":       1,
+               "minus one": -1,
+               "twelve":    12,
+       }
+       var stmts []string
+       for k, v := range expected {
+               stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], 
s.TableName, k, v))
+       }
+       _, err = db.Exec(strings.Join(stmts, "\n"))
+       require.NoError(t, err)
+
+       // Do query
+       stmt, err := db.Prepare(fmt.Sprintf(s.Statements["query"], s.TableName))
+       require.NoError(t, err)
+
+       rows, err := stmt.Query()
+       require.NoError(t, err)
+
+       // Check result
+       actual := make(map[string]int, len(expected))
+       for rows.Next() {
+               var name string
+               var id, value int
+               require.NoError(t, rows.Scan(&id, &name, &value))
+               actual[name] = value
+       }
+       require.NoError(t, db.Close())
+       require.EqualValues(t, expected, actual)
+
+       // Tear-down server
+       s.stopServer(server)
+       wg.Wait()
+}
+
+func (s *SqlTestSuite) TestPreparedQueryWithConstraint() {
+       t := s.T()
+
+       // Create and start the server
+       server, addr, err := s.createServer()
+       require.NoError(t, err)
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               require.NoError(s.T(), s.startServer(server))
+       }()
+       defer s.stopServer(server)
+       time.Sleep(100 * time.Millisecond)
+
+       // Configure client
+       cfg := s.Config
+       cfg.Address = addr
+       db, err := sql.Open("flightsql", cfg.DSN())
+       require.NoError(t, err)
+       defer db.Close()
+
+       // Create the table
+       _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+       require.NoError(t, err)
+
+       // Insert data
+       data := map[string]int{
+               "zero":      0,
+               "one":       1,
+               "minus one": -1,
+               "twelve":    12,
+       }
+       var stmts []string
+       for k, v := range data {
+               stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], 
s.TableName, k, v))
+       }
+       _, err = db.Exec(strings.Join(stmts, "\n"))
+       require.NoError(t, err)
+
+       // Do query
+       stmt, err := db.Prepare(fmt.Sprintf(s.Statements["constraint query"], 
s.TableName, "one"))
+       require.NoError(t, err)
+
+       rows, err := stmt.Query()
+       require.NoError(t, err)
+
+       // Check result
+       expected := map[string]int{
+               "one":       1,
+               "minus one": -1,
+       }
+       actual := make(map[string]int, len(expected))
+       for rows.Next() {
+               var name string
+               var id, value int
+               require.NoError(t, rows.Scan(&id, &name, &value))
+               actual[name] = value
+       }
+       require.NoError(t, db.Close())
+       require.EqualValues(t, expected, actual)
+
+       // Tear-down server
+       s.stopServer(server)
+       wg.Wait()
+}
+
+func (s *SqlTestSuite) TestPreparedQueryWithPlaceholder() {
+       t := s.T()
+
+       // Create and start the server
+       server, addr, err := s.createServer()
+       require.NoError(t, err)
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               require.NoError(s.T(), s.startServer(server))
+       }()
+       defer s.stopServer(server)
+       time.Sleep(100 * time.Millisecond)
+
+       // Configure client
+       cfg := s.Config
+       cfg.Address = addr
+       db, err := sql.Open("flightsql", cfg.DSN())
+       require.NoError(t, err)
+       defer db.Close()
+
+       // Create the table
+       _, err = db.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+       require.NoError(t, err)
+
+       // Insert data
+       data := map[string]int{
+               "zero":      0,
+               "one":       1,
+               "minus one": -1,
+               "twelve":    12,
+       }
+       var stmts []string
+       for k, v := range data {
+               stmts = append(stmts, fmt.Sprintf(s.Statements["insert"], 
s.TableName, k, v))
+       }
+       _, err = db.Exec(strings.Join(stmts, "\n"))
+       require.NoError(t, err)
+
+       // Do query
+       query := fmt.Sprintf(s.Statements["placeholder query"], s.TableName)
+       stmt, err := db.Prepare(query)
+       require.NoError(t, err)
+
+       params := []interface{}{"%%one%%"}
+       rows, err := stmt.Query(params...)
+       require.NoError(t, err)
+
+       // Check result
+       expected := map[string]int{
+               "one":       1,
+               "minus one": -1,
+       }
+       actual := make(map[string]int, len(expected))
+       for rows.Next() {
+               var name string
+               var id, value int
+               require.NoError(t, rows.Scan(&id, &name, &value))
+               actual[name] = value
+       }
+       require.NoError(t, db.Close())
+       require.EqualValues(t, expected, actual)
+
+       // Tear-down server
+       s.stopServer(server)
+       wg.Wait()
+}
+
+func (s *SqlTestSuite) TestTxRollback() {
+       t := s.T()
+
+       // Create and start the server
+       server, addr, err := s.createServer()
+       require.NoError(t, err)
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               require.NoError(s.T(), s.startServer(server))
+       }()
+       defer s.stopServer(server)
+       time.Sleep(100 * time.Millisecond)
+
+       // Configure client
+       cfg := s.Config
+       cfg.Address = addr
+       db, err := sql.Open("flightsql", cfg.DSN())
+       require.NoError(t, err)
+       defer db.Close()
+
+       tx, err := db.Begin()
+       require.NoError(t, err)
+
+       // Create the table
+       _, err = tx.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+       require.NoError(t, err)
+
+       // Insert data
+       data := map[string]int{
+               "zero":      0,
+               "one":       1,
+               "minus one": -1,
+               "twelve":    12,
+       }
+       for k, v := range data {
+               stmt := fmt.Sprintf(s.Statements["insert"], s.TableName, k, v)
+               _, err = tx.Exec(stmt)
+               require.NoError(t, err)
+       }
+
+       // Rollback the transaction
+       require.NoError(t, tx.Rollback())
+
+       // Check result
+       tbls := `SELECT name FROM sqlite_schema WHERE type ='table' AND name 
NOT LIKE 'sqlite_%';`
+       rows, err := db.Query(tbls)
+       require.NoError(t, err)
+       count := 0
+       for rows.Next() {
+               count++
+       }
+       require.Equal(t, 0, count)
+       require.NoError(t, db.Close())
+
+       // Tear-down server
+       s.stopServer(server)
+       wg.Wait()
+}
+
+func (s *SqlTestSuite) TestTxCommit() {
+       t := s.T()
+
+       // Create and start the server
+       server, addr, err := s.createServer()
+       require.NoError(t, err)
+
+       var wg sync.WaitGroup
+       wg.Add(1)
+       go func() {
+               defer wg.Done()
+               require.NoError(s.T(), s.startServer(server))
+       }()
+       defer s.stopServer(server)
+       time.Sleep(100 * time.Millisecond)
+
+       // Configure client
+       cfg := s.Config
+       cfg.Address = addr
+       db, err := sql.Open("flightsql", cfg.DSN())
+       require.NoError(t, err)
+       defer db.Close()
+
+       tx, err := db.Begin()
+       require.NoError(t, err)
+
+       // Create the table
+       _, err = tx.Exec(fmt.Sprintf(s.Statements["create table"], s.TableName))
+       require.NoError(t, err)
+
+       // Insert data
+       data := map[string]int{
+               "zero":      0,
+               "one":       1,
+               "minus one": -1,
+               "twelve":    12,
+       }
+       for k, v := range data {
+               stmt := fmt.Sprintf(s.Statements["insert"], s.TableName, k, v)
+               _, err = tx.Exec(stmt)
+               require.NoError(t, err)
+       }
+
+       // Commit the transaction
+       require.NoError(t, tx.Commit())
+
+       // Check if the table exists
+       tbls := `SELECT name FROM sqlite_schema WHERE type ='table' AND name 
NOT LIKE 'sqlite_%';`
+       rows, err := db.Query(tbls)
+       require.NoError(t, err)
+
+       var tables []string
+       for rows.Next() {
+               var name string
+               require.NoError(t, rows.Scan(&name))
+               tables = append(tables, name)
+       }
+       require.Contains(t, tables, "drivertest")
+
+       // Check the actual data
+       stmt, err := db.Prepare(fmt.Sprintf(s.Statements["query"], s.TableName))
+       require.NoError(t, err)
+
+       rows, err = stmt.Query()
+       require.NoError(t, err)
+
+       // Check result
+       actual := make(map[string]int, len(data))
+       for rows.Next() {
+               var name string
+               var id, value int
+               require.NoError(t, rows.Scan(&id, &name, &value))
+               actual[name] = value
+       }
+       require.NoError(t, db.Close())
+       require.EqualValues(t, data, actual)
+
+       // Tear-down server
+       s.stopServer(server)
+       wg.Wait()
+}
+
+/*** BACKEND tests ***/
+
+func TestSqliteBackend(t *testing.T) {
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       s := &SqlTestSuite{
+               Config: driver.DriverConfig{
+                       Timeout: 5 * time.Second,
+               },
+       }
+
+       s.createServer = func() (flight.Server, string, error) {
+               server := flight.NewServerWithMiddleware(nil)
+
+               // Setup the SQLite backend
+               db, err := sql.Open("sqlite", ":memory:")
+               if err != nil {
+                       return nil, "", err
+               }
+               sqliteServer, err := example.NewSQLiteFlightSQLServer(db)
+               if err != nil {
+                       return nil, "", err
+               }
+               sqliteServer.Alloc = mem
+
+               // Connect the FlightSQL frontend to the backend
+               
server.RegisterFlightService(flightsql.NewFlightServer(sqliteServer))
+               if err := server.Init("localhost:0"); err != nil {
+                       return nil, "", err
+               }
+               server.SetShutdownOnSignals(os.Interrupt, os.Kill)
+               return server, server.Addr().String(), nil
+       }
+       s.startServer = func(server flight.Server) error { return 
server.Serve() }
+       s.stopServer = func(server flight.Server) { server.Shutdown() }
+
+       suite.Run(t, s)
+}
+
+func TestPreparedStatementSchema(t *testing.T) {
+       // Setup the expected test
+       backend := &MockServer{
+               PreparedStatementParameterSchema: 
arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}, Nullable: false}}, 
nil),
+               DataSchema: arrow.NewSchema([]arrow.Field{
+                       {Name: "time", Type: &arrow.Time64Type{Unit: 
arrow.Nanosecond}, Nullable: true},
+                       {Name: "value", Type: &arrow.Int64Type{}, Nullable: 
false},
+               }, nil),
+               Data: "[]",
+       }
+
+       // Instantiate a mock server
+       server := flight.NewServerWithMiddleware(nil)
+       server.RegisterFlightService(flightsql.NewFlightServer(backend))
+       require.NoError(t, server.Init("localhost:0"))
+       server.SetShutdownOnSignals(os.Interrupt, os.Kill)
+       go server.Serve()
+       defer server.Shutdown()
+
+       // Configure client
+       cfg := driver.DriverConfig{
+               Timeout: 5 * time.Second,
+               Address: server.Addr().String(),
+       }
+       db, err := sql.Open("flightsql", cfg.DSN())
+       require.NoError(t, err)
+       defer db.Close()
+
+       // Do query
+       stmt, err := db.Prepare("SELECT * FROM foo WHERE name LIKE ?")
+       require.NoError(t, err)
+
+       _, err = stmt.Query()
+       require.ErrorContains(t, err, "expected 1 arguments, got 0")
+
+       // Test for error issues by driver
+       _, err = stmt.Query(23)
+       require.ErrorContains(t, err, "invalid value type int64 for builder 
*array.StringBuilder")
+
+       rows, err := stmt.Query("master")
+       require.NoError(t, err)
+       require.NotNil(t, rows)
+}
+
+func TestPreparedStatementNoSchema(t *testing.T) {
+       // Setup the expected test
+       backend := &MockServer{
+               DataSchema: arrow.NewSchema([]arrow.Field{
+                       {Name: "time", Type: &arrow.Time64Type{Unit: 
arrow.Nanosecond}, Nullable: true},
+                       {Name: "value", Type: &arrow.Int64Type{}, Nullable: 
false},
+               }, nil),
+               Data:                            "[]",
+               ExpectedPreparedStatementSchema: 
arrow.NewSchema([]arrow.Field{{Type: &arrow.StringType{}, Nullable: false}}, 
nil),
+       }
+
+       // Instantiate a mock server
+       server := flight.NewServerWithMiddleware(nil)
+       server.RegisterFlightService(flightsql.NewFlightServer(backend))
+       require.NoError(t, server.Init("localhost:0"))
+       server.SetShutdownOnSignals(os.Interrupt, os.Kill)
+       go server.Serve()
+       defer server.Shutdown()
+
+       // Configure client
+       cfg := driver.DriverConfig{
+               Timeout: 5 * time.Second,
+               Address: server.Addr().String(),
+       }
+       db, err := sql.Open("flightsql", cfg.DSN())
+       require.NoError(t, err)
+       defer db.Close()
+
+       // Do query
+       stmt, err := db.Prepare("SELECT * FROM foo WHERE name LIKE ?")
+       require.NoError(t, err)
+
+       _, err = stmt.Query()
+       require.NoError(t, err, "expected 1 arguments, got 0")
+
+       // Test for error issued by server due to missing parameter schema
+       _, err = stmt.Query(23)
+       require.ErrorContains(t, err, "parameter schema: unexpected")
+
+       rows, err := stmt.Query("master")
+       require.NoError(t, err)
+       require.NotNil(t, rows)
+}
+
+// Mockup database server
+type MockServer struct {
+       flightsql.BaseServer
+       DataSchema                       *arrow.Schema
+       PreparedStatementParameterSchema *arrow.Schema
+       PreparedStatementError           string
+       Data                             string
+
+       ExpectedPreparedStatementSchema *arrow.Schema
+}
+
+func (s *MockServer) CreatePreparedStatement(ctx context.Context, req 
flightsql.ActionCreatePreparedStatementRequest) 
(flightsql.ActionCreatePreparedStatementResult, error) {
+       if s.PreparedStatementError != "" {
+               return flightsql.ActionCreatePreparedStatementResult{}, 
errors.New(s.PreparedStatementError)
+       }
+       return flightsql.ActionCreatePreparedStatementResult{
+               Handle:          []byte("prepared"),
+               DatasetSchema:   s.DataSchema,
+               ParameterSchema: s.PreparedStatementParameterSchema,
+       }, nil
+}
+
+func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry 
flightsql.PreparedStatementQuery, r flight.MessageReader, w 
flight.MetadataWriter) error {
+       if s.ExpectedPreparedStatementSchema != nil {
+               if !s.ExpectedPreparedStatementSchema.Equal(r.Schema()) {
+                       return errors.New("parameter schema: unexpected")
+               }
+               return nil
+       }
+
+       if s.PreparedStatementParameterSchema != nil && 
!s.PreparedStatementParameterSchema.Equal(r.Schema()) {
+               return fmt.Errorf("parameter schema: %w", arrow.ErrInvalid)
+       }
+
+       return nil
+}
+
+func (s *MockServer) DoGetStatement(ctx context.Context, ticket 
flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, 
error) {
+       record, _, err := array.RecordFromJSON(memory.DefaultAllocator, 
s.DataSchema, strings.NewReader(s.Data))
+       if err != nil {
+               return nil, nil, err
+       }
+       chunk := make(chan flight.StreamChunk)
+       go func() {
+               defer close(chunk)
+               chunk <- flight.StreamChunk{
+                       Data: record,
+                       Desc: nil,
+                       Err:  nil,
+               }
+       }()
+       return s.DataSchema, chunk, nil
+}
+
+func (s *MockServer) GetFlightInfoPreparedStatement(ctx context.Context, stmt 
flightsql.PreparedStatementQuery, desc *flight.FlightDescriptor) 
(*flight.FlightInfo, error) {
+       handle := stmt.GetPreparedStatementHandle()
+       ticket, err := flightsql.CreateStatementQueryTicket(handle)
+       if err != nil {
+               return nil, err
+       }
+       return &flight.FlightInfo{
+               FlightDescriptor: desc,
+               Endpoint: []*flight.FlightEndpoint{
+                       {Ticket: &flight.Ticket{Ticket: ticket}},
+               },
+               TotalRecords: -1,
+               TotalBytes:   -1,
+       }, nil
+}
diff --git a/go/arrow/flight/flightsql/driver/utils.go 
b/go/arrow/flight/flightsql/driver/utils.go
new file mode 100644
index 0000000000..f8f1a0e86a
--- /dev/null
+++ b/go/arrow/flight/flightsql/driver/utils.go
@@ -0,0 +1,272 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package driver
+
+import (
+       "context"
+       "encoding/base64"
+       "fmt"
+       "time"
+
+       "github.com/apache/arrow/go/v12/arrow"
+       "github.com/apache/arrow/go/v12/arrow/array"
+)
+
+// *** GRPC helpers ***
+type grpcCredentials struct {
+       username string
+       password string
+       token    string
+       params   map[string]string
+}
+
+func (g grpcCredentials) GetRequestMetadata(ctx context.Context, uri 
...string) (map[string]string, error) {
+       md := make(map[string]string, len(g.params)+1)
+
+       // Authentication parameters
+       switch {
+       case g.token != "":
+               md["authorization"] = "Bearer " + g.token
+       case g.username != "":
+
+               md["authorization"] = "Basic " + 
base64.StdEncoding.EncodeToString([]byte(g.username+":"+g.password))
+       }
+
+       for k, v := range g.params {
+               md[k] = v
+       }
+
+       return md, nil
+}
+
+func (g grpcCredentials) RequireTransportSecurity() bool {
+       return g.token != "" || g.username != ""
+}
+
+// *** Type conversions ***
+func fromArrowType(arr arrow.Array, idx int) (interface{}, error) {
+       switch c := arr.(type) {
+       case *array.Boolean:
+               return c.Value(idx), nil
+       case *array.Float16:
+               return float64(c.Value(idx).Float32()), nil
+       case *array.Float32:
+               return float64(c.Value(idx)), nil
+       case *array.Float64:
+               return c.Value(idx), nil
+       case *array.Int8:
+               return int64(c.Value(idx)), nil
+       case *array.Int16:
+               return int64(c.Value(idx)), nil
+       case *array.Int32:
+               return int64(c.Value(idx)), nil
+       case *array.Int64:
+               return c.Value(idx), nil
+       case *array.String:
+               return c.Value(idx), nil
+       case *array.Time32:
+               dt, ok := arr.DataType().(*arrow.Time32Type)
+               if !ok {
+                       return nil, fmt.Errorf("datatype %T not matching 
time32", arr.DataType())
+               }
+               v := c.Value(idx)
+               return v.ToTime(dt.TimeUnit()), nil
+       case *array.Time64:
+               dt, ok := arr.DataType().(*arrow.Time64Type)
+               if !ok {
+                       return nil, fmt.Errorf("datatype %T not matching 
time64", arr.DataType())
+               }
+               v := c.Value(idx)
+               return v.ToTime(dt.TimeUnit()), nil
+       case *array.Timestamp:
+               dt, ok := arr.DataType().(*arrow.TimestampType)
+               if !ok {
+                       return nil, fmt.Errorf("datatype %T not matching 
timestamp", arr.DataType())
+               }
+               v := c.Value(idx)
+               return v.ToTime(dt.TimeUnit()), nil
+       }
+
+       return nil, fmt.Errorf("type %T: %w", arr, ErrNotSupported)
+}
+
+func toArrowDataType(value interface{}) (arrow.DataType, error) {
+       switch value.(type) {
+       case bool:
+               return &arrow.BooleanType{}, nil
+       case float32:
+               return &arrow.Float32Type{}, nil
+       case float64:
+               return &arrow.Float64Type{}, nil
+       case int8:
+               return &arrow.Int8Type{}, nil
+       case int16:
+               return &arrow.Int16Type{}, nil
+       case int32:
+               return &arrow.Int32Type{}, nil
+       case int64:
+               return &arrow.Int64Type{}, nil
+       case uint8:
+               return &arrow.Uint8Type{}, nil
+       case uint16:
+               return &arrow.Uint16Type{}, nil
+       case uint32:
+               return &arrow.Uint32Type{}, nil
+       case uint64:
+               return &arrow.Uint64Type{}, nil
+       case string:
+               return &arrow.StringType{}, nil
+       case time.Time:
+               return &arrow.Time64Type{Unit: arrow.Nanosecond}, nil
+       }
+       return nil, fmt.Errorf("type %T: %w", value, ErrNotSupported)
+}
+
+// *** Field builder versions ***
+func setFieldValue(builder array.Builder, arg interface{}) error {
+       switch b := builder.(type) {
+       case *array.BooleanBuilder:
+               switch v := arg.(type) {
+               case bool:
+                       b.Append(v)
+               case []bool:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.Float32Builder:
+               switch v := arg.(type) {
+               case float32:
+                       b.Append(v)
+               case []float32:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.Float64Builder:
+               switch v := arg.(type) {
+               case float64:
+                       b.Append(v)
+               case []float64:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.Int8Builder:
+               switch v := arg.(type) {
+               case int8:
+                       b.Append(v)
+               case []int8:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.Int16Builder:
+               switch v := arg.(type) {
+               case int16:
+                       b.Append(v)
+               case []int16:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.Int32Builder:
+               switch v := arg.(type) {
+               case int32:
+                       b.Append(v)
+               case []int32:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.Int64Builder:
+               switch v := arg.(type) {
+               case int64:
+                       b.Append(v)
+               case []int64:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.Uint8Builder:
+               switch v := arg.(type) {
+               case uint8:
+                       b.Append(v)
+               case []uint8:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.Uint16Builder:
+               switch v := arg.(type) {
+               case uint16:
+                       b.Append(v)
+               case []uint16:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.Uint32Builder:
+               switch v := arg.(type) {
+               case uint32:
+                       b.Append(v)
+               case []uint32:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.Uint64Builder:
+               switch v := arg.(type) {
+               case uint64:
+                       b.Append(v)
+               case []uint64:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.StringBuilder:
+               switch v := arg.(type) {
+               case string:
+                       b.Append(v)
+               case []string:
+                       b.AppendValues(v, nil)
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       case *array.Time64Builder:
+               switch v := arg.(type) {
+               case int64:
+                       b.Append(arrow.Time64(v))
+               case []int64:
+                       for _, x := range v {
+                               b.Append(arrow.Time64(x))
+                       }
+               case uint64:
+                       b.Append(arrow.Time64(v))
+               case []uint64:
+                       for _, x := range v {
+                               b.Append(arrow.Time64(x))
+                       }
+               case time.Time:
+                       b.Append(arrow.Time64(v.Nanosecond()))
+               default:
+                       return fmt.Errorf("invalid value type %T for builder 
%T", arg, builder)
+               }
+       default:
+               return fmt.Errorf("unknown builder type %T", builder)
+       }
+       return nil
+}

Reply via email to