This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new a14aad9 fix(go/adbc/driver/flightsql): use updated authorization
header from server (#594)
a14aad9 is described below
commit a14aad9bd7021a2b4a74943185b2a6a11384965f
Author: David Li <[email protected]>
AuthorDate: Sat Apr 22 08:22:14 2023 +0900
fix(go/adbc/driver/flightsql): use updated authorization header from server
(#594)
Fixes #584.
---
go/adbc/driver/flightsql/flightsql_adbc.go | 17 ++-
go/adbc/driver/flightsql/flightsql_adbc_test.go | 131 ++++++++++++++++++++++++
go/adbc/go.mod | 2 +-
go/adbc/go.sum | 4 +-
go/adbc/validation/validation.go | 4 +-
5 files changed, 152 insertions(+), 6 deletions(-)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go
b/go/adbc/driver/flightsql/flightsql_adbc.go
index e78f72a..5ab6d59 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -45,6 +45,7 @@ import (
"runtime/debug"
"strconv"
"strings"
+ "sync"
"time"
"github.com/apache/arrow-adbc/go/adbc"
@@ -512,14 +513,27 @@ func streamTimeoutInterceptor(ctx context.Context, desc
*grpc.StreamDesc, cc *gr
}
type bearerAuthMiddleware struct {
- hdrs metadata.MD
+ mutex sync.RWMutex
+ hdrs metadata.MD
}
func (b *bearerAuthMiddleware) StartCall(ctx context.Context) context.Context {
md, _ := metadata.FromOutgoingContext(ctx)
+ b.mutex.RLock()
+ defer b.mutex.RUnlock()
return metadata.NewOutgoingContext(ctx, metadata.Join(md, b.hdrs))
}
+func (b *bearerAuthMiddleware) HeadersReceived(ctx context.Context, md
metadata.MD) {
+ // apache/arrow-adbc#584
+ headers := md.Get("authorization")
+ if len(headers) > 0 {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+ b.hdrs.Set("authorization", headers...)
+ }
+}
+
func getFlightClient(ctx context.Context, loc string, d *database)
(*flightsql.Client, error) {
authMiddle := &bearerAuthMiddleware{hdrs: d.hdrs.Copy()}
middleware := []flight.ClientMiddleware{
@@ -564,6 +578,7 @@ func getFlightClient(ctx context.Context, loc string, d
*database) (*flightsql.C
}
if md, ok := metadata.FromOutgoingContext(ctx); ok {
+ // No need to worry about lock here since we are sole
owner
authMiddle.hdrs.Set("authorization",
md.Get("Authorization")[0])
}
}
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 09e6ac8..b9265ca 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -49,8 +49,10 @@ import (
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
+ "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
@@ -252,6 +254,7 @@ func TestADBCFlightSQL(t *testing.T) {
suite.Run(t, &DefaultDialOptionsTests{Quirks: q})
suite.Run(t, &HeaderTests{Quirks: q})
+ suite.Run(t, &AuthnTests{})
suite.Run(t, &OptionTests{Quirks: q})
suite.Run(t, &PartitionTests{Quirks: q})
suite.Run(t, &StatementTests{Quirks: q})
@@ -708,6 +711,134 @@ func (suite *HeaderTests) TestPrepared() {
suite.Contains(suite.Quirks.middle.recordedHeaders.Get("x-header-two"),
"value 2")
}
+type AuthnTests struct {
+ suite.Suite
+
+ s flight.Server
+ db adbc.Database
+ cnxn adbc.Connection
+}
+
+type AuthnTestServer struct {
+ flightsql.BaseServer
+}
+
+func (server *AuthnTestServer) GetFlightInfoStatement(ctx context.Context, cmd
flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.FlightInfo,
error) {
+ md := metadata.MD{}
+ md.Set("authorization", "Bearer final")
+ if err := grpc.SendHeader(ctx, md); err != nil {
+ return nil, err
+ }
+ tkt, _ := flightsql.CreateStatementQueryTicket([]byte{})
+ info := &flight.FlightInfo{
+ FlightDescriptor: desc,
+ Endpoint: []*flight.FlightEndpoint{
+ {Ticket: &flight.Ticket{Ticket: tkt}},
+ },
+ TotalRecords: -1,
+ TotalBytes: -1,
+ }
+ return info, nil
+}
+
+func (server *AuthnTestServer) DoGetStatement(ctx context.Context, tkt
flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk,
error) {
+ sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type:
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc,
strings.NewReader(`[{"a": 5}]`))
+ if err != nil {
+ return nil, nil, err
+ }
+
+ ch := make(chan flight.StreamChunk)
+ go func() {
+ defer close(ch)
+ ch <- flight.StreamChunk{
+ Data: rec,
+ Desc: nil,
+ Err: nil,
+ }
+ }()
+ return sc, ch, nil
+}
+
+func authnTestUnary(ctx context.Context, req interface{}, info
*grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error)
{
+ md, ok := metadata.FromIncomingContext(ctx)
+ if !ok {
+ return nil, status.Error(codes.InvalidArgument, "Could not get
metadata")
+ }
+ auth := md.Get("authorization")
+ if len(auth) == 0 {
+ return nil, status.Error(codes.Unauthenticated, "No token")
+ } else if auth[0] != "Bearer initial" {
+ return nil, status.Error(codes.Unauthenticated, "Invalid token
for unary call: "+auth[0])
+ }
+
+ md.Set("authorization", "Bearer final")
+ ctx = metadata.NewOutgoingContext(ctx, md)
+ return handler(ctx, req)
+}
+
+func authnTestStream(srv interface{}, ss grpc.ServerStream, info
*grpc.StreamServerInfo, handler grpc.StreamHandler) error {
+ md, ok := metadata.FromIncomingContext(ss.Context())
+ if !ok {
+ return status.Error(codes.InvalidArgument, "Could not get
metadata")
+ }
+ auth := md.Get("authorization")
+ if len(auth) == 0 {
+ return status.Error(codes.Unauthenticated, "No token")
+ } else if auth[0] != "Bearer final" {
+ return status.Error(codes.Unauthenticated, "Invalid token for
stream call: "+auth[0])
+ }
+
+ return handler(srv, ss)
+}
+
+func (suite *AuthnTests) SetupSuite() {
+ suite.s = flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+ {Stream: authnTestStream, Unary: authnTestUnary},
+ })
+
suite.s.RegisterFlightService(flightsql.NewFlightServer(&AuthnTestServer{}))
+ suite.Require().NoError(suite.s.Init("localhost:0"))
+ suite.s.SetShutdownOnSignals(os.Interrupt, os.Kill)
+ go func() {
+ _ = suite.s.Serve()
+ }()
+
+ uri := "grpc+tcp://" + suite.s.Addr().String()
+ var err error
+ suite.db, err = (driver.Driver{}).NewDatabase(map[string]string{
+ "uri": uri,
+ driver.OptionAuthorizationHeader: "Bearer initial",
+ })
+ suite.Require().NoError(err)
+}
+
+func (suite *AuthnTests) SetupTest() {
+ var err error
+ suite.cnxn, err = suite.db.Open(context.Background())
+ suite.Require().NoError(err)
+}
+
+func (suite *AuthnTests) TearDownTest() {
+ suite.Require().NoError(suite.cnxn.Close())
+}
+
+func (suite *AuthnTests) TearDownSuite() {
+ suite.db = nil
+ suite.s.Shutdown()
+}
+
+func (suite *AuthnTests) TestBearerTokenUpdated() {
+ // apache/arrow-adbc#584: when setting the auth header directly, the
client should use any updated token value from the server if given
+ stmt, err := suite.cnxn.NewStatement()
+ suite.Require().NoError(err)
+ defer stmt.Close()
+
+ suite.Require().NoError(stmt.SetSqlQuery("timeout"))
+ reader, _, err := stmt.ExecuteQuery(context.Background())
+ suite.NoError(err)
+ defer reader.Release()
+}
+
type TimeoutTestServer struct {
flightsql.BaseServer
}
diff --git a/go/adbc/go.mod b/go/adbc/go.mod
index 6ba0443..b428162 100644
--- a/go/adbc/go.mod
+++ b/go/adbc/go.mod
@@ -20,7 +20,7 @@ module github.com/apache/arrow-adbc/go/adbc
go 1.18
require (
- github.com/apache/arrow/go/v12 v12.0.0-20230307201612-6fdf1e520a76
+ github.com/apache/arrow/go/v12 v12.0.0-20230421000340-388f3a88c647
github.com/bluele/gcache v0.0.2
github.com/stretchr/testify v1.8.1
golang.org/x/exp v0.0.0-20230206171751-46f607a40771
diff --git a/go/adbc/go.sum b/go/adbc/go.sum
index 7091980..3fd31dd 100644
--- a/go/adbc/go.sum
+++ b/go/adbc/go.sum
@@ -1,8 +1,8 @@
github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c
h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU=
github.com/andybalholm/brotli v1.0.4
h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod
h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
-github.com/apache/arrow/go/v12 v12.0.0-20230307201612-6fdf1e520a76
h1:6O9PR51TZFveY6xYg4RCCj23hIFyL7vR5hav0zc+Vss=
-github.com/apache/arrow/go/v12 v12.0.0-20230307201612-6fdf1e520a76/go.mod
h1:p9SbFzxqBIUcxDFFqy4aLDT6RdooPoA18Bru+OL1gbk=
+github.com/apache/arrow/go/v12 v12.0.0-20230421000340-388f3a88c647
h1:qsBSonbDQRwj8HyUeD/NSaA0e2bT4f3kgcqkSqVZzdo=
+github.com/apache/arrow/go/v12 v12.0.0-20230421000340-388f3a88c647/go.mod
h1:d+tV/eHZZ7Dz7RPrFKtPK02tpr+c9/PEd/zm8mDS9Vg=
github.com/apache/thrift v0.17.0
h1:cMd2aj52n+8VoAtvSvLn4kDC3aZ6IAkBuqWQ2IDu7wo=
github.com/apache/thrift v0.17.0/go.mod
h1:OLxhMRJxomX+1I/KUw03qoV3mMz16BwaKI+d4fPBx7Q=
github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw=
diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go
index e3e8212..c31c377 100644
--- a/go/adbc/validation/validation.go
+++ b/go/adbc/validation/validation.go
@@ -242,11 +242,11 @@ func (c *ConnectionTests) TestMetadataGetTableSchema() {
c.Require().NoError(err)
expectedSchema := arrow.NewSchema([]arrow.Field{
- {Name: "ints", Type: arrow.PrimitiveTypes.Int64,
+ {Name: "ints", Type: arrow.PrimitiveTypes.Int64, Nullable: true,
Metadata: arrow.MetadataFrom(map[string]string{
flightsql.ScaleKey: "15",
flightsql.IsReadOnlyKey: "0", flightsql.IsAutoIncrementKey: "0",
flightsql.TableNameKey: "sample_test",
flightsql.PrecisionKey: "10"})},
- {Name: "strings", Type: arrow.BinaryTypes.String,
+ {Name: "strings", Type: arrow.BinaryTypes.String, Nullable:
true,
Metadata: arrow.MetadataFrom(map[string]string{
flightsql.ScaleKey: "15",
flightsql.IsReadOnlyKey: "0", flightsql.IsAutoIncrementKey: "0",
flightsql.TableNameKey: "sample_test"})},