This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 2ea2fcb8a feat(go/adbc/driver/flightsql): Add OAuth Support to Flight
Client (#2651)
2ea2fcb8a is described below
commit 2ea2fcb8ae6062b3a0599ef9f9eeda207e16e4eb
Author: Hélder Gregório <[email protected]>
AuthorDate: Thu Apr 17 14:32:22 2025 +0000
feat(go/adbc/driver/flightsql): Add OAuth Support to Flight Client (#2651)
## Description
This pull request introduces OAuth support to the Flight client in the
GO driver. The changes include the addition of OAuth access token
support, implementation of token exchange and client credentials OAuth
flows.
## Related Issues
- Closes #[2650](https://github.com/apache/arrow-adbc/issues/2650)
## Changes Made
1. Added `token` as a database option
1. Added support for [Token
Exchange](https://datatracker.ietf.org/doc/html/rfc8693). If configured,
`token` gets exchanged and the result is added to the `Authorization`
header as a `Bearer` token
1. Added support for [Client
Credentials](https://datatracker.ietf.org/doc/html/rfc6749#section-4.4).
If configured, `client_id` and `client_secret` are used to obtain a
access token that is added to the `Authorization` header as a `Bearer`
token
1. Added new driver options to allow third-party applications to
configure oauth flows:
1. Added tests
Here's the markdown code for the OAuth 2.0 configuration options table:
markdown# OAuth 2.0 Configuration Options
| Option | Description |
|--------|-------------|
| `adbc.flight.sql.oauth.flow` | Specifies the OAuth 2.0 flow type to
use. Possible values: `client_credentials`, `token_exchange` |
| `adbc.flight.sql.oauth.client_id` | Unique identifier issued to the
client application by the authorization server |
| `adbc.flight.sql.oauth.client_secret` | Secret associated to the
client_id. Used to authenticate the client application to the
authorization server |
| `adbc.flight.sql.oauth.token_uri` | The endpoint URL where the client
application requests tokens from the authorization server |
| `adbc.flight.sql.oauth.scope` | Space-separated list of permissions
that the client is requesting access to (e.g `"read.all
offline_access"`) |
| `adbc.flight.sql.oauth.exchange.subject_token` | The security token
that the client application wants to exchange |
| `adbc.flight.sql.oauth.exchange.subject_token_type` | Identifier for
the type of the subject token. Check list below for supported token
types. |
| `adbc.flight.sql.oauth.exchange.actor_token` | A security token that
represents the identity of the acting party |
| `adbc.flight.sql.oauth.exchange.actor_token_type` | Identifier for the
type of the actor token. Check list below for supported token types. |
| `adbc.flight.sql.oauth.exchange.aud` | The intended audience for the
requested security token |
| `adbc.flight.sql.oauth.exchange.resource` | The resource server where
the client intends to use the requested security token |
| `adbc.flight.sql.oauth.exchange.scope` | Specific permissions
requested for the new token |
| `adbc.flight.sql.oauth.exchange.requested_token_type` | The type of
token the client wants to receive in exchange. Check list below for
supported token types. |
**Supported token types:**
* `urn:ietf:params:oauth:token-type:access_token`
* `urn:ietf:params:oauth:token-type:refresh_token`
* `urn:ietf:params:oauth:token-type:id_token`
* `urn:ietf:params:oauth:token-type:saml1`
* `urn:ietf:params:oauth:token-type:saml2`
* `urn:ietf:params:oauth:token-type:jwt`
---
docs/source/driver/flight_sql.rst | 65 +++-
.../driver/flightsql/flightsql_adbc_server_test.go | 375 ++++++++++++++++++++-
go/adbc/driver/flightsql/flightsql_database.go | 74 +++-
go/adbc/driver/flightsql/flightsql_driver.go | 17 +
go/adbc/driver/flightsql/flightsql_oauth.go | 151 +++++++++
5 files changed, 650 insertions(+), 32 deletions(-)
diff --git a/docs/source/driver/flight_sql.rst
b/docs/source/driver/flight_sql.rst
index 983a6162c..d3a588be5 100644
--- a/docs/source/driver/flight_sql.rst
+++ b/docs/source/driver/flight_sql.rst
@@ -159,6 +159,12 @@ few optional authentication schemes:
header will then be sent back as the ``authorization`` header on all
future requests.
+- OAuth 2.0 authentication flows.
+
+ The client provides :ref:`configurations <oauth-configurations>` to allow
client application to obtain access
+ tokens from an authorization server. The obtained token is then used
+ on the ``authorization`` header on all future requests.
+
Bulk Ingestion
--------------
@@ -246,10 +252,67 @@ to :c:struct:`AdbcDatabase`, :c:struct:`AdbcConnection`,
and
Add the header ``<HEADER NAME>`` to outgoing requests with the given
value.
- Python:
:attr:`adbc_driver_flightsql.ConnectionOptions.RPC_CALL_HEADER_PREFIX`
+ Python:
:attr:`adbc_driver_flightsql.ConnectionOptions.RPC_CALL_HEADER_PREFIX`
.. warning:: Header names must be in all lowercase.
+
+OAuth 2.0 Options
+-----------------------
+.. _oauth-configurations:
+
+Supported configurations to obtain tokens using OAuth 2.0 authentication flows.
+
+``adbc.flight.sql.oauth.flow``
+ Specifies the OAuth 2.0 flow type to use. Possible values:
``client_credentials``, ``token_exchange``
+
+``adbc.flight.sql.oauth.client_id``
+ Unique identifier issued to the client application by the authorization
server
+
+``adbc.flight.sql.oauth.client_secret``
+ Secret associated to the client_id. Used to authenticate the client
application to the authorization server
+
+``adbc.flight.sql.oauth.token_uri``
+ The endpoint URL where the client application requests tokens from the
authorization server
+
+``adbc.flight.sql.oauth.scope``
+ Space-separated list of permissions that the client is requesting access to
(e.g ``"read.all offline_access"``)
+
+``adbc.flight.sql.oauth.exchange.subject_token``
+ The security token that the client application wants to exchange
+
+``adbc.flight.sql.oauth.exchange.subject_token_type``
+ Identifier for the type of the subject token.
+ Check list below for supported token types.
+
+``adbc.flight.sql.oauth.exchange.actor_token``
+ A security token that represents the identity of the acting party
+
+``adbc.flight.sql.oauth.exchange.actor_token_type``
+ Identifier for the type of the actor token.
+ Check list below for supported token types.
+``adbc.flight.sql.oauth.exchange.aud``
+ The intended audience for the requested security token
+
+``adbc.flight.sql.oauth.exchange.resource``
+ The resource server where the client intends to use the requested security
token
+
+``adbc.flight.sql.oauth.exchange.scope``
+ Specific permissions requested for the new token
+
+``adbc.flight.sql.oauth.exchange.requested_token_type``
+ The type of token the client wants to receive in exchange.
+ Check list below for supported token types.
+
+
+Supported token types:
+ - ``urn:ietf:params:oauth:token-type:access_token``
+ - ``urn:ietf:params:oauth:token-type:refresh_token``
+ - ``urn:ietf:params:oauth:token-type:id_token``
+ - ``urn:ietf:params:oauth:token-type:saml1``
+ - ``urn:ietf:params:oauth:token-type:saml2``
+ - ``urn:ietf:params:oauth:token-type:jwt``
+
Distributed Result Sets
-----------------------
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index c8a59d72a..f2cd0060c 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -20,11 +20,21 @@
package flightsql_test
import (
+ "bytes"
"context"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
"encoding/json"
+ "encoding/pem"
"errors"
"fmt"
+ "math/big"
"net"
+ "net/http"
+ "net/http/httptest"
"net/textproto"
"os"
"strconv"
@@ -50,6 +60,7 @@ import (
"golang.org/x/exp/maps"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
+ "google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
@@ -69,16 +80,14 @@ type ServerBasedTests struct {
}
func (suite *ServerBasedTests) DoSetupSuite(srv flightsql.Server,
srvMiddleware []flight.ServerMiddleware, dbArgs map[string]string, dialOpts
...grpc.DialOption) {
- suite.s = flight.NewServerWithMiddleware(srvMiddleware)
- suite.s.RegisterFlightService(flightsql.NewFlightServer(srv))
- suite.Require().NoError(suite.s.Init("localhost:0"))
- suite.s.SetShutdownOnSignals(os.Interrupt, os.Kill)
- go func() {
- _ = suite.s.Serve()
- }()
+ suite.setupFlightServer(srv, srvMiddleware)
- uri := "grpc+tcp://" + suite.s.Addr().String()
+ suite.setupDatabase(dbArgs, dialOpts...)
+}
+
+func (suite *ServerBasedTests) setupDatabase(dbArgs map[string]string,
dialOpts ...grpc.DialOption) {
var err error
+ uri := "grpc+tcp://" + suite.s.Addr().String()
args := map[string]string{
"uri": uri,
@@ -88,6 +97,16 @@ func (suite *ServerBasedTests) DoSetupSuite(srv
flightsql.Server, srvMiddleware
suite.Require().NoError(err)
}
+func (suite *ServerBasedTests) setupFlightServer(srv flightsql.Server,
srvMiddleware []flight.ServerMiddleware, srvOpts ...grpc.ServerOption) {
+ suite.s = flight.NewServerWithMiddleware(srvMiddleware, srvOpts...)
+ suite.s.RegisterFlightService(flightsql.NewFlightServer(srv))
+ suite.Require().NoError(suite.s.Init("localhost:0"))
+ suite.s.SetShutdownOnSignals(os.Interrupt, os.Kill)
+ go func() {
+ _ = suite.s.Serve()
+ }()
+}
+
func (suite *ServerBasedTests) SetupTest() {
var err error
suite.cnxn, err = suite.db.Open(context.Background())
@@ -104,6 +123,59 @@ func (suite *ServerBasedTests) TearDownSuite() {
suite.s.Shutdown()
}
+func (suite *ServerBasedTests) generateCertOption() grpc.ServerOption {
+ // Generate a self-signed certificate in-process for testing
+ privKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ suite.Require().NoError(err)
+ certTemplate := x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ Organization: []string{"Unit Tests Incorporated"},
+ },
+ IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1),
net.IPv6loopback},
+ NotBefore: time.Now(),
+ NotAfter: time.Now().Add(time.Hour),
+ KeyUsage: x509.KeyUsageKeyEncipherment,
+ ExtKeyUsage:
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ }
+ certDer, err := x509.CreateCertificate(rand.Reader, &certTemplate,
&certTemplate, &privKey.PublicKey, privKey)
+ suite.Require().NoError(err)
+ buffer := &bytes.Buffer{}
+ suite.Require().NoError(pem.Encode(buffer, &pem.Block{Type:
"CERTIFICATE", Bytes: certDer}))
+ certBytes := make([]byte, buffer.Len())
+ copy(certBytes, buffer.Bytes())
+ buffer.Reset()
+ suite.Require().NoError(pem.Encode(buffer, &pem.Block{Type: "RSA
PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}))
+ keyBytes := make([]byte, buffer.Len())
+ copy(keyBytes, buffer.Bytes())
+
+ cert, err := tls.X509KeyPair(certBytes, keyBytes)
+ suite.Require().NoError(err)
+
+ suite.Require().NoError(err)
+ tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
+ tlsCreds := credentials.NewTLS(tlsConfig)
+
+ return grpc.Creds(tlsCreds)
+}
+
+func (suite *ServerBasedTests) openAndExecuteQuery(query string) {
+ var err error
+ suite.cnxn, err = suite.db.Open(context.Background())
+ suite.Require().NoError(err)
+ defer suite.cnxn.Close()
+
+ stmt, err := suite.cnxn.NewStatement()
+ suite.Require().NoError(err)
+ defer stmt.Close()
+
+ suite.Require().NoError(stmt.SetSqlQuery(query))
+ reader, _, err := stmt.ExecuteQuery(context.Background())
+ suite.NoError(err)
+ defer reader.Release()
+}
+
// ---- Tests --------------------
func TestAuthn(t *testing.T) {
@@ -150,6 +222,10 @@ func TestGetObjects(t *testing.T) {
suite.Run(t, &GetObjectsTests{})
}
+func TestOauth(t *testing.T) {
+ suite.Run(t, &OAuthTests{})
+}
+
// ---- AuthN Tests --------------------
type AuthnTestServer struct {
@@ -230,23 +306,288 @@ type AuthnTests struct {
}
func (suite *AuthnTests) SetupSuite() {
- suite.DoSetupSuite(&AuthnTestServer{}, []flight.ServerMiddleware{
+ suite.setupFlightServer(&AuthnTestServer{}, []flight.ServerMiddleware{
{Stream: authnTestStream, Unary: authnTestUnary},
- }, map[string]string{
- driver.OptionAuthorizationHeader: "Bearer initial",
})
}
+func (suite *AuthnTests) SetupTest() {
+ suite.setupDatabase(map[string]string{
+ "uri": "grpc+tcp://" + suite.s.Addr().String(),
+ })
+}
+
+func (suite *AuthnTests) TearDownTest() {
+ suite.NoError(suite.db.Close())
+ suite.db = nil
+}
+
+func (suite *AuthnTests) TearDownSuite() {
+ suite.s.Shutdown()
+}
+
func (suite *AuthnTests) TestBearerTokenUpdated() {
+ err := suite.db.SetOptions(map[string]string{
+ driver.OptionAuthorizationHeader: "Bearer initial",
+ })
+ suite.Require().NoError(err)
+
// apache/arrow-adbc#584: when setting the auth header directly, the
client should use any updated token value from the server if given
- stmt, err := suite.cnxn.NewStatement()
+
+ suite.openAndExecuteQuery("a-query")
+}
+
+type OAuthTests struct {
+ ServerBasedTests
+
+ oauthServer *httptest.Server
+ mockOAuthServer *MockOAuthServer
+}
+
+// MockOAuthServer simulates an OAuth 2.0 server for testing
+type MockOAuthServer struct {
+ // Track calls to validate server behavior
+ clientCredentialsCalls int
+ tokenExchangeCalls int
+}
+
+func (m *MockOAuthServer) handleTokenRequest(w http.ResponseWriter, r
*http.Request) {
+ // Parse the form to get the request parameters
+ if err := r.ParseForm(); err != nil {
+ http.Error(w, "Invalid request", http.StatusBadRequest)
+ return
+ }
+
+ grantType := r.FormValue("grant_type")
+
+ switch grantType {
+ case "client_credentials":
+ m.clientCredentialsCalls++
+ // Validate client credentials
+ clientID := r.FormValue("client_id")
+ clientSecret := r.FormValue("client_secret")
+
+ if clientID == "test-client" && clientSecret == "test-secret" {
+ // Return a valid token response
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{
+ "access_token": "test-client-token",
+ "token_type": "bearer",
+ "expires_in": 3600
+ }`))
+
+ return
+ }
+
+ case "urn:ietf:params:oauth:grant-type:token-exchange":
+ m.tokenExchangeCalls++
+ // Validate token exchange parameters
+ subjectToken := r.FormValue("subject_token")
+ subjectTokenType := r.FormValue("subject_token_type")
+
+ if subjectToken == "test-subject-token" &&
+ subjectTokenType ==
"urn:ietf:params:oauth:token-type:jwt" {
+ // Return a valid token response
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{
+ "access_token": "test-exchanged-token",
+ "token_type": "bearer",
+ "expires_in": 3600
+ }`))
+ return
+ }
+ }
+
+ // Default: return error for invalid request
+ http.Error(w, "Invalid request", http.StatusBadRequest)
+}
+
+func oauthTestUnary(ctx context.Context, req interface{}, info
*grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error)
{
+ md, ok := metadata.FromIncomingContext(ctx)
+ if !ok {
+ return nil, status.Error(codes.InvalidArgument, "Could not get
metadata")
+ }
+ auth := md.Get("authorization")
+ if len(auth) == 0 {
+ return nil, status.Error(codes.Unauthenticated, "No token")
+ } else if auth[0] != "Bearer test-exchanged-token" && auth[0] !=
"Bearer test-client-token" {
+ return nil, status.Error(codes.Unauthenticated, "Invalid token
for unary call: "+auth[0])
+ }
+
+ md.Set("authorization", "Bearer final")
+ ctx = metadata.NewOutgoingContext(ctx, md)
+ return handler(ctx, req)
+}
+
+func (suite *OAuthTests) SetupSuite() {
+ suite.mockOAuthServer = &MockOAuthServer{}
+ suite.oauthServer =
httptest.NewServer(http.HandlerFunc(suite.mockOAuthServer.handleTokenRequest))
+
+ suite.setupFlightServer(&AuthnTestServer{}, []flight.ServerMiddleware{
+ {Unary: oauthTestUnary},
+ }, suite.generateCertOption())
+}
+
+func (suite *OAuthTests) TearDownSuite() {
+ suite.oauthServer.Close()
+ suite.s.Shutdown()
+}
+
+func (suite *OAuthTests) SetupTest() {
+ suite.setupDatabase(map[string]string{
+ "uri": "grpc+tls://" + suite.s.Addr().String(),
+ })
+}
+
+func (suite *OAuthTests) TearDownTest() {
+ suite.NoError(suite.db.Close())
+ suite.db = nil
+}
+
+func (suite *OAuthTests) TestTokenExchangeFlow() {
+ err := suite.db.SetOptions(map[string]string{
+ driver.OptionKeyOauthFlow: driver.TokenExchange,
+ driver.OptionKeySubjectToken: "test-subject-token",
+ driver.OptionKeySubjectTokenType:
"urn:ietf:params:oauth:token-type:jwt",
+ driver.OptionKeyTokenURI: suite.oauthServer.URL,
+ driver.OptionSSLSkipVerify: adbc.OptionValueEnabled,
+ })
suite.Require().NoError(err)
- defer stmt.Close()
- suite.Require().NoError(stmt.SetSqlQuery("timeout"))
- reader, _, err := stmt.ExecuteQuery(context.Background())
- suite.NoError(err)
- defer reader.Release()
+ suite.openAndExecuteQuery("a-query")
+ suite.Equal(1, suite.mockOAuthServer.tokenExchangeCalls, "Token
exchange flow should be called once")
+}
+
+func (suite *OAuthTests) TestClientCredentialsFlow() {
+ err := suite.db.SetOptions(map[string]string{
+ driver.OptionKeyOauthFlow: driver.ClientCredentials,
+ driver.OptionKeyClientId: "test-client",
+ driver.OptionKeyClientSecret: "test-secret",
+ driver.OptionKeyTokenURI: suite.oauthServer.URL,
+ driver.OptionSSLSkipVerify: adbc.OptionValueEnabled,
+ })
+ suite.Require().NoError(err)
+
+ suite.cnxn, err = suite.db.Open(context.Background())
+ suite.Require().NoError(err)
+ defer suite.cnxn.Close()
+
+ suite.openAndExecuteQuery("a-query")
+ // golang/oauth2 tries to call the token endpoint sending the client
credentials in the authentication header,
+ // if it fails, it retries sending the client credentials in the
request body.
+ // See https://code.google.com/p/goauth2/issues/detail?id=31 for
background.
+ suite.Equal(2, suite.mockOAuthServer.clientCredentialsCalls, "Client
credentials flow should be called once")
+}
+
+func (suite *OAuthTests) TestFailOauthWithTokenSet() {
+ err := suite.db.SetOptions(map[string]string{
+ driver.OptionAuthorizationHeader: "Bearer test-client-token",
+ driver.OptionKeyOauthFlow: driver.ClientCredentials,
+ driver.OptionKeyClientId: "test-client",
+ driver.OptionKeyClientSecret: "test-secret",
+ driver.OptionKeyTokenURI: suite.oauthServer.URL,
+ })
+ suite.Error(err, "Expected error for missing parameters")
+ suite.Contains(err.Error(), "Authentication conflict: Use either
Authorization header OR username/password parameter")
+}
+
+func (suite *OAuthTests) TestMissingRequiredParamsTokenExchange() {
+ testCases := []struct {
+ name string
+ options map[string]string
+ expectedErrorMsg string
+ }{
+ {
+ name: "Missing token",
+ options: map[string]string{
+ driver.OptionKeyOauthFlow:
driver.TokenExchange,
+ driver.OptionKeySubjectTokenType:
"urn:ietf:params:oauth:token-type:jwt",
+ driver.OptionKeyTokenURI:
suite.oauthServer.URL,
+ },
+ expectedErrorMsg: "token exchange grant requires
adbc.flight.sql.oauth.exchange.subject_token",
+ },
+ {
+ name: "Missing subject token type",
+ options: map[string]string{
+ driver.OptionKeyOauthFlow:
driver.TokenExchange,
+ driver.OptionKeySubjectToken:
"test-subject-token",
+ driver.OptionKeyTokenURI:
suite.oauthServer.URL,
+ },
+ expectedErrorMsg: "token exchange grant requires
adbc.flight.sql.oauth.exchange.subject_token_type",
+ },
+ {
+ name: "Missing token URI",
+ options: map[string]string{
+ driver.OptionKeyOauthFlow:
driver.TokenExchange,
+ driver.OptionKeySubjectToken:
"test-subject-token",
+ driver.OptionKeySubjectTokenType:
"urn:ietf:params:oauth:token-type:jwt",
+ },
+ expectedErrorMsg: "token exchange grant requires
adbc.flight.sql.oauth.token_uri",
+ },
+ }
+
+ for _, tc := range testCases {
+ suite.Run(tc.name, func() {
+ // We need to set options with the driver's SetOptions
method
+ err := suite.db.SetOptions(tc.options)
+ suite.Error(err, "Expected error for missing
parameters")
+ suite.Contains(err.Error(), tc.expectedErrorMsg)
+ })
+ }
+}
+func (suite *OAuthTests) TestMissingRequiredParamsClientCredentials() {
+ testCases := []struct {
+ name string
+ options map[string]string
+ expectedErrorMsg string
+ }{
+ {
+ name: "Missing client ID",
+ options: map[string]string{
+ driver.OptionKeyOauthFlow:
driver.ClientCredentials,
+ driver.OptionKeyClientSecret: "test-secret",
+ driver.OptionKeyTokenURI:
suite.oauthServer.URL,
+ },
+ expectedErrorMsg: "client credentials grant requires
adbc.flight.sql.oauth.client_id",
+ },
+ {
+ name: "Missing client secret",
+ options: map[string]string{
+ driver.OptionKeyOauthFlow:
driver.ClientCredentials,
+ driver.OptionKeyClientId: "test-client",
+ driver.OptionKeyTokenURI:
suite.oauthServer.URL,
+ },
+ expectedErrorMsg: "client credentials grant requires
adbc.flight.sql.oauth.client_secret",
+ },
+ {
+ name: "Missing token URI",
+ options: map[string]string{
+ driver.OptionKeyOauthFlow:
driver.ClientCredentials,
+ driver.OptionKeyClientId: "test-client",
+ driver.OptionKeyClientSecret: "test-secret",
+ },
+ expectedErrorMsg: "client credentials grant requires
adbc.flight.sql.oauth.token_uri",
+ },
+ }
+
+ for _, tc := range testCases {
+ suite.Run(tc.name, func() {
+ // We need to set options with the driver's SetOptions
method
+ err := suite.db.SetOptions(tc.options)
+ suite.Error(err, "Expected error for missing
parameters")
+ suite.Contains(err.Error(), tc.expectedErrorMsg)
+ })
+ }
+}
+
+func (suite *OAuthTests) TestInvalidOAuthFlow() {
+ err := suite.db.SetOptions(map[string]string{
+ driver.OptionKeyOauthFlow: "invalid-flow",
+ driver.OptionKeySubjectToken: "test-token",
+ })
+
+ suite.Error(err, "Expected error for invalid OAuth flow")
+ suite.Contains(err.Error(), "Not Implemented: oauth flow not
implemented: invalid-flow")
}
// ---- Grpc Dialer Options Tests --------------
diff --git a/go/adbc/driver/flightsql/flightsql_database.go
b/go/adbc/driver/flightsql/flightsql_database.go
index bbbcbbf06..e45eb4d5d 100644
--- a/go/adbc/driver/flightsql/flightsql_database.go
+++ b/go/adbc/driver/flightsql/flightsql_database.go
@@ -68,6 +68,7 @@ type databaseImpl struct {
enableCookies bool
options map[string]string
userDialOpts []grpc.DialOption
+ oauthToken credentials.PerRPCCredentials
}
func (d *databaseImpl) SetOptions(cnOptions map[string]string) error {
@@ -146,10 +147,12 @@ func (d *databaseImpl) SetOptions(cnOptions
map[string]string) error {
delete(cnOptions, OptionAuthorizationHeader)
}
+ const authConflictError = "Authentication conflict: Use either
Authorization header OR username/password parameter"
+
if u, ok := cnOptions[adbc.OptionKeyUsername]; ok {
if d.hdrs.Len() > 0 {
return adbc.Error{
- Msg: "Authorization header already provided,
do not provide user/pass also",
+ Msg: authConflictError,
Code: adbc.StatusInvalidArgument,
}
}
@@ -160,7 +163,7 @@ func (d *databaseImpl) SetOptions(cnOptions
map[string]string) error {
if p, ok := cnOptions[adbc.OptionKeyPassword]; ok {
if d.hdrs.Len() > 0 {
return adbc.Error{
- Msg: "Authorization header already provided,
do not provide user/pass also",
+ Msg: authConflictError,
Code: adbc.StatusInvalidArgument,
}
}
@@ -168,6 +171,33 @@ func (d *databaseImpl) SetOptions(cnOptions
map[string]string) error {
delete(cnOptions, adbc.OptionKeyPassword)
}
+ if flow, ok := cnOptions[OptionKeyOauthFlow]; ok {
+ if d.hdrs.Len() > 0 {
+ return adbc.Error{
+ Msg: authConflictError,
+ Code: adbc.StatusInvalidArgument,
+ }
+ }
+
+ var err error
+ switch flow {
+ case ClientCredentials:
+ d.oauthToken, err = newClientCredentials(cnOptions)
+ case TokenExchange:
+ d.oauthToken, err = newTokenExchangeFlow(cnOptions)
+ default:
+ return adbc.Error{
+ Msg: fmt.Sprintf("oauth flow not implemented:
%s", flow),
+ Code: adbc.StatusNotImplemented,
+ }
+ }
+
+ if err != nil {
+ return err
+ }
+ delete(cnOptions, OptionKeyOauthFlow)
+ }
+
var err error
if tv, ok := cnOptions[OptionTimeoutFetch]; ok {
if err = d.timeout.setTimeoutString(OptionTimeoutFetch, tv);
err != nil {
@@ -374,6 +404,10 @@ func getFlightClient(ctx context.Context, loc string, d
*databaseImpl, authMiddl
dialOpts := append(d.dialOpts.opts,
grpc.WithConnectParams(d.timeout.connectParams()),
grpc.WithTransportCredentials(creds), grpc.WithUserAgent("ADBC Flight SQL
Driver "+driverVersion))
dialOpts = append(dialOpts, d.userDialOpts...)
+ if d.oauthToken != nil {
+ dialOpts = append(dialOpts,
grpc.WithPerRPCCredentials(d.oauthToken))
+ }
+
d.Logger.DebugContext(ctx, "new client", "location", loc)
cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
if err != nil {
@@ -384,24 +418,30 @@ func getFlightClient(ctx context.Context, loc string, d
*databaseImpl, authMiddl
}
cl.Alloc = d.Alloc
+ // Authorization header is already set, continue
if len(authMiddle.hdrs.Get("authorization")) > 0 {
d.Logger.DebugContext(ctx, "reusing auth token", "location",
loc)
- } else {
- if d.user != "" || d.pass != "" {
- var header, trailer metadata.MD
- ctx, err = cl.Client.AuthenticateBasicToken(ctx,
d.user, d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout)
- if err != nil {
- return nil,
adbcFromFlightStatusWithDetails(err, header, trailer, "AuthenticateBasicToken")
- }
+ return cl, nil
+ }
- if md, ok := metadata.FromOutgoingContext(ctx); ok {
- authMiddle.mutex.Lock()
- defer authMiddle.mutex.Unlock()
- authMiddle.hdrs.Set("authorization",
md.Get("Authorization")[0])
- }
+ var authValue string
+
+ if d.user != "" || d.pass != "" {
+ var header, trailer metadata.MD
+ ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user,
d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout)
+ if err != nil {
+ return nil, adbcFromFlightStatusWithDetails(err,
header, trailer, "AuthenticateBasicToken")
+ }
+
+ if md, ok := metadata.FromOutgoingContext(ctx); ok {
+ authValue = md.Get("Authorization")[0]
}
}
+ if authValue != "" {
+ authMiddle.SetHeader(authValue)
+ }
+
return cl, nil
}
@@ -526,3 +566,9 @@ func (b *bearerAuthMiddleware) HeadersReceived(ctx
context.Context, md metadata.
b.hdrs.Set("authorization", headers...)
}
}
+
+func (b *bearerAuthMiddleware) SetHeader(authValue string) {
+ b.mutex.Lock()
+ defer b.mutex.Unlock()
+ b.hdrs.Set("authorization", authValue)
+}
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go
b/go/adbc/driver/flightsql/flightsql_driver.go
index 9e517c716..ff1e74bb5 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -66,6 +66,23 @@ const (
OptionStringListSessionOptionPrefix =
"adbc.flight.sql.session.optionstringlist."
OptionLastFlightInfo =
"adbc.flight.sql.statement.exec.last_flight_info"
infoDriverName = "ADBC Flight SQL Driver - Go"
+
+ // Oauth2 options
+ OptionKeyOauthFlow = "adbc.flight.sql.oauth.flow"
+ OptionKeyAuthURI = "adbc.flight.sql.oauth.auth_uri"
+ OptionKeyTokenURI = "adbc.flight.sql.oauth.token_uri"
+ OptionKeyRedirectURI = "adbc.flight.sql.oauth.redirect_uri"
+ OptionKeyScope = "adbc.flight.sql.oauth.scope"
+ OptionKeyClientId = "adbc.flight.sql.oauth.client_id"
+ OptionKeyClientSecret = "adbc.flight.sql.oauth.client_secret"
+ OptionKeySubjectToken =
"adbc.flight.sql.oauth.exchange.subject_token"
+ OptionKeySubjectTokenType =
"adbc.flight.sql.oauth.exchange.subject_token_type"
+ OptionKeyActorToken = "adbc.flight.sql.oauth.exchange.actor_token"
+ OptionKeyActorTokenType =
"adbc.flight.sql.oauth.exchange.actor_token_type"
+ OptionKeyReqTokenType =
"adbc.flight.sql.oauth.exchange.requested_token_type"
+ OptionKeyExchangeScope = "adbc.flight.sql.oauth.exchange.scope"
+ OptionKeyExchangeAud = "adbc.flight.sql.oauth.exchange.aud"
+ OptionKeyExchangeResource = "adbc.flight.sql.oauth.exchange.resource"
)
var errNoTransactionSupport = adbc.Error{
diff --git a/go/adbc/driver/flightsql/flightsql_oauth.go
b/go/adbc/driver/flightsql/flightsql_oauth.go
new file mode 100644
index 000000000..707590a0d
--- /dev/null
+++ b/go/adbc/driver/flightsql/flightsql_oauth.go
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package flightsql
+
+import (
+ "context"
+ "fmt"
+
+ "golang.org/x/oauth2"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/credentials/oauth"
+)
+
+const (
+ ClientCredentials = "client_credentials"
+ TokenExchange = "token_exchange"
+)
+
+type oAuthOption struct {
+ isRequired bool
+ oAuthKey string
+}
+
+var (
+ clientCredentialsParams = map[string]oAuthOption{
+ OptionKeyClientId: {true, "client_id"},
+ OptionKeyClientSecret: {true, "client_secret"},
+ OptionKeyTokenURI: {true, "token_uri"},
+ OptionKeyScope: {false, "scope"},
+ }
+
+ tokenExchangParams = map[string]oAuthOption{
+ OptionKeySubjectToken: {true, "subject_token"},
+ OptionKeySubjectTokenType: {true, "subject_token_type"},
+ OptionKeyReqTokenType: {false, "requested_token_type"},
+ OptionKeyExchangeAud: {false, "audience"},
+ OptionKeyExchangeResource: {false, "resource"},
+ OptionKeyExchangeScope: {false, "scope"},
+ }
+)
+
+func parseOAuthOptions(options map[string]string, paramMap
map[string]oAuthOption, flowName string) (map[string]string, error) {
+ params := map[string]string{}
+
+ for key, param := range paramMap {
+ if value, ok := options[key]; ok {
+ params[key] = value
+ delete(options, key)
+ } else if param.isRequired {
+ return nil, fmt.Errorf("%s grant requires %s",
flowName, key)
+ }
+ }
+
+ return params, nil
+}
+
+func exchangeToken(conf *oauth2.Config, codeOptions []oauth2.AuthCodeOption)
(credentials.PerRPCCredentials, error) {
+ ctx := context.Background()
+ tok, err := conf.Exchange(ctx, "", codeOptions...)
+ if err != nil {
+ return nil, err
+ }
+ return &oauth.TokenSource{TokenSource: conf.TokenSource(ctx, tok)}, nil
+}
+
+func newClientCredentials(options map[string]string)
(credentials.PerRPCCredentials, error) {
+ codeOptions := []oauth2.AuthCodeOption{
+ // Required value for client credentials requests as specified
in https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.2
+ oauth2.SetAuthURLParam("grant_type", "client_credentials"),
+ }
+
+ params, err := parseOAuthOptions(options, clientCredentialsParams,
"client credentials")
+ if err != nil {
+ return nil, err
+ }
+
+ conf := &oauth2.Config{
+ ClientID: params[OptionKeyClientId],
+ ClientSecret: params[OptionKeyClientSecret],
+ Endpoint: oauth2.Endpoint{
+ TokenURL: params[OptionKeyTokenURI],
+ },
+ }
+
+ if scopes, ok := params[OptionKeyScope]; ok {
+ conf.Scopes = []string{scopes}
+ }
+
+ return exchangeToken(conf, codeOptions)
+}
+
+func newTokenExchangeFlow(options map[string]string)
(credentials.PerRPCCredentials, error) {
+ tokenURI, ok := options[OptionKeyTokenURI]
+ if !ok {
+ return nil, fmt.Errorf("token exchange grant requires %s",
OptionKeyTokenURI)
+ }
+ delete(options, OptionKeyTokenURI)
+
+ conf := &oauth2.Config{
+ Endpoint: oauth2.Endpoint{
+ TokenURL: tokenURI,
+ },
+ }
+
+ codeOptions := []oauth2.AuthCodeOption{
+ // Required value for token exchange requests as specified in
https://datatracker.ietf.org/doc/html/rfc8693#name-request
+ oauth2.SetAuthURLParam("grant_type",
"urn:ietf:params:oauth:grant-type:token-exchange"),
+ }
+
+ params, err := parseOAuthOptions(options, tokenExchangParams, "token
exchange")
+ if err != nil {
+ return nil, err
+ }
+
+ for key, param := range tokenExchangParams {
+ if value, ok := params[key]; ok {
+ codeOptions = append(codeOptions,
oauth2.SetAuthURLParam(param.oAuthKey, value))
+ }
+ }
+
+ // actor token and actor token type are optional
+ // but if one is present, the other must be present
+ if actor, ok := options[OptionKeyActorToken]; ok {
+ codeOptions = append(codeOptions,
oauth2.SetAuthURLParam("actor_token", actor))
+ delete(options, OptionKeyActorToken)
+ if actorTokenType, ok := options[OptionKeyActorTokenType]; ok {
+ codeOptions = append(codeOptions,
oauth2.SetAuthURLParam("actor_token_type", actorTokenType))
+ delete(options, OptionKeyActorTokenType)
+ } else {
+ return nil, fmt.Errorf("token exchange grant requires
%s when %s is provided",
+ OptionKeyActorTokenType, OptionKeyActorToken)
+ }
+ }
+
+ return exchangeToken(conf, codeOptions)
+}