This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 1cbfb2f2 feat(go/adbc/driver/flightsql): support session options 
(#1597)
1cbfb2f2 is described below

commit 1cbfb2f2e457ddde109f5324ed800c99acc6c576
Author: David Li <[email protected]>
AuthorDate: Mon Mar 11 11:06:37 2024 -0400

    feat(go/adbc/driver/flightsql): support session options (#1597)
    
    Fixes #1557.
---
 c/validation/adbc_validation_connection.cc         |   2 +-
 .../driver/flightsql/flightsql_adbc_server_test.go | 217 +++++++++++++++
 go/adbc/driver/flightsql/flightsql_connection.go   | 293 ++++++++++++++++++++-
 go/adbc/driver/flightsql/flightsql_driver.go       |  39 +--
 .../adbc_driver_flightsql/__init__.py              |  12 +
 5 files changed, 538 insertions(+), 25 deletions(-)

diff --git a/c/validation/adbc_validation_connection.cc 
b/c/validation/adbc_validation_connection.cc
index 7a438f10..4ed1d0e0 100644
--- a/c/validation/adbc_validation_connection.cc
+++ b/c/validation/adbc_validation_connection.cc
@@ -151,7 +151,7 @@ void ConnectionTest::TestMetadataCurrentCatalog() {
     ASSERT_THAT(
         AdbcConnectionGetOption(&connection, 
ADBC_CONNECTION_OPTION_CURRENT_CATALOG,
                                 buffer, &buffer_size, &error),
-        IsStatus(ADBC_STATUS_NOT_FOUND));
+        IsStatus(ADBC_STATUS_NOT_FOUND, &error));
   }
 }
 
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 3b576149..e70694e3 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -21,6 +21,7 @@ package flightsql_test
 
 import (
        "context"
+       "encoding/json"
        "errors"
        "fmt"
        "net"
@@ -41,6 +42,7 @@ import (
        "github.com/apache/arrow/go/v16/arrow/flight"
        "github.com/apache/arrow/go/v16/arrow/flight/flightsql"
        "github.com/apache/arrow/go/v16/arrow/flight/flightsql/schema_ref"
+       flightproto "github.com/apache/arrow/go/v16/arrow/flight/gen/flight"
        "github.com/apache/arrow/go/v16/arrow/memory"
        "github.com/golang/protobuf/ptypes/wrappers"
        "github.com/stretchr/testify/suite"
@@ -134,6 +136,10 @@ func TestMultiTable(t *testing.T) {
        suite.Run(t, &MultiTableTests{})
 }
 
+func TestSessionOptions(t *testing.T) {
+       suite.Run(t, &SessionOptionTests{})
+}
+
 // ---- AuthN Tests --------------------
 
 type AuthnTestServer struct {
@@ -1654,3 +1660,214 @@ func (suite *MultiTableTests) TestGetTableSchema() {
        expectedSchema := arrow.NewSchema([]arrow.Field{{Name: "b", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
        suite.Equal(expectedSchema, actualSchema)
 }
+
+// ---- Session Option Tests --------------------
+
+type SessionOptionTestServer struct {
+       flightsql.BaseServer
+       options map[string]interface{}
+}
+
+func (server *SessionOptionTestServer) GetSessionOptions(ctx context.Context, 
req *flight.GetSessionOptionsRequest) (*flight.GetSessionOptionsResult, error) {
+       options := make(map[string]*flight.SessionOptionValue)
+       for k, v := range server.options {
+               switch s := v.(type) {
+               case bool:
+                       options[k] = &flight.SessionOptionValue{OptionValue: 
&flightproto.SessionOptionValue_BoolValue{BoolValue: s}}
+               case float64:
+                       options[k] = &flight.SessionOptionValue{OptionValue: 
&flightproto.SessionOptionValue_DoubleValue{DoubleValue: s}}
+               case int64:
+                       options[k] = &flight.SessionOptionValue{OptionValue: 
&flightproto.SessionOptionValue_Int64Value{Int64Value: s}}
+               case string:
+                       options[k] = &flight.SessionOptionValue{OptionValue: 
&flightproto.SessionOptionValue_StringValue{StringValue: s}}
+               case []string:
+                       options[k] = &flight.SessionOptionValue{OptionValue: 
&flightproto.SessionOptionValue_StringListValue_{StringListValue: 
&flightproto.SessionOptionValue_StringListValue{Values: s}}}
+               case nil:
+                       options[k] = &flight.SessionOptionValue{}
+               default:
+                       panic("not implemented")
+               }
+       }
+       return &flight.GetSessionOptionsResult{
+               SessionOptions: options,
+       }, nil
+}
+
+func (server *SessionOptionTestServer) SetSessionOptions(ctx context.Context, 
req *flight.SetSessionOptionsRequest) (*flight.SetSessionOptionsResult, error) {
+       errors := map[string]*flightproto.SetSessionOptionsResult_Error{}
+       for k, v := range req.SessionOptions {
+               switch k {
+               case "bad name":
+                       errors[k] = 
&flightproto.SetSessionOptionsResult_Error{Value: 
flightproto.SetSessionOptionsResult_INVALID_NAME}
+                       continue
+               case "bad value":
+                       errors[k] = 
&flightproto.SetSessionOptionsResult_Error{Value: 
flightproto.SetSessionOptionsResult_INVALID_VALUE}
+                       continue
+               case "error":
+                       errors[k] = 
&flightproto.SetSessionOptionsResult_Error{Value: 
flightproto.SetSessionOptionsResult_ERROR}
+                       continue
+               }
+               switch s := v.GetOptionValue().(type) {
+               case *flightproto.SessionOptionValue_BoolValue:
+                       server.options[k] = s.BoolValue
+               case *flightproto.SessionOptionValue_DoubleValue:
+                       server.options[k] = s.DoubleValue
+               case *flightproto.SessionOptionValue_Int64Value:
+                       server.options[k] = s.Int64Value
+               case *flightproto.SessionOptionValue_StringValue:
+                       server.options[k] = s.StringValue
+               case *flightproto.SessionOptionValue_StringListValue_:
+                       server.options[k] = s.StringListValue.Values
+               case nil:
+                       delete(server.options, k)
+               default:
+                       return nil, status.Error(codes.InvalidArgument, 
"invalid option type")
+               }
+       }
+       return &flight.SetSessionOptionsResult{Errors: errors}, nil
+}
+
+func (server *SessionOptionTestServer) CloseSession(ctx context.Context, req 
*flight.CloseSessionRequest) (*flight.CloseSessionResult, error) {
+       return &flight.CloseSessionResult{
+               Status: flight.CloseSessionResultClosed,
+       }, nil
+}
+
+type SessionOptionTests struct {
+       ServerBasedTests
+}
+
+func (suite *SessionOptionTests) SetupSuite() {
+       suite.DoSetupSuite(&SessionOptionTestServer{
+               options: map[string]interface{}{
+                       "string":     "expected",
+                       "bool":       true,
+                       "float64":    float64(1.5),
+                       "int64":      int64(20),
+                       "catalog":    "main",
+                       "schema":     "session",
+                       "stringlist": []string{"a", "b", "c"},
+                       "nilopt":     nil,
+               },
+       }, nil, map[string]string{})
+}
+
+func (suite *SessionOptionTests) TestGetAllOptions() {
+       val, err := 
suite.cnxn.(adbc.GetSetOptions).GetOption(driver.OptionSessionOptions)
+       suite.NoError(err)
+
+       options := make(map[string]interface{})
+       suite.NoError(json.Unmarshal([]byte(val), &options))
+       // XXX: because Go decodes ints to strings by default. Should we use
+       // an alternate representation? What happens to int64max?
+       suite.Equal(float64(20), options["int64"])
+       suite.Equal("expected", options["string"])
+       // Bit of a hack, but lets servers send "this option exists, but is
+       // not set" by returning a nil/unset value
+       suite.Nil(options["nilopt"])
+}
+
+func (suite *SessionOptionTests) TestGetAllOptionsByte() {
+       val, err := 
suite.cnxn.(adbc.GetSetOptions).GetOptionBytes(driver.OptionSessionOptions)
+       suite.NoError(err)
+
+       options := make(map[string]interface{})
+       // XXX: maybe we can return the underlying proto repr here?
+       suite.NoError(json.Unmarshal(val, &options))
+       suite.Equal(float64(20), options["int64"])
+       suite.Equal("expected", options["string"])
+}
+
+func (suite *SessionOptionTests) TestGetSetCatalog() {
+       val, err := 
suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog)
+       suite.NoError(err)
+       suite.Equal("main", val)
+
+       
suite.NoError(suite.cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentCatalog,
 "postgres"))
+       val, err = 
suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentCatalog)
+       suite.NoError(err)
+       suite.Equal("postgres", val)
+}
+
+func (suite *SessionOptionTests) TestGetSetSchema() {
+       val, err := 
suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema)
+       suite.NoError(err)
+       suite.Equal("session", val)
+
+       
suite.NoError(suite.cnxn.(adbc.GetSetOptions).SetOption(adbc.OptionKeyCurrentDbSchema,
 "public"))
+       val, err = 
suite.cnxn.(adbc.GetSetOptions).GetOption(adbc.OptionKeyCurrentDbSchema)
+       suite.NoError(err)
+       suite.Equal("public", val)
+}
+
+func (suite *SessionOptionTests) TestGetSetBool() {
+       o := suite.cnxn.(adbc.GetSetOptions)
+       val, err := o.GetOption(driver.OptionBoolSessionOptionPrefix + "bool")
+       suite.NoError(err)
+       suite.Equal("true", val)
+
+       suite.NoError(o.SetOption(driver.OptionBoolSessionOptionPrefix+"bool", 
"false"))
+       val, err = o.GetOption(driver.OptionBoolSessionOptionPrefix + "bool")
+       suite.NoError(err)
+       suite.Equal("false", val)
+}
+
+func (suite *SessionOptionTests) TestGetSetFloat64() {
+       o := suite.cnxn.(adbc.GetSetOptions)
+       val, err := o.GetOptionDouble(driver.OptionSessionOptionPrefix + 
"float64")
+       suite.NoError(err)
+       suite.Equal(1.5, val)
+
+       
suite.NoError(o.SetOptionDouble(driver.OptionSessionOptionPrefix+"float64", 
-42.0))
+       val, err = o.GetOptionDouble(driver.OptionSessionOptionPrefix + 
"float64")
+       suite.NoError(err)
+       suite.Equal(-42.0, val)
+}
+
+func (suite *SessionOptionTests) TestGetSetInt64() {
+       o := suite.cnxn.(adbc.GetSetOptions)
+       val, err := o.GetOptionInt(driver.OptionSessionOptionPrefix + "int64")
+       suite.NoError(err)
+       suite.Equal(int64(20), val)
+
+       suite.NoError(o.SetOptionInt(driver.OptionSessionOptionPrefix+"int64", 
128))
+       val, err = o.GetOptionInt(driver.OptionSessionOptionPrefix + "int64")
+       suite.NoError(err)
+       suite.Equal(int64(128), val)
+}
+
+func (suite *SessionOptionTests) TestGetSetString() {
+       o := suite.cnxn.(adbc.GetSetOptions)
+       _, err := o.GetOption(driver.OptionSessionOptionPrefix + "unknown")
+       suite.ErrorContains(err, "unknown session option 'unknown'")
+
+       suite.NoError(o.SetOption(driver.OptionSessionOptionPrefix+"unknown", 
"42"))
+       val, err := o.GetOption(driver.OptionSessionOptionPrefix + "unknown")
+       suite.NoError(err)
+       suite.Equal("42", val)
+
+       
suite.NoError(o.SetOption(driver.OptionEraseSessionOptionPrefix+"unknown", ""))
+       _, err = o.GetOption(driver.OptionSessionOptionPrefix + "unknown")
+       suite.ErrorContains(err, "unknown session option 'unknown'")
+
+       suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"bad 
name", ""), "Could not set option(s) 'bad name' (invalid name)")
+       suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"bad 
value", ""), "Could not set option(s) 'bad value' (invalid value)")
+       
suite.ErrorContains(o.SetOption(driver.OptionSessionOptionPrefix+"error", ""), 
"Could not set option(s) 'error' (error setting option)")
+}
+
+func (suite *SessionOptionTests) TestGetSetStringList() {
+       o := suite.cnxn.(adbc.GetSetOptions)
+       val, err := o.GetOption(driver.OptionStringListSessionOptionPrefix + 
"stringlist")
+       suite.NoError(err)
+       suite.Equal(`["a","b","c"]`, val)
+
+       
suite.NoError(o.SetOption(driver.OptionStringListSessionOptionPrefix+"stringlist",
 `["foo", "bar"]`))
+       val, err = o.GetOption(driver.OptionStringListSessionOptionPrefix + 
"stringlist")
+       suite.NoError(err)
+       suite.Equal(`["foo","bar"]`, val)
+
+       
suite.NoError(o.SetOption(driver.OptionStringListSessionOptionPrefix+"stringlist",
 `[]`))
+       val, err = o.GetOption(driver.OptionStringListSessionOptionPrefix + 
"stringlist")
+       suite.NoError(err)
+       suite.Equal(`[]`, val)
+}
diff --git a/go/adbc/driver/flightsql/flightsql_connection.go 
b/go/adbc/driver/flightsql/flightsql_connection.go
index d0aa0b02..e71ac308 100644
--- a/go/adbc/driver/flightsql/flightsql_connection.go
+++ b/go/adbc/driver/flightsql/flightsql_connection.go
@@ -20,6 +20,7 @@ package flightsql
 import (
        "bytes"
        "context"
+       "encoding/json"
        "fmt"
        "io"
        "math"
@@ -32,6 +33,7 @@ import (
        "github.com/apache/arrow/go/v16/arrow/flight"
        "github.com/apache/arrow/go/v16/arrow/flight/flightsql"
        "github.com/apache/arrow/go/v16/arrow/flight/flightsql/schema_ref"
+       flightproto "github.com/apache/arrow/go/v16/arrow/flight/gen/flight"
        "github.com/apache/arrow/go/v16/arrow/ipc"
        "github.com/bluele/gcache"
        "google.golang.org/grpc"
@@ -95,6 +97,115 @@ func doGet(ctx context.Context, cl *flightsql.Client, 
endpoint *flight.FlightEnd
        return nil, err
 }
 
+func (c *cnxn) getSessionOptions(ctx context.Context) (map[string]interface{}, 
error) {
+       ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
+       var header, trailer metadata.MD
+       rawOptions, err := c.cl.GetSessionOptions(ctx, 
&flight.GetSessionOptionsRequest{}, grpc.Header(&header), 
grpc.Trailer(&trailer), c.timeouts)
+       if err != nil {
+               // We're going to make a bit of a concession to backwards 
compatibility
+               // here and ignore UNIMPLEMENTED or INVALID_ARGUMENT
+               grpcStatus := grpcstatus.Convert(err)
+               if grpcStatus.Code() == grpccodes.InvalidArgument || 
grpcStatus.Code() == grpccodes.Unimplemented {
+                       return map[string]interface{}{}, nil
+               }
+               return nil, adbcFromFlightStatusWithDetails(err, header, 
trailer, "GetSessionOptions")
+       }
+
+       options := make(map[string]interface{}, len(rawOptions.SessionOptions))
+       for k, rawValue := range rawOptions.SessionOptions {
+               switch v := rawValue.OptionValue.(type) {
+               case *flightproto.SessionOptionValue_BoolValue:
+                       options[k] = v.BoolValue
+               case *flightproto.SessionOptionValue_DoubleValue:
+                       options[k] = v.DoubleValue
+               case *flightproto.SessionOptionValue_Int64Value:
+                       options[k] = v.Int64Value
+               case *flightproto.SessionOptionValue_StringValue:
+                       options[k] = v.StringValue
+               case *flightproto.SessionOptionValue_StringListValue_:
+                       if v.StringListValue.Values == nil {
+                               options[k] = make([]string, 0)
+                       } else {
+                               options[k] = v.StringListValue.Values
+                       }
+               case nil:
+                       options[k] = nil
+               default:
+                       return nil, adbc.Error{
+                               Code: adbc.StatusNotImplemented,
+                               Msg:  fmt.Sprintf("[FlightSQL] Unknown session 
option type %#v", rawValue),
+                       }
+               }
+       }
+       return options, nil
+}
+
+func (c *cnxn) setSessionOptions(ctx context.Context, key string, val 
interface{}) error {
+       req := flight.SetSessionOptionsRequest{}
+
+       var err error
+       req.SessionOptions, err = 
flight.NewSessionOptionValues(map[string]any{key: val})
+       if err != nil {
+               return adbc.Error{
+                       Msg:  fmt.Sprintf("[Flight SQL] Invalid session option 
%s=%#v: %s", key, val, err.Error()),
+                       Code: adbc.StatusInvalidArgument,
+               }
+       }
+
+       var header, trailer metadata.MD
+       errors, err := c.cl.SetSessionOptions(ctx, &req, grpc.Header(&header), 
grpc.Trailer(&trailer), c.timeouts)
+       if err != nil {
+               return adbcFromFlightStatusWithDetails(err, header, trailer, 
"GetSessionOptions")
+       }
+       if len(errors.Errors) > 0 {
+               msg := strings.Builder{}
+               fmt.Fprint(&msg, "[Flight SQL] Could not set option(s) ")
+
+               first := true
+               for k, v := range errors.Errors {
+                       if !first {
+                               fmt.Fprint(&msg, ", ")
+                       }
+                       first = false
+
+                       errmsg := "unknown error"
+                       switch v.Value {
+                       case flightproto.SetSessionOptionsResult_INVALID_NAME:
+                               errmsg = "invalid name"
+                       case flightproto.SetSessionOptionsResult_INVALID_VALUE:
+                               errmsg = "invalid value"
+                       case flightproto.SetSessionOptionsResult_ERROR:
+                               errmsg = "error setting option"
+                       }
+                       fmt.Fprintf(&msg, "'%s' (%s)", k, errmsg)
+               }
+
+               return adbc.Error{
+                       Msg:  msg.String(),
+                       Code: adbc.StatusInvalidArgument,
+               }
+       }
+       return nil
+}
+
+func getSessionOption[T any](options map[string]interface{}, key string, 
defaultVal T, valueType string) (T, error) {
+       rawValue, ok := options[key]
+       if !ok {
+               return defaultVal, adbc.Error{
+                       Msg:  fmt.Sprintf("[Flight SQL] unknown session option 
'%s'", key),
+                       Code: adbc.StatusNotFound,
+               }
+       }
+       value, ok := rawValue.(T)
+       if !ok {
+               return defaultVal, adbc.Error{
+                       Msg:  fmt.Sprintf("[Flight SQL] session option %s=%#v 
is not %s value", key, rawValue, valueType),
+                       Code: adbc.StatusNotFound,
+               }
+       }
+       return value, nil
+}
+
 func (c *cnxn) GetOption(key string) (string, error) {
        if strings.HasPrefix(key, OptionRPCCallHeaderPrefix) {
                name := strings.TrimPrefix(key, OptionRPCCallHeaderPrefix)
@@ -124,16 +235,96 @@ func (c *cnxn) GetOption(key string) (string, error) {
                        return adbc.OptionValueEnabled, nil
                }
        case adbc.OptionKeyCurrentCatalog:
+               options, err := c.getSessionOptions(context.Background())
+               if err != nil {
+                       return "", err
+               }
+               if catalog, ok := options["catalog"]; ok {
+                       if val, ok := catalog.(string); ok {
+                               return val, nil
+                       }
+                       return "", adbc.Error{
+                               Msg:  fmt.Sprintf("[FlightSQL] Server returned 
non-string catalog %#v", catalog),
+                               Code: adbc.StatusInternal,
+                       }
+               }
                return "", adbc.Error{
-                       Msg:  "[Flight SQL] current catalog not supported",
+                       Msg:  "[FlightSQL] current catalog not supported",
                        Code: adbc.StatusNotFound,
                }
 
        case adbc.OptionKeyCurrentDbSchema:
+               options, err := c.getSessionOptions(context.Background())
+               if err != nil {
+                       return "", err
+               }
+               if schema, ok := options["schema"]; ok {
+                       if val, ok := schema.(string); ok {
+                               return val, nil
+                       }
+                       return "", adbc.Error{
+                               Msg:  fmt.Sprintf("[FlightSQL] Server returned 
non-string schema %#v", schema),
+                               Code: adbc.StatusInternal,
+                       }
+               }
                return "", adbc.Error{
-                       Msg:  "[Flight SQL] current schema not supported",
+                       Msg:  "[FlightSQL] current schema not supported",
                        Code: adbc.StatusNotFound,
                }
+       case OptionSessionOptions:
+               options, err := c.getSessionOptions(context.Background())
+               if err != nil {
+                       return "", err
+               }
+               encoded, err := json.Marshal(options)
+               if err != nil {
+                       return "", adbc.Error{
+                               Msg:  fmt.Sprintf("[Flight SQL] Could not 
encode option values: %s", err.Error()),
+                               Code: adbc.StatusInternal,
+                       }
+               }
+               return string(encoded), nil
+       }
+       switch {
+       case strings.HasPrefix(key, OptionSessionOptionPrefix):
+               options, err := c.getSessionOptions(context.Background())
+               if err != nil {
+                       return "", err
+               }
+               name := key[len(OptionSessionOptionPrefix):]
+               return getSessionOption(options, name, "", "a string")
+       case strings.HasPrefix(key, OptionBoolSessionOptionPrefix):
+               options, err := c.getSessionOptions(context.Background())
+               if err != nil {
+                       return "", err
+               }
+               name := key[len(OptionBoolSessionOptionPrefix):]
+               v, err := getSessionOption(options, name, false, "a boolean")
+               if err != nil {
+                       return "", err
+               }
+               if v {
+                       return adbc.OptionValueEnabled, nil
+               }
+               return adbc.OptionValueDisabled, nil
+       case strings.HasPrefix(key, OptionStringListSessionOptionPrefix):
+               options, err := c.getSessionOptions(context.Background())
+               if err != nil {
+                       return "", err
+               }
+               name := key[len(OptionStringListSessionOptionPrefix):]
+               v, err := getSessionOption[[]string](options, name, nil, "a 
string list")
+               if err != nil {
+                       return "", err
+               }
+               encoded, err := json.Marshal(v)
+               if err != nil {
+                       return "", adbc.Error{
+                               Msg:  fmt.Sprintf("[Flight SQL] Could not 
encode option value: %s", err.Error()),
+                               Code: adbc.StatusInternal,
+                       }
+               }
+               return string(encoded), nil
        }
 
        return "", adbc.Error{
@@ -143,6 +334,22 @@ func (c *cnxn) GetOption(key string) (string, error) {
 }
 
 func (c *cnxn) GetOptionBytes(key string) ([]byte, error) {
+       switch key {
+       case OptionSessionOptions:
+               options, err := c.getSessionOptions(context.Background())
+               if err != nil {
+                       return nil, err
+               }
+               encoded, err := json.Marshal(options)
+               if err != nil {
+                       return nil, adbc.Error{
+                               Msg:  fmt.Sprintf("[Flight SQL] Could not 
encode option values: %s", err.Error()),
+                               Code: adbc.StatusInternal,
+                       }
+               }
+               return encoded, nil
+       }
+
        return nil, adbc.Error{
                Msg:  "[Flight SQL] unknown connection option",
                Code: adbc.StatusNotFound,
@@ -162,6 +369,14 @@ func (c *cnxn) GetOptionInt(key string) (int64, error) {
                }
                return int64(val), nil
        }
+       if strings.HasPrefix(key, OptionSessionOptionPrefix) {
+               options, err := c.getSessionOptions(context.Background())
+               if err != nil {
+                       return 0, err
+               }
+               name := key[len(OptionSessionOptionPrefix):]
+               return getSessionOption(options, name, int64(0), "an integer")
+       }
 
        return 0, adbc.Error{
                Msg:  "[Flight SQL] unknown connection option",
@@ -178,6 +393,14 @@ func (c *cnxn) GetOptionDouble(key string) (float64, 
error) {
        case OptionTimeoutUpdate:
                return c.timeouts.updateTimeout.Seconds(), nil
        }
+       if strings.HasPrefix(key, OptionSessionOptionPrefix) {
+               options, err := c.getSessionOptions(context.Background())
+               if err != nil {
+                       return 0, err
+               }
+               name := key[len(OptionSessionOptionPrefix):]
+               return getSessionOption(options, name, float64(0.0), "a 
floating-point")
+       }
 
        return 0.0, adbc.Error{
                Msg:  "[Flight SQL] unknown connection option",
@@ -245,12 +468,47 @@ func (c *cnxn) SetOption(key, value string) error {
                        }
                }
                return nil
+       case adbc.OptionKeyCurrentCatalog:
+               return c.setSessionOptions(context.Background(), "catalog", 
value)
+       case adbc.OptionKeyCurrentDbSchema:
+               return c.setSessionOptions(context.Background(), "schema", 
value)
+       }
 
-       default:
-               return adbc.Error{
-                       Msg:  "[Flight SQL] unknown connection option",
-                       Code: adbc.StatusNotImplemented,
+       switch {
+       case strings.HasPrefix(key, OptionSessionOptionPrefix):
+               name := key[len(OptionSessionOptionPrefix):]
+               return c.setSessionOptions(context.Background(), name, value)
+       case strings.HasPrefix(key, OptionBoolSessionOptionPrefix):
+               name := key[len(OptionBoolSessionOptionPrefix):]
+               switch value {
+               case adbc.OptionValueEnabled:
+                       return c.setSessionOptions(context.Background(), name, 
true)
+               case adbc.OptionValueDisabled:
+                       return c.setSessionOptions(context.Background(), name, 
false)
+               default:
+                       return adbc.Error{
+                               Msg:  fmt.Sprintf("[Flight SQL] invalid boolean 
session option value %s=%s", name, value),
+                               Code: adbc.StatusNotImplemented,
+                       }
                }
+       case strings.HasPrefix(key, OptionStringListSessionOptionPrefix):
+               name := key[len(OptionStringListSessionOptionPrefix):]
+               stringlist := make([]string, 0)
+               if err := json.Unmarshal([]byte(value), &stringlist); err != 
nil {
+                       return adbc.Error{
+                               Msg:  fmt.Sprintf("[Flight SQL] invalid string 
list session option value %s=%s: %s", name, value, err.Error()),
+                               Code: adbc.StatusNotImplemented,
+                       }
+               }
+               return c.setSessionOptions(context.Background(), name, 
stringlist)
+       case strings.HasPrefix(key, OptionEraseSessionOptionPrefix):
+               name := key[len(OptionEraseSessionOptionPrefix):]
+               return c.setSessionOptions(context.Background(), name, nil)
+       }
+
+       return adbc.Error{
+               Msg:  "[Flight SQL] unknown connection option",
+               Code: adbc.StatusNotImplemented,
        }
 }
 
@@ -266,6 +524,10 @@ func (c *cnxn) SetOptionInt(key string, value int64) error 
{
        case OptionTimeoutFetch, OptionTimeoutQuery, OptionTimeoutUpdate:
                return c.timeouts.setTimeout(key, float64(value))
        }
+       if strings.HasPrefix(key, OptionSessionOptionPrefix) {
+               name := key[len(OptionSessionOptionPrefix):]
+               return c.setSessionOptions(context.Background(), name, value)
+       }
 
        return adbc.Error{
                Msg:  "[Flight SQL] unknown connection option",
@@ -282,6 +544,10 @@ func (c *cnxn) SetOptionDouble(key string, value float64) 
error {
        case OptionTimeoutUpdate:
                return c.timeouts.setTimeout(key, value)
        }
+       if strings.HasPrefix(key, OptionSessionOptionPrefix) {
+               name := key[len(OptionSessionOptionPrefix):]
+               return c.setSessionOptions(context.Background(), name, value)
+       }
 
        return adbc.Error{
                Msg:  "[Flight SQL] unknown connection option",
@@ -937,7 +1203,20 @@ func (c *cnxn) Close() error {
                }
        }
 
-       err := c.cl.Close()
+       ctx := metadata.NewOutgoingContext(context.Background(), c.hdrs)
+       var header, trailer metadata.MD
+       _, err := c.cl.CloseSession(ctx, &flight.CloseSessionRequest{}, 
grpc.Header(&header), grpc.Trailer(&trailer), c.timeouts)
+       if err != nil {
+               grpcStatus := grpcstatus.Convert(err)
+               // Ignore unimplemented
+               if grpcStatus.Code() != grpccodes.Unimplemented {
+                       // Ignore the error since server may not support it and 
may not properly return UNIMPLEMENTED
+                       // 
TODO(https://github.com/apache/arrow-adbc/issues/1243): log a proper warning
+                       c.db.Logger.Debug("failed to close session", "error", 
err.Error())
+               }
+       }
+
+       err = c.cl.Close()
        c.cl = nil
        return adbcFromFlightStatus(err, "Close")
 }
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go 
b/go/adbc/driver/flightsql/flightsql_driver.go
index df1ae688..d437f082 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -45,23 +45,28 @@ import (
 )
 
 const (
-       OptionAuthority           = "adbc.flight.sql.client_option.authority"
-       OptionMTLSCertChain       = 
"adbc.flight.sql.client_option.mtls_cert_chain"
-       OptionMTLSPrivateKey      = 
"adbc.flight.sql.client_option.mtls_private_key"
-       OptionSSLOverrideHostname = 
"adbc.flight.sql.client_option.tls_override_hostname"
-       OptionSSLSkipVerify       = 
"adbc.flight.sql.client_option.tls_skip_verify"
-       OptionSSLRootCerts        = 
"adbc.flight.sql.client_option.tls_root_certs"
-       OptionWithBlock           = "adbc.flight.sql.client_option.with_block"
-       OptionWithMaxMsgSize      = 
"adbc.flight.sql.client_option.with_max_msg_size"
-       OptionAuthorizationHeader = "adbc.flight.sql.authorization_header"
-       OptionTimeoutConnect      = 
"adbc.flight.sql.rpc.timeout_seconds.connect"
-       OptionTimeoutFetch        = "adbc.flight.sql.rpc.timeout_seconds.fetch"
-       OptionTimeoutQuery        = "adbc.flight.sql.rpc.timeout_seconds.query"
-       OptionTimeoutUpdate       = "adbc.flight.sql.rpc.timeout_seconds.update"
-       OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header."
-       OptionCookieMiddleware    = "adbc.flight.sql.rpc.with_cookie_middleware"
-       OptionLastFlightInfo      = 
"adbc.flight.sql.statement.exec.last_flight_info"
-       infoDriverName            = "ADBC Flight SQL Driver - Go"
+       OptionAuthority                     = 
"adbc.flight.sql.client_option.authority"
+       OptionMTLSCertChain                 = 
"adbc.flight.sql.client_option.mtls_cert_chain"
+       OptionMTLSPrivateKey                = 
"adbc.flight.sql.client_option.mtls_private_key"
+       OptionSSLOverrideHostname           = 
"adbc.flight.sql.client_option.tls_override_hostname"
+       OptionSSLSkipVerify                 = 
"adbc.flight.sql.client_option.tls_skip_verify"
+       OptionSSLRootCerts                  = 
"adbc.flight.sql.client_option.tls_root_certs"
+       OptionWithBlock                     = 
"adbc.flight.sql.client_option.with_block"
+       OptionWithMaxMsgSize                = 
"adbc.flight.sql.client_option.with_max_msg_size"
+       OptionAuthorizationHeader           = 
"adbc.flight.sql.authorization_header"
+       OptionTimeoutConnect                = 
"adbc.flight.sql.rpc.timeout_seconds.connect"
+       OptionTimeoutFetch                  = 
"adbc.flight.sql.rpc.timeout_seconds.fetch"
+       OptionTimeoutQuery                  = 
"adbc.flight.sql.rpc.timeout_seconds.query"
+       OptionTimeoutUpdate                 = 
"adbc.flight.sql.rpc.timeout_seconds.update"
+       OptionRPCCallHeaderPrefix           = "adbc.flight.sql.rpc.call_header."
+       OptionCookieMiddleware              = 
"adbc.flight.sql.rpc.with_cookie_middleware"
+       OptionSessionOptions                = "adbc.flight.sql.session.options"
+       OptionSessionOptionPrefix           = "adbc.flight.sql.session.option."
+       OptionEraseSessionOptionPrefix      = 
"adbc.flight.sql.session.optionerase."
+       OptionBoolSessionOptionPrefix       = 
"adbc.flight.sql.session.optionbool."
+       OptionStringListSessionOptionPrefix = 
"adbc.flight.sql.session.optionstringlist."
+       OptionLastFlightInfo                = 
"adbc.flight.sql.statement.exec.last_flight_info"
+       infoDriverName                      = "ADBC Flight SQL Driver - Go"
 )
 
 var (
diff --git a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py 
b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
index e50f7c5a..7d45adf0 100644
--- a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
+++ b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
@@ -87,6 +87,18 @@ class ConnectionOptions(enum.Enum):
     #:
     #: Overrides any headers set via the equivalent database option.
     RPC_CALL_HEADER_PREFIX = DatabaseOptions.RPC_CALL_HEADER_PREFIX.value
+    #: Get all session options as a JSON key-value blob.
+    OPTION_SESSION_OPTIONS = "adbc.flight.sql.session.options"
+    #: Get or set a session option.
+    OPTION_SESSION_OPTION_PREFIX = "adbc.flight.sql.session.option."
+    #: Erase a session option (use "" as the value).
+    OPTION_ERASE_SESSION_OPTION_PREFIX = "adbc.flight.sql.session.optionerase."
+    #: Get or set a boolean valued session option.
+    OPTION_BOOL_SESSION_OPTION_PREFIX = "adbc.flight.sql.session.optionbool."
+    #: Get or set a string-list-valued session option as a JSON array.
+    OPTION_STRING_LIST_SESSION_OPTION_PREFIX = (
+        "adbc.flight.sql.session.optionstringlist."
+    )
     #: Set a timeout on calls that fetch data (in floating-point seconds).
     #:
     #: This corresponds to Flight RPC DoGet calls.

Reply via email to