This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ea2fcb8a feat(go/adbc/driver/flightsql): Add OAuth Support to Flight 
Client (#2651)
2ea2fcb8a is described below

commit 2ea2fcb8ae6062b3a0599ef9f9eeda207e16e4eb
Author: Hélder Gregório <[email protected]>
AuthorDate: Thu Apr 17 14:32:22 2025 +0000

    feat(go/adbc/driver/flightsql): Add OAuth Support to Flight Client (#2651)
    
    ## Description
    
    This pull request introduces OAuth support to the Flight client in the
    GO driver. The changes include the addition of OAuth access token
    support, implementation of token exchange and client credentials OAuth
    flows.
    
    ## Related Issues
    
    - Closes #[2650](https://github.com/apache/arrow-adbc/issues/2650)
    
    ## Changes Made
    1. Added `token` as a database option
    1. Added support for [Token
    Exchange](https://datatracker.ietf.org/doc/html/rfc8693). If configured,
    `token` gets exchanged and the result is added to the `Authorization`
    header as a `Bearer` token
    1. Added support for [Client
    Credentials](https://datatracker.ietf.org/doc/html/rfc6749#section-4.4).
    If configured, `client_id` and `client_secret` are used to obtain a
    access token that is added to the `Authorization` header as a `Bearer`
    token
    1. Added new driver options to allow third-party applications to
    configure oauth flows:
    1. Added tests
    
    Here's the markdown code for the OAuth 2.0 configuration options table:
    markdown# OAuth 2.0 Configuration Options
    
    | Option | Description |
    |--------|-------------|
    | `adbc.flight.sql.oauth.flow` | Specifies the OAuth 2.0 flow type to
    use. Possible values: `client_credentials`, `token_exchange` |
    | `adbc.flight.sql.oauth.client_id` | Unique identifier issued to the
    client application by the authorization server |
    | `adbc.flight.sql.oauth.client_secret` | Secret associated to the
    client_id. Used to authenticate the client application to the
    authorization server |
    | `adbc.flight.sql.oauth.token_uri` | The endpoint URL where the client
    application requests tokens from the authorization server |
    | `adbc.flight.sql.oauth.scope` | Space-separated list of permissions
    that the client is requesting access to (e.g `"read.all
    offline_access"`) |
    | `adbc.flight.sql.oauth.exchange.subject_token` | The security token
    that the client application wants to exchange |
    | `adbc.flight.sql.oauth.exchange.subject_token_type` | Identifier for
    the type of the subject token. Check list below for supported token
    types. |
    | `adbc.flight.sql.oauth.exchange.actor_token` | A security token that
    represents the identity of the acting party |
    | `adbc.flight.sql.oauth.exchange.actor_token_type` | Identifier for the
    type of the actor token. Check list below for supported token types. |
    | `adbc.flight.sql.oauth.exchange.aud` | The intended audience for the
    requested security token |
    | `adbc.flight.sql.oauth.exchange.resource` | The resource server where
    the client intends to use the requested security token |
    | `adbc.flight.sql.oauth.exchange.scope` | Specific permissions
    requested for the new token |
    | `adbc.flight.sql.oauth.exchange.requested_token_type` | The type of
    token the client wants to receive in exchange. Check list below for
    supported token types. |
    
    **Supported token types:**
    * `urn:ietf:params:oauth:token-type:access_token`
    * `urn:ietf:params:oauth:token-type:refresh_token`
    * `urn:ietf:params:oauth:token-type:id_token`
    * `urn:ietf:params:oauth:token-type:saml1`
    * `urn:ietf:params:oauth:token-type:saml2`
    * `urn:ietf:params:oauth:token-type:jwt`
---
 docs/source/driver/flight_sql.rst                  |  65 +++-
 .../driver/flightsql/flightsql_adbc_server_test.go | 375 ++++++++++++++++++++-
 go/adbc/driver/flightsql/flightsql_database.go     |  74 +++-
 go/adbc/driver/flightsql/flightsql_driver.go       |  17 +
 go/adbc/driver/flightsql/flightsql_oauth.go        | 151 +++++++++
 5 files changed, 650 insertions(+), 32 deletions(-)

diff --git a/docs/source/driver/flight_sql.rst 
b/docs/source/driver/flight_sql.rst
index 983a6162c..d3a588be5 100644
--- a/docs/source/driver/flight_sql.rst
+++ b/docs/source/driver/flight_sql.rst
@@ -159,6 +159,12 @@ few optional authentication schemes:
   header will then be sent back as the ``authorization`` header on all
   future requests.
 
+- OAuth 2.0 authentication flows.
+
+  The client provides :ref:`configurations <oauth-configurations>` to allow 
client application to obtain access
+  tokens from an authorization server. The obtained token is then used
+  on the ``authorization`` header on all future requests.
+
 Bulk Ingestion
 --------------
 
@@ -246,10 +252,67 @@ to :c:struct:`AdbcDatabase`, :c:struct:`AdbcConnection`, 
and
   Add the header ``<HEADER NAME>`` to outgoing requests with the given
   value.
 
-    Python: 
:attr:`adbc_driver_flightsql.ConnectionOptions.RPC_CALL_HEADER_PREFIX`
+  Python: 
:attr:`adbc_driver_flightsql.ConnectionOptions.RPC_CALL_HEADER_PREFIX`
 
   .. warning:: Header names must be in all lowercase.
 
+
+OAuth 2.0 Options
+-----------------------
+.. _oauth-configurations:
+
+Supported configurations to obtain tokens using OAuth 2.0 authentication flows.
+
+``adbc.flight.sql.oauth.flow``
+  Specifies the OAuth 2.0 flow type to use. Possible values: 
``client_credentials``, ``token_exchange``
+
+``adbc.flight.sql.oauth.client_id``
+  Unique identifier issued to the client application by the authorization 
server
+
+``adbc.flight.sql.oauth.client_secret``
+  Secret associated to the client_id. Used to authenticate the client 
application to the authorization server
+
+``adbc.flight.sql.oauth.token_uri``
+  The endpoint URL where the client application requests tokens from the 
authorization server
+
+``adbc.flight.sql.oauth.scope``
+  Space-separated list of permissions that the client is requesting access to 
(e.g ``"read.all offline_access"``)
+
+``adbc.flight.sql.oauth.exchange.subject_token``
+  The security token that the client application wants to exchange
+
+``adbc.flight.sql.oauth.exchange.subject_token_type``
+  Identifier for the type of the subject token.
+  Check list below for supported token types.
+
+``adbc.flight.sql.oauth.exchange.actor_token``
+  A security token that represents the identity of the acting party
+
+``adbc.flight.sql.oauth.exchange.actor_token_type``
+  Identifier for the type of the actor token.
+  Check list below for supported token types.
+``adbc.flight.sql.oauth.exchange.aud``
+  The intended audience for the requested security token
+
+``adbc.flight.sql.oauth.exchange.resource``
+  The resource server where the client intends to use the requested security 
token
+
+``adbc.flight.sql.oauth.exchange.scope``
+  Specific permissions requested for the new token
+
+``adbc.flight.sql.oauth.exchange.requested_token_type``
+  The type of token the client wants to receive in exchange.
+  Check list below for supported token types.
+
+
+Supported token types:
+  - ``urn:ietf:params:oauth:token-type:access_token``
+  - ``urn:ietf:params:oauth:token-type:refresh_token``
+  - ``urn:ietf:params:oauth:token-type:id_token``
+  - ``urn:ietf:params:oauth:token-type:saml1``
+  - ``urn:ietf:params:oauth:token-type:saml2``
+  - ``urn:ietf:params:oauth:token-type:jwt``
+
 Distributed Result Sets
 -----------------------
 
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index c8a59d72a..f2cd0060c 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -20,11 +20,21 @@
 package flightsql_test
 
 import (
+       "bytes"
        "context"
+       "crypto/rand"
+       "crypto/rsa"
+       "crypto/tls"
+       "crypto/x509"
+       "crypto/x509/pkix"
        "encoding/json"
+       "encoding/pem"
        "errors"
        "fmt"
+       "math/big"
        "net"
+       "net/http"
+       "net/http/httptest"
        "net/textproto"
        "os"
        "strconv"
@@ -50,6 +60,7 @@ import (
        "golang.org/x/exp/maps"
        "google.golang.org/grpc"
        "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/credentials"
        "google.golang.org/grpc/metadata"
        "google.golang.org/grpc/stats"
        "google.golang.org/grpc/status"
@@ -69,16 +80,14 @@ type ServerBasedTests struct {
 }
 
 func (suite *ServerBasedTests) DoSetupSuite(srv flightsql.Server, 
srvMiddleware []flight.ServerMiddleware, dbArgs map[string]string, dialOpts 
...grpc.DialOption) {
-       suite.s = flight.NewServerWithMiddleware(srvMiddleware)
-       suite.s.RegisterFlightService(flightsql.NewFlightServer(srv))
-       suite.Require().NoError(suite.s.Init("localhost:0"))
-       suite.s.SetShutdownOnSignals(os.Interrupt, os.Kill)
-       go func() {
-               _ = suite.s.Serve()
-       }()
+       suite.setupFlightServer(srv, srvMiddleware)
 
-       uri := "grpc+tcp://" + suite.s.Addr().String()
+       suite.setupDatabase(dbArgs, dialOpts...)
+}
+
+func (suite *ServerBasedTests) setupDatabase(dbArgs map[string]string, 
dialOpts ...grpc.DialOption) {
        var err error
+       uri := "grpc+tcp://" + suite.s.Addr().String()
 
        args := map[string]string{
                "uri": uri,
@@ -88,6 +97,16 @@ func (suite *ServerBasedTests) DoSetupSuite(srv 
flightsql.Server, srvMiddleware
        suite.Require().NoError(err)
 }
 
+func (suite *ServerBasedTests) setupFlightServer(srv flightsql.Server, 
srvMiddleware []flight.ServerMiddleware, srvOpts ...grpc.ServerOption) {
+       suite.s = flight.NewServerWithMiddleware(srvMiddleware, srvOpts...)
+       suite.s.RegisterFlightService(flightsql.NewFlightServer(srv))
+       suite.Require().NoError(suite.s.Init("localhost:0"))
+       suite.s.SetShutdownOnSignals(os.Interrupt, os.Kill)
+       go func() {
+               _ = suite.s.Serve()
+       }()
+}
+
 func (suite *ServerBasedTests) SetupTest() {
        var err error
        suite.cnxn, err = suite.db.Open(context.Background())
@@ -104,6 +123,59 @@ func (suite *ServerBasedTests) TearDownSuite() {
        suite.s.Shutdown()
 }
 
+func (suite *ServerBasedTests) generateCertOption() grpc.ServerOption {
+       // Generate a self-signed certificate in-process for testing
+       privKey, err := rsa.GenerateKey(rand.Reader, 2048)
+       suite.Require().NoError(err)
+       certTemplate := x509.Certificate{
+               SerialNumber: big.NewInt(1),
+               Subject: pkix.Name{
+                       Organization: []string{"Unit Tests Incorporated"},
+               },
+               IPAddresses:           []net.IP{net.IPv4(127, 0, 0, 1), 
net.IPv6loopback},
+               NotBefore:             time.Now(),
+               NotAfter:              time.Now().Add(time.Hour),
+               KeyUsage:              x509.KeyUsageKeyEncipherment,
+               ExtKeyUsage:           
[]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+               BasicConstraintsValid: true,
+       }
+       certDer, err := x509.CreateCertificate(rand.Reader, &certTemplate, 
&certTemplate, &privKey.PublicKey, privKey)
+       suite.Require().NoError(err)
+       buffer := &bytes.Buffer{}
+       suite.Require().NoError(pem.Encode(buffer, &pem.Block{Type: 
"CERTIFICATE", Bytes: certDer}))
+       certBytes := make([]byte, buffer.Len())
+       copy(certBytes, buffer.Bytes())
+       buffer.Reset()
+       suite.Require().NoError(pem.Encode(buffer, &pem.Block{Type: "RSA 
PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privKey)}))
+       keyBytes := make([]byte, buffer.Len())
+       copy(keyBytes, buffer.Bytes())
+
+       cert, err := tls.X509KeyPair(certBytes, keyBytes)
+       suite.Require().NoError(err)
+
+       suite.Require().NoError(err)
+       tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
+       tlsCreds := credentials.NewTLS(tlsConfig)
+
+       return grpc.Creds(tlsCreds)
+}
+
+func (suite *ServerBasedTests) openAndExecuteQuery(query string) {
+       var err error
+       suite.cnxn, err = suite.db.Open(context.Background())
+       suite.Require().NoError(err)
+       defer suite.cnxn.Close()
+
+       stmt, err := suite.cnxn.NewStatement()
+       suite.Require().NoError(err)
+       defer stmt.Close()
+
+       suite.Require().NoError(stmt.SetSqlQuery(query))
+       reader, _, err := stmt.ExecuteQuery(context.Background())
+       suite.NoError(err)
+       defer reader.Release()
+}
+
 // ---- Tests --------------------
 
 func TestAuthn(t *testing.T) {
@@ -150,6 +222,10 @@ func TestGetObjects(t *testing.T) {
        suite.Run(t, &GetObjectsTests{})
 }
 
+func TestOauth(t *testing.T) {
+       suite.Run(t, &OAuthTests{})
+}
+
 // ---- AuthN Tests --------------------
 
 type AuthnTestServer struct {
@@ -230,23 +306,288 @@ type AuthnTests struct {
 }
 
 func (suite *AuthnTests) SetupSuite() {
-       suite.DoSetupSuite(&AuthnTestServer{}, []flight.ServerMiddleware{
+       suite.setupFlightServer(&AuthnTestServer{}, []flight.ServerMiddleware{
                {Stream: authnTestStream, Unary: authnTestUnary},
-       }, map[string]string{
-               driver.OptionAuthorizationHeader: "Bearer initial",
        })
 }
 
+func (suite *AuthnTests) SetupTest() {
+       suite.setupDatabase(map[string]string{
+               "uri": "grpc+tcp://" + suite.s.Addr().String(),
+       })
+}
+
+func (suite *AuthnTests) TearDownTest() {
+       suite.NoError(suite.db.Close())
+       suite.db = nil
+}
+
+func (suite *AuthnTests) TearDownSuite() {
+       suite.s.Shutdown()
+}
+
 func (suite *AuthnTests) TestBearerTokenUpdated() {
+       err := suite.db.SetOptions(map[string]string{
+               driver.OptionAuthorizationHeader: "Bearer initial",
+       })
+       suite.Require().NoError(err)
+
        // apache/arrow-adbc#584: when setting the auth header directly, the 
client should use any updated token value from the server if given
-       stmt, err := suite.cnxn.NewStatement()
+
+       suite.openAndExecuteQuery("a-query")
+}
+
+type OAuthTests struct {
+       ServerBasedTests
+
+       oauthServer     *httptest.Server
+       mockOAuthServer *MockOAuthServer
+}
+
+// MockOAuthServer simulates an OAuth 2.0 server for testing
+type MockOAuthServer struct {
+       // Track calls to validate server behavior
+       clientCredentialsCalls int
+       tokenExchangeCalls     int
+}
+
+func (m *MockOAuthServer) handleTokenRequest(w http.ResponseWriter, r 
*http.Request) {
+       // Parse the form to get the request parameters
+       if err := r.ParseForm(); err != nil {
+               http.Error(w, "Invalid request", http.StatusBadRequest)
+               return
+       }
+
+       grantType := r.FormValue("grant_type")
+
+       switch grantType {
+       case "client_credentials":
+               m.clientCredentialsCalls++
+               // Validate client credentials
+               clientID := r.FormValue("client_id")
+               clientSecret := r.FormValue("client_secret")
+
+               if clientID == "test-client" && clientSecret == "test-secret" {
+                       // Return a valid token response
+                       w.Header().Set("Content-Type", "application/json")
+                       _, _ = w.Write([]byte(`{
+                               "access_token": "test-client-token",
+                               "token_type": "bearer",
+                               "expires_in": 3600
+                       }`))
+
+                       return
+               }
+
+       case "urn:ietf:params:oauth:grant-type:token-exchange":
+               m.tokenExchangeCalls++
+               // Validate token exchange parameters
+               subjectToken := r.FormValue("subject_token")
+               subjectTokenType := r.FormValue("subject_token_type")
+
+               if subjectToken == "test-subject-token" &&
+                       subjectTokenType == 
"urn:ietf:params:oauth:token-type:jwt" {
+                       // Return a valid token response
+                       w.Header().Set("Content-Type", "application/json")
+                       _, _ = w.Write([]byte(`{
+                               "access_token": "test-exchanged-token",
+                               "token_type": "bearer",
+                               "expires_in": 3600
+                       }`))
+                       return
+               }
+       }
+
+       // Default: return error for invalid request
+       http.Error(w, "Invalid request", http.StatusBadRequest)
+}
+
+func oauthTestUnary(ctx context.Context, req interface{}, info 
*grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) 
{
+       md, ok := metadata.FromIncomingContext(ctx)
+       if !ok {
+               return nil, status.Error(codes.InvalidArgument, "Could not get 
metadata")
+       }
+       auth := md.Get("authorization")
+       if len(auth) == 0 {
+               return nil, status.Error(codes.Unauthenticated, "No token")
+       } else if auth[0] != "Bearer test-exchanged-token" && auth[0] != 
"Bearer test-client-token" {
+               return nil, status.Error(codes.Unauthenticated, "Invalid token 
for unary call: "+auth[0])
+       }
+
+       md.Set("authorization", "Bearer final")
+       ctx = metadata.NewOutgoingContext(ctx, md)
+       return handler(ctx, req)
+}
+
+func (suite *OAuthTests) SetupSuite() {
+       suite.mockOAuthServer = &MockOAuthServer{}
+       suite.oauthServer = 
httptest.NewServer(http.HandlerFunc(suite.mockOAuthServer.handleTokenRequest))
+
+       suite.setupFlightServer(&AuthnTestServer{}, []flight.ServerMiddleware{
+               {Unary: oauthTestUnary},
+       }, suite.generateCertOption())
+}
+
+func (suite *OAuthTests) TearDownSuite() {
+       suite.oauthServer.Close()
+       suite.s.Shutdown()
+}
+
+func (suite *OAuthTests) SetupTest() {
+       suite.setupDatabase(map[string]string{
+               "uri": "grpc+tls://" + suite.s.Addr().String(),
+       })
+}
+
+func (suite *OAuthTests) TearDownTest() {
+       suite.NoError(suite.db.Close())
+       suite.db = nil
+}
+
+func (suite *OAuthTests) TestTokenExchangeFlow() {
+       err := suite.db.SetOptions(map[string]string{
+               driver.OptionKeyOauthFlow:        driver.TokenExchange,
+               driver.OptionKeySubjectToken:     "test-subject-token",
+               driver.OptionKeySubjectTokenType: 
"urn:ietf:params:oauth:token-type:jwt",
+               driver.OptionKeyTokenURI:         suite.oauthServer.URL,
+               driver.OptionSSLSkipVerify:       adbc.OptionValueEnabled,
+       })
        suite.Require().NoError(err)
-       defer stmt.Close()
 
-       suite.Require().NoError(stmt.SetSqlQuery("timeout"))
-       reader, _, err := stmt.ExecuteQuery(context.Background())
-       suite.NoError(err)
-       defer reader.Release()
+       suite.openAndExecuteQuery("a-query")
+       suite.Equal(1, suite.mockOAuthServer.tokenExchangeCalls, "Token 
exchange flow should be called once")
+}
+
+func (suite *OAuthTests) TestClientCredentialsFlow() {
+       err := suite.db.SetOptions(map[string]string{
+               driver.OptionKeyOauthFlow:    driver.ClientCredentials,
+               driver.OptionKeyClientId:     "test-client",
+               driver.OptionKeyClientSecret: "test-secret",
+               driver.OptionKeyTokenURI:     suite.oauthServer.URL,
+               driver.OptionSSLSkipVerify:   adbc.OptionValueEnabled,
+       })
+       suite.Require().NoError(err)
+
+       suite.cnxn, err = suite.db.Open(context.Background())
+       suite.Require().NoError(err)
+       defer suite.cnxn.Close()
+
+       suite.openAndExecuteQuery("a-query")
+       // golang/oauth2 tries to call the token endpoint sending the client 
credentials in the authentication header,
+       // if it fails, it retries sending the client credentials in the 
request body.
+       // See https://code.google.com/p/goauth2/issues/detail?id=31 for 
background.
+       suite.Equal(2, suite.mockOAuthServer.clientCredentialsCalls, "Client 
credentials flow should be called once")
+}
+
+func (suite *OAuthTests) TestFailOauthWithTokenSet() {
+       err := suite.db.SetOptions(map[string]string{
+               driver.OptionAuthorizationHeader: "Bearer test-client-token",
+               driver.OptionKeyOauthFlow:        driver.ClientCredentials,
+               driver.OptionKeyClientId:         "test-client",
+               driver.OptionKeyClientSecret:     "test-secret",
+               driver.OptionKeyTokenURI:         suite.oauthServer.URL,
+       })
+       suite.Error(err, "Expected error for missing parameters")
+       suite.Contains(err.Error(), "Authentication conflict: Use either 
Authorization header OR username/password parameter")
+}
+
+func (suite *OAuthTests) TestMissingRequiredParamsTokenExchange() {
+       testCases := []struct {
+               name             string
+               options          map[string]string
+               expectedErrorMsg string
+       }{
+               {
+                       name: "Missing token",
+                       options: map[string]string{
+                               driver.OptionKeyOauthFlow:        
driver.TokenExchange,
+                               driver.OptionKeySubjectTokenType: 
"urn:ietf:params:oauth:token-type:jwt",
+                               driver.OptionKeyTokenURI:         
suite.oauthServer.URL,
+                       },
+                       expectedErrorMsg: "token exchange grant requires 
adbc.flight.sql.oauth.exchange.subject_token",
+               },
+               {
+                       name: "Missing subject token type",
+                       options: map[string]string{
+                               driver.OptionKeyOauthFlow:    
driver.TokenExchange,
+                               driver.OptionKeySubjectToken: 
"test-subject-token",
+                               driver.OptionKeyTokenURI:     
suite.oauthServer.URL,
+                       },
+                       expectedErrorMsg: "token exchange grant requires 
adbc.flight.sql.oauth.exchange.subject_token_type",
+               },
+               {
+                       name: "Missing token URI",
+                       options: map[string]string{
+                               driver.OptionKeyOauthFlow:        
driver.TokenExchange,
+                               driver.OptionKeySubjectToken:     
"test-subject-token",
+                               driver.OptionKeySubjectTokenType: 
"urn:ietf:params:oauth:token-type:jwt",
+                       },
+                       expectedErrorMsg: "token exchange grant requires 
adbc.flight.sql.oauth.token_uri",
+               },
+       }
+
+       for _, tc := range testCases {
+               suite.Run(tc.name, func() {
+                       // We need to set options with the driver's SetOptions 
method
+                       err := suite.db.SetOptions(tc.options)
+                       suite.Error(err, "Expected error for missing 
parameters")
+                       suite.Contains(err.Error(), tc.expectedErrorMsg)
+               })
+       }
+}
+func (suite *OAuthTests) TestMissingRequiredParamsClientCredentials() {
+       testCases := []struct {
+               name             string
+               options          map[string]string
+               expectedErrorMsg string
+       }{
+               {
+                       name: "Missing client ID",
+                       options: map[string]string{
+                               driver.OptionKeyOauthFlow:    
driver.ClientCredentials,
+                               driver.OptionKeyClientSecret: "test-secret",
+                               driver.OptionKeyTokenURI:     
suite.oauthServer.URL,
+                       },
+                       expectedErrorMsg: "client credentials grant requires 
adbc.flight.sql.oauth.client_id",
+               },
+               {
+                       name: "Missing client secret",
+                       options: map[string]string{
+                               driver.OptionKeyOauthFlow: 
driver.ClientCredentials,
+                               driver.OptionKeyClientId:  "test-client",
+                               driver.OptionKeyTokenURI:  
suite.oauthServer.URL,
+                       },
+                       expectedErrorMsg: "client credentials grant requires 
adbc.flight.sql.oauth.client_secret",
+               },
+               {
+                       name: "Missing token URI",
+                       options: map[string]string{
+                               driver.OptionKeyOauthFlow:    
driver.ClientCredentials,
+                               driver.OptionKeyClientId:     "test-client",
+                               driver.OptionKeyClientSecret: "test-secret",
+                       },
+                       expectedErrorMsg: "client credentials grant requires 
adbc.flight.sql.oauth.token_uri",
+               },
+       }
+
+       for _, tc := range testCases {
+               suite.Run(tc.name, func() {
+                       // We need to set options with the driver's SetOptions 
method
+                       err := suite.db.SetOptions(tc.options)
+                       suite.Error(err, "Expected error for missing 
parameters")
+                       suite.Contains(err.Error(), tc.expectedErrorMsg)
+               })
+       }
+}
+
+func (suite *OAuthTests) TestInvalidOAuthFlow() {
+       err := suite.db.SetOptions(map[string]string{
+               driver.OptionKeyOauthFlow:    "invalid-flow",
+               driver.OptionKeySubjectToken: "test-token",
+       })
+
+       suite.Error(err, "Expected error for invalid OAuth flow")
+       suite.Contains(err.Error(), "Not Implemented: oauth flow not 
implemented: invalid-flow")
 }
 
 // ---- Grpc Dialer Options Tests --------------
diff --git a/go/adbc/driver/flightsql/flightsql_database.go 
b/go/adbc/driver/flightsql/flightsql_database.go
index bbbcbbf06..e45eb4d5d 100644
--- a/go/adbc/driver/flightsql/flightsql_database.go
+++ b/go/adbc/driver/flightsql/flightsql_database.go
@@ -68,6 +68,7 @@ type databaseImpl struct {
        enableCookies bool
        options       map[string]string
        userDialOpts  []grpc.DialOption
+       oauthToken    credentials.PerRPCCredentials
 }
 
 func (d *databaseImpl) SetOptions(cnOptions map[string]string) error {
@@ -146,10 +147,12 @@ func (d *databaseImpl) SetOptions(cnOptions 
map[string]string) error {
                delete(cnOptions, OptionAuthorizationHeader)
        }
 
+       const authConflictError = "Authentication conflict: Use either 
Authorization header OR username/password parameter"
+
        if u, ok := cnOptions[adbc.OptionKeyUsername]; ok {
                if d.hdrs.Len() > 0 {
                        return adbc.Error{
-                               Msg:  "Authorization header already provided, 
do not provide user/pass also",
+                               Msg:  authConflictError,
                                Code: adbc.StatusInvalidArgument,
                        }
                }
@@ -160,7 +163,7 @@ func (d *databaseImpl) SetOptions(cnOptions 
map[string]string) error {
        if p, ok := cnOptions[adbc.OptionKeyPassword]; ok {
                if d.hdrs.Len() > 0 {
                        return adbc.Error{
-                               Msg:  "Authorization header already provided, 
do not provide user/pass also",
+                               Msg:  authConflictError,
                                Code: adbc.StatusInvalidArgument,
                        }
                }
@@ -168,6 +171,33 @@ func (d *databaseImpl) SetOptions(cnOptions 
map[string]string) error {
                delete(cnOptions, adbc.OptionKeyPassword)
        }
 
+       if flow, ok := cnOptions[OptionKeyOauthFlow]; ok {
+               if d.hdrs.Len() > 0 {
+                       return adbc.Error{
+                               Msg:  authConflictError,
+                               Code: adbc.StatusInvalidArgument,
+                       }
+               }
+
+               var err error
+               switch flow {
+               case ClientCredentials:
+                       d.oauthToken, err = newClientCredentials(cnOptions)
+               case TokenExchange:
+                       d.oauthToken, err = newTokenExchangeFlow(cnOptions)
+               default:
+                       return adbc.Error{
+                               Msg:  fmt.Sprintf("oauth flow not implemented: 
%s", flow),
+                               Code: adbc.StatusNotImplemented,
+                       }
+               }
+
+               if err != nil {
+                       return err
+               }
+               delete(cnOptions, OptionKeyOauthFlow)
+       }
+
        var err error
        if tv, ok := cnOptions[OptionTimeoutFetch]; ok {
                if err = d.timeout.setTimeoutString(OptionTimeoutFetch, tv); 
err != nil {
@@ -374,6 +404,10 @@ func getFlightClient(ctx context.Context, loc string, d 
*databaseImpl, authMiddl
        dialOpts := append(d.dialOpts.opts, 
grpc.WithConnectParams(d.timeout.connectParams()), 
grpc.WithTransportCredentials(creds), grpc.WithUserAgent("ADBC Flight SQL 
Driver "+driverVersion))
        dialOpts = append(dialOpts, d.userDialOpts...)
 
+       if d.oauthToken != nil {
+               dialOpts = append(dialOpts, 
grpc.WithPerRPCCredentials(d.oauthToken))
+       }
+
        d.Logger.DebugContext(ctx, "new client", "location", loc)
        cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...)
        if err != nil {
@@ -384,24 +418,30 @@ func getFlightClient(ctx context.Context, loc string, d 
*databaseImpl, authMiddl
        }
 
        cl.Alloc = d.Alloc
+       // Authorization header is already set, continue
        if len(authMiddle.hdrs.Get("authorization")) > 0 {
                d.Logger.DebugContext(ctx, "reusing auth token", "location", 
loc)
-       } else {
-               if d.user != "" || d.pass != "" {
-                       var header, trailer metadata.MD
-                       ctx, err = cl.Client.AuthenticateBasicToken(ctx, 
d.user, d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout)
-                       if err != nil {
-                               return nil, 
adbcFromFlightStatusWithDetails(err, header, trailer, "AuthenticateBasicToken")
-                       }
+               return cl, nil
+       }
 
-                       if md, ok := metadata.FromOutgoingContext(ctx); ok {
-                               authMiddle.mutex.Lock()
-                               defer authMiddle.mutex.Unlock()
-                               authMiddle.hdrs.Set("authorization", 
md.Get("Authorization")[0])
-                       }
+       var authValue string
+
+       if d.user != "" || d.pass != "" {
+               var header, trailer metadata.MD
+               ctx, err = cl.Client.AuthenticateBasicToken(ctx, d.user, 
d.pass, grpc.Header(&header), grpc.Trailer(&trailer), d.timeout)
+               if err != nil {
+                       return nil, adbcFromFlightStatusWithDetails(err, 
header, trailer, "AuthenticateBasicToken")
+               }
+
+               if md, ok := metadata.FromOutgoingContext(ctx); ok {
+                       authValue = md.Get("Authorization")[0]
                }
        }
 
+       if authValue != "" {
+               authMiddle.SetHeader(authValue)
+       }
+
        return cl, nil
 }
 
@@ -526,3 +566,9 @@ func (b *bearerAuthMiddleware) HeadersReceived(ctx 
context.Context, md metadata.
                b.hdrs.Set("authorization", headers...)
        }
 }
+
+func (b *bearerAuthMiddleware) SetHeader(authValue string) {
+       b.mutex.Lock()
+       defer b.mutex.Unlock()
+       b.hdrs.Set("authorization", authValue)
+}
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go 
b/go/adbc/driver/flightsql/flightsql_driver.go
index 9e517c716..ff1e74bb5 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -66,6 +66,23 @@ const (
        OptionStringListSessionOptionPrefix = 
"adbc.flight.sql.session.optionstringlist."
        OptionLastFlightInfo                = 
"adbc.flight.sql.statement.exec.last_flight_info"
        infoDriverName                      = "ADBC Flight SQL Driver - Go"
+
+       // Oauth2 options
+       OptionKeyOauthFlow        = "adbc.flight.sql.oauth.flow"
+       OptionKeyAuthURI          = "adbc.flight.sql.oauth.auth_uri"
+       OptionKeyTokenURI         = "adbc.flight.sql.oauth.token_uri"
+       OptionKeyRedirectURI      = "adbc.flight.sql.oauth.redirect_uri"
+       OptionKeyScope            = "adbc.flight.sql.oauth.scope"
+       OptionKeyClientId         = "adbc.flight.sql.oauth.client_id"
+       OptionKeyClientSecret     = "adbc.flight.sql.oauth.client_secret"
+       OptionKeySubjectToken     = 
"adbc.flight.sql.oauth.exchange.subject_token"
+       OptionKeySubjectTokenType = 
"adbc.flight.sql.oauth.exchange.subject_token_type"
+       OptionKeyActorToken       = "adbc.flight.sql.oauth.exchange.actor_token"
+       OptionKeyActorTokenType   = 
"adbc.flight.sql.oauth.exchange.actor_token_type"
+       OptionKeyReqTokenType     = 
"adbc.flight.sql.oauth.exchange.requested_token_type"
+       OptionKeyExchangeScope    = "adbc.flight.sql.oauth.exchange.scope"
+       OptionKeyExchangeAud      = "adbc.flight.sql.oauth.exchange.aud"
+       OptionKeyExchangeResource = "adbc.flight.sql.oauth.exchange.resource"
 )
 
 var errNoTransactionSupport = adbc.Error{
diff --git a/go/adbc/driver/flightsql/flightsql_oauth.go 
b/go/adbc/driver/flightsql/flightsql_oauth.go
new file mode 100644
index 000000000..707590a0d
--- /dev/null
+++ b/go/adbc/driver/flightsql/flightsql_oauth.go
@@ -0,0 +1,151 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package flightsql
+
+import (
+       "context"
+       "fmt"
+
+       "golang.org/x/oauth2"
+       "google.golang.org/grpc/credentials"
+       "google.golang.org/grpc/credentials/oauth"
+)
+
+const (
+       ClientCredentials = "client_credentials"
+       TokenExchange     = "token_exchange"
+)
+
+type oAuthOption struct {
+       isRequired bool
+       oAuthKey   string
+}
+
+var (
+       clientCredentialsParams = map[string]oAuthOption{
+               OptionKeyClientId:     {true, "client_id"},
+               OptionKeyClientSecret: {true, "client_secret"},
+               OptionKeyTokenURI:     {true, "token_uri"},
+               OptionKeyScope:        {false, "scope"},
+       }
+
+       tokenExchangParams = map[string]oAuthOption{
+               OptionKeySubjectToken:     {true, "subject_token"},
+               OptionKeySubjectTokenType: {true, "subject_token_type"},
+               OptionKeyReqTokenType:     {false, "requested_token_type"},
+               OptionKeyExchangeAud:      {false, "audience"},
+               OptionKeyExchangeResource: {false, "resource"},
+               OptionKeyExchangeScope:    {false, "scope"},
+       }
+)
+
+func parseOAuthOptions(options map[string]string, paramMap 
map[string]oAuthOption, flowName string) (map[string]string, error) {
+       params := map[string]string{}
+
+       for key, param := range paramMap {
+               if value, ok := options[key]; ok {
+                       params[key] = value
+                       delete(options, key)
+               } else if param.isRequired {
+                       return nil, fmt.Errorf("%s grant requires %s", 
flowName, key)
+               }
+       }
+
+       return params, nil
+}
+
+func exchangeToken(conf *oauth2.Config, codeOptions []oauth2.AuthCodeOption) 
(credentials.PerRPCCredentials, error) {
+       ctx := context.Background()
+       tok, err := conf.Exchange(ctx, "", codeOptions...)
+       if err != nil {
+               return nil, err
+       }
+       return &oauth.TokenSource{TokenSource: conf.TokenSource(ctx, tok)}, nil
+}
+
+func newClientCredentials(options map[string]string) 
(credentials.PerRPCCredentials, error) {
+       codeOptions := []oauth2.AuthCodeOption{
+               // Required value for client credentials requests as specified 
in https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.2
+               oauth2.SetAuthURLParam("grant_type", "client_credentials"),
+       }
+
+       params, err := parseOAuthOptions(options, clientCredentialsParams, 
"client credentials")
+       if err != nil {
+               return nil, err
+       }
+
+       conf := &oauth2.Config{
+               ClientID:     params[OptionKeyClientId],
+               ClientSecret: params[OptionKeyClientSecret],
+               Endpoint: oauth2.Endpoint{
+                       TokenURL: params[OptionKeyTokenURI],
+               },
+       }
+
+       if scopes, ok := params[OptionKeyScope]; ok {
+               conf.Scopes = []string{scopes}
+       }
+
+       return exchangeToken(conf, codeOptions)
+}
+
+func newTokenExchangeFlow(options map[string]string) 
(credentials.PerRPCCredentials, error) {
+       tokenURI, ok := options[OptionKeyTokenURI]
+       if !ok {
+               return nil, fmt.Errorf("token exchange grant requires %s", 
OptionKeyTokenURI)
+       }
+       delete(options, OptionKeyTokenURI)
+
+       conf := &oauth2.Config{
+               Endpoint: oauth2.Endpoint{
+                       TokenURL: tokenURI,
+               },
+       }
+
+       codeOptions := []oauth2.AuthCodeOption{
+               // Required value for token exchange requests as specified in 
https://datatracker.ietf.org/doc/html/rfc8693#name-request
+               oauth2.SetAuthURLParam("grant_type", 
"urn:ietf:params:oauth:grant-type:token-exchange"),
+       }
+
+       params, err := parseOAuthOptions(options, tokenExchangParams, "token 
exchange")
+       if err != nil {
+               return nil, err
+       }
+
+       for key, param := range tokenExchangParams {
+               if value, ok := params[key]; ok {
+                       codeOptions = append(codeOptions, 
oauth2.SetAuthURLParam(param.oAuthKey, value))
+               }
+       }
+
+       // actor token and actor token type are optional
+       // but if one is present, the other must be present
+       if actor, ok := options[OptionKeyActorToken]; ok {
+               codeOptions = append(codeOptions, 
oauth2.SetAuthURLParam("actor_token", actor))
+               delete(options, OptionKeyActorToken)
+               if actorTokenType, ok := options[OptionKeyActorTokenType]; ok {
+                       codeOptions = append(codeOptions, 
oauth2.SetAuthURLParam("actor_token_type", actorTokenType))
+                       delete(options, OptionKeyActorTokenType)
+               } else {
+                       return nil, fmt.Errorf("token exchange grant requires 
%s when %s is provided",
+                               OptionKeyActorTokenType, OptionKeyActorToken)
+               }
+       }
+
+       return exchangeToken(conf, codeOptions)
+}

Reply via email to