This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new a14aad9  fix(go/adbc/driver/flightsql): use updated authorization 
header from server (#594)
a14aad9 is described below

commit a14aad9bd7021a2b4a74943185b2a6a11384965f
Author: David Li <[email protected]>
AuthorDate: Sat Apr 22 08:22:14 2023 +0900

    fix(go/adbc/driver/flightsql): use updated authorization header from server 
(#594)
    
    Fixes #584.
---
 go/adbc/driver/flightsql/flightsql_adbc.go      |  17 ++-
 go/adbc/driver/flightsql/flightsql_adbc_test.go | 131 ++++++++++++++++++++++++
 go/adbc/go.mod                                  |   2 +-
 go/adbc/go.sum                                  |   4 +-
 go/adbc/validation/validation.go                |   4 +-
 5 files changed, 152 insertions(+), 6 deletions(-)

diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go 
b/go/adbc/driver/flightsql/flightsql_adbc.go
index e78f72a..5ab6d59 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -45,6 +45,7 @@ import (
        "runtime/debug"
        "strconv"
        "strings"
+       "sync"
        "time"
 
        "github.com/apache/arrow-adbc/go/adbc"
@@ -512,14 +513,27 @@ func streamTimeoutInterceptor(ctx context.Context, desc 
*grpc.StreamDesc, cc *gr
 }
 
 type bearerAuthMiddleware struct {
-       hdrs metadata.MD
+       mutex sync.RWMutex
+       hdrs  metadata.MD
 }
 
 func (b *bearerAuthMiddleware) StartCall(ctx context.Context) context.Context {
        md, _ := metadata.FromOutgoingContext(ctx)
+       b.mutex.RLock()
+       defer b.mutex.RUnlock()
        return metadata.NewOutgoingContext(ctx, metadata.Join(md, b.hdrs))
 }
 
+func (b *bearerAuthMiddleware) HeadersReceived(ctx context.Context, md 
metadata.MD) {
+       // apache/arrow-adbc#584
+       headers := md.Get("authorization")
+       if len(headers) > 0 {
+               b.mutex.Lock()
+               defer b.mutex.Unlock()
+               b.hdrs.Set("authorization", headers...)
+       }
+}
+
 func getFlightClient(ctx context.Context, loc string, d *database) 
(*flightsql.Client, error) {
        authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()}
        middleware := []flight.ClientMiddleware{
@@ -564,6 +578,7 @@ func getFlightClient(ctx context.Context, loc string, d 
*database) (*flightsql.C
                }
 
                if md, ok := metadata.FromOutgoingContext(ctx); ok {
+                       // No need to worry about lock here since we are sole 
owner
                        authMiddle.hdrs.Set("authorization", 
md.Get("Authorization")[0])
                }
        }
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 09e6ac8..b9265ca 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -49,8 +49,10 @@ import (
        "github.com/stretchr/testify/require"
        "github.com/stretchr/testify/suite"
        "google.golang.org/grpc"
+       "google.golang.org/grpc/codes"
        "google.golang.org/grpc/credentials"
        "google.golang.org/grpc/metadata"
+       "google.golang.org/grpc/status"
        "google.golang.org/protobuf/proto"
 )
 
@@ -252,6 +254,7 @@ func TestADBCFlightSQL(t *testing.T) {
 
        suite.Run(t, &DefaultDialOptionsTests{Quirks: q})
        suite.Run(t, &HeaderTests{Quirks: q})
+       suite.Run(t, &AuthnTests{})
        suite.Run(t, &OptionTests{Quirks: q})
        suite.Run(t, &PartitionTests{Quirks: q})
        suite.Run(t, &StatementTests{Quirks: q})
@@ -708,6 +711,134 @@ func (suite *HeaderTests) TestPrepared() {
        suite.Contains(suite.Quirks.middle.recordedHeaders.Get("x-header-two"), 
"value 2")
 }
 
+type AuthnTests struct {
+       suite.Suite
+
+       s    flight.Server
+       db   adbc.Database
+       cnxn adbc.Connection
+}
+
+type AuthnTestServer struct {
+       flightsql.BaseServer
+}
+
+func (server *AuthnTestServer) GetFlightInfoStatement(ctx context.Context, cmd 
flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo, 
error) {
+       md := metadata.MD{}
+       md.Set("authorization", "Bearer final")
+       if err := grpc.SendHeader(ctx, md); err != nil {
+               return nil, err
+       }
+       tkt, _ := flightsql.CreateStatementQueryTicket([]byte{})
+       info := &flight.FlightInfo{
+               FlightDescriptor: desc,
+               Endpoint: []*flight.FlightEndpoint{
+                       {Ticket: &flight.Ticket{Ticket: tkt}},
+               },
+               TotalRecords: -1,
+               TotalBytes:   -1,
+       }
+       return info, nil
+}
+
+func (server *AuthnTestServer) DoGetStatement(ctx context.Context, tkt 
flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, 
error) {
+       sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+       rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, 
strings.NewReader(`[{"a": 5}]`))
+       if err != nil {
+               return nil, nil, err
+       }
+
+       ch := make(chan flight.StreamChunk)
+       go func() {
+               defer close(ch)
+               ch <- flight.StreamChunk{
+                       Data: rec,
+                       Desc: nil,
+                       Err:  nil,
+               }
+       }()
+       return sc, ch, nil
+}
+
+func authnTestUnary(ctx context.Context, req interface{}, info 
*grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) 
{
+       md, ok := metadata.FromIncomingContext(ctx)
+       if !ok {
+               return nil, status.Error(codes.InvalidArgument, "Could not get 
metadata")
+       }
+       auth := md.Get("authorization")
+       if len(auth) == 0 {
+               return nil, status.Error(codes.Unauthenticated, "No token")
+       } else if auth[0] != "Bearer initial" {
+               return nil, status.Error(codes.Unauthenticated, "Invalid token 
for unary call: "+auth[0])
+       }
+
+       md.Set("authorization", "Bearer final")
+       ctx = metadata.NewOutgoingContext(ctx, md)
+       return handler(ctx, req)
+}
+
+func authnTestStream(srv interface{}, ss grpc.ServerStream, info 
*grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+       md, ok := metadata.FromIncomingContext(ss.Context())
+       if !ok {
+               return status.Error(codes.InvalidArgument, "Could not get 
metadata")
+       }
+       auth := md.Get("authorization")
+       if len(auth) == 0 {
+               return status.Error(codes.Unauthenticated, "No token")
+       } else if auth[0] != "Bearer final" {
+               return status.Error(codes.Unauthenticated, "Invalid token for 
stream call: "+auth[0])
+       }
+
+       return handler(srv, ss)
+}
+
+func (suite *AuthnTests) SetupSuite() {
+       suite.s = flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+               {Stream: authnTestStream, Unary: authnTestUnary},
+       })
+       
suite.s.RegisterFlightService(flightsql.NewFlightServer(&AuthnTestServer{}))
+       suite.Require().NoError(suite.s.Init("localhost:0"))
+       suite.s.SetShutdownOnSignals(os.Interrupt, os.Kill)
+       go func() {
+               _ = suite.s.Serve()
+       }()
+
+       uri := "grpc+tcp://" + suite.s.Addr().String()
+       var err error
+       suite.db, err = (driver.Driver{}).NewDatabase(map[string]string{
+               "uri":                            uri,
+               driver.OptionAuthorizationHeader: "Bearer initial",
+       })
+       suite.Require().NoError(err)
+}
+
+func (suite *AuthnTests) SetupTest() {
+       var err error
+       suite.cnxn, err = suite.db.Open(context.Background())
+       suite.Require().NoError(err)
+}
+
+func (suite *AuthnTests) TearDownTest() {
+       suite.Require().NoError(suite.cnxn.Close())
+}
+
+func (suite *AuthnTests) TearDownSuite() {
+       suite.db = nil
+       suite.s.Shutdown()
+}
+
+func (suite *AuthnTests) TestBearerTokenUpdated() {
+       // apache/arrow-adbc#584: when setting the auth header directly, the 
client should use any updated token value from the server if given
+       stmt, err := suite.cnxn.NewStatement()
+       suite.Require().NoError(err)
+       defer stmt.Close()
+
+       suite.Require().NoError(stmt.SetSqlQuery("timeout"))
+       reader, _, err := stmt.ExecuteQuery(context.Background())
+       suite.NoError(err)
+       defer reader.Release()
+}
+
 type TimeoutTestServer struct {
        flightsql.BaseServer
 }
diff --git a/go/adbc/go.mod b/go/adbc/go.mod
index 6ba0443..b428162 100644
--- a/go/adbc/go.mod
+++ b/go/adbc/go.mod
@@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc
 go 1.18
 
 require (
-       github.com/apache/arrow/go/v12 v12.0.0-20230307201612-6fdf1e520a76
+       github.com/apache/arrow/go/v12 v12.0.0-20230421000340-388f3a88c647
        github.com/bluele/gcache v0.0.2
        github.com/stretchr/testify v1.8.1
        golang.org/x/exp v0.0.0-20230206171751-46f607a40771
diff --git a/go/adbc/go.sum b/go/adbc/go.sum
index 7091980..3fd31dd 100644
--- a/go/adbc/go.sum
+++ b/go/adbc/go.sum
@@ -1,8 +1,8 @@
 github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c 
h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU=
 github.com/andybalholm/brotli v1.0.4 
h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
 github.com/andybalholm/brotli v1.0.4/go.mod 
h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
-github.com/apache/arrow/go/v12 v12.0.0-20230307201612-6fdf1e520a76 
h1:6O9PR51TZFveY6xYg4RCCj23hIFyL7vR5hav0zc+Vss=
-github.com/apache/arrow/go/v12 v12.0.0-20230307201612-6fdf1e520a76/go.mod 
h1:p9SbFzxqBIUcxDFFqy4aLDT6RdooPoA18Bru+OL1gbk=
+github.com/apache/arrow/go/v12 v12.0.0-20230421000340-388f3a88c647 
h1:qsBSonbDQRwj8HyUeD/NSaA0e2bT4f3kgcqkSqVZzdo=
+github.com/apache/arrow/go/v12 v12.0.0-20230421000340-388f3a88c647/go.mod 
h1:d+tV/eHZZ7Dz7RPrFKtPK02tpr+c9/PEd/zm8mDS9Vg=
 github.com/apache/thrift v0.17.0 
h1:cMd2aj52n+8VoAtvSvLn4kDC3aZ6IAkBuqWQ2IDu7wo=
 github.com/apache/thrift v0.17.0/go.mod 
h1:OLxhMRJxomX+1I/KUw03qoV3mMz16BwaKI+d4fPBx7Q=
 github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw=
diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go
index e3e8212..c31c377 100644
--- a/go/adbc/validation/validation.go
+++ b/go/adbc/validation/validation.go
@@ -242,11 +242,11 @@ func (c *ConnectionTests) TestMetadataGetTableSchema() {
        c.Require().NoError(err)
 
        expectedSchema := arrow.NewSchema([]arrow.Field{
-               {Name: "ints", Type: arrow.PrimitiveTypes.Int64,
+               {Name: "ints", Type: arrow.PrimitiveTypes.Int64, Nullable: true,
                        Metadata: arrow.MetadataFrom(map[string]string{
                                flightsql.ScaleKey: "15", 
flightsql.IsReadOnlyKey: "0", flightsql.IsAutoIncrementKey: "0",
                                flightsql.TableNameKey: "sample_test", 
flightsql.PrecisionKey: "10"})},
-               {Name: "strings", Type: arrow.BinaryTypes.String,
+               {Name: "strings", Type: arrow.BinaryTypes.String, Nullable: 
true,
                        Metadata: arrow.MetadataFrom(map[string]string{
                                flightsql.ScaleKey: "15", 
flightsql.IsReadOnlyKey: "0", flightsql.IsAutoIncrementKey: "0",
                                flightsql.TableNameKey: "sample_test"})},

Reply via email to