This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-go.git
The following commit(s) were added to refs/heads/main by this push:
new bb72d993 feat(catalog/rest): Adding AuthManager support (#579)
bb72d993 is described below
commit bb72d993509f425e4f11319636190416e21679eb
Author: Alex Stephen <[email protected]>
AuthorDate: Mon Jan 26 11:30:05 2026 -0800
feat(catalog/rest): Adding AuthManager support (#579)
The Java and Python implementations introduced an AuthManager interface,
where users could add their own custom authentication solutions. An
AuthManager was responsible for creating an authentication header.
This adds an interface for an AuthManager and converts the OAuth logic
to its own AuthManager. No APIs were broken in this transition.
---
catalog/rest/auth.go | 146 ++++++++++++++++++++++++++++++++++++++++++++++
catalog/rest/auth_test.go | 113 +++++++++++++++++++++++++++++++++++
catalog/rest/options.go | 9 ++-
catalog/rest/rest.go | 129 +++++++++-------------------------------
4 files changed, 296 insertions(+), 101 deletions(-)
diff --git a/catalog/rest/auth.go b/catalog/rest/auth.go
new file mode 100644
index 00000000..82dceb39
--- /dev/null
+++ b/catalog/rest/auth.go
@@ -0,0 +1,146 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package rest
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+)
+
+// AuthManager is an interface for providing custom authorization headers.
+type AuthManager interface {
+ // AuthHeader returns the key and value for the authorization header.
+ AuthHeader() (string, string, error)
+}
+
+type oauthTokenResponse struct {
+ AccessToken string `json:"access_token"`
+ TokenType string `json:"token_type"`
+ ExpiresIn int `json:"expires_in"`
+ Scope string `json:"scope"`
+ RefreshToken string `json:"refresh_token"`
+}
+
+type oauthErrorResponse struct {
+ Err string `json:"error"`
+ ErrDesc string `json:"error_description"`
+ ErrURI string `json:"error_uri"`
+}
+
+func (o oauthErrorResponse) Unwrap() error { return ErrOAuthError }
+func (o oauthErrorResponse) Error() string {
+ msg := o.Err
+ if o.ErrDesc != "" {
+ msg += ": " + o.ErrDesc
+ }
+
+ if o.ErrURI != "" {
+ msg += " (" + o.ErrURI + ")"
+ }
+
+ return msg
+}
+
+// Oauth2AuthManager is an implementation of the AuthManager interface which
+// simply returns the provided token as a bearer token. If a credential
+// is provided instead of a static token, it will fetch and refresh the
+// token as needed.
+type Oauth2AuthManager struct {
+ Token string
+ Credential string
+
+ AuthURI *url.URL
+ Scope string
+ Client *http.Client
+}
+
+// AuthHeader returns the authorization header with the bearer token.
+func (o *Oauth2AuthManager) AuthHeader() (string, string, error) {
+ if o.Token == "" && o.Credential != "" {
+ if o.Client == nil {
+ return "", "", fmt.Errorf("%w: cannot fetch token
without http client", ErrRESTError)
+ }
+
+ tok, err := o.fetchAccessToken()
+ if err != nil {
+ return "", "", err
+ }
+ o.Token = tok
+ }
+
+ return "Authorization", "Bearer " + o.Token, nil
+}
+
+func (o *Oauth2AuthManager) fetchAccessToken() (string, error) {
+ clientID, clientSecret, hasID := strings.Cut(o.Credential, ":")
+ if !hasID {
+ clientID, clientSecret = "", o.Credential
+ }
+
+ scope := "catalog"
+ if o.Scope != "" {
+ scope = o.Scope
+ }
+ data := url.Values{
+ "grant_type": {"client_credentials"},
+ "client_id": {clientID},
+ "client_secret": {clientSecret},
+ "scope": {scope},
+ }
+
+ if o.AuthURI == nil {
+ return "", fmt.Errorf("%w: missing auth uri for fetching
token", ErrRESTError)
+ }
+
+ rsp, err := o.Client.PostForm(o.AuthURI.String(), data)
+ if err != nil {
+ return "", err
+ }
+
+ if rsp.StatusCode == http.StatusOK {
+ defer rsp.Body.Close()
+ dec := json.NewDecoder(rsp.Body)
+ var tok oauthTokenResponse
+ if err := dec.Decode(&tok); err != nil {
+ return "", fmt.Errorf("failed to decode oauth token
response: %w", err)
+ }
+
+ return tok.AccessToken, nil
+ }
+
+ switch rsp.StatusCode {
+ case http.StatusUnauthorized, http.StatusBadRequest:
+ defer func() {
+ _, _ = io.Copy(io.Discard, rsp.Body)
+ _ = rsp.Body.Close()
+ }()
+ dec := json.NewDecoder(rsp.Body)
+ var oauthErr oauthErrorResponse
+ if err := dec.Decode(&oauthErr); err != nil {
+ return "", fmt.Errorf("failed to decode oauth error:
%w", err)
+ }
+
+ return "", oauthErr
+ default:
+ return "", handleNon200(rsp, nil)
+ }
+}
diff --git a/catalog/rest/auth_test.go b/catalog/rest/auth_test.go
new file mode 100644
index 00000000..84dac1ed
--- /dev/null
+++ b/catalog/rest/auth_test.go
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package rest
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOauth2AuthManager_AuthHeader_StaticToken(t *testing.T) {
+ manager := &Oauth2AuthManager{
+ Token: "static_token",
+ }
+
+ key, value, err := manager.AuthHeader()
+ require.NoError(t, err)
+ assert.Equal(t, "Authorization", key)
+ assert.Equal(t, "Bearer static_token", value)
+}
+
+func TestOauth2AuthManager_AuthHeader_MissingClient(t *testing.T) {
+ manager := &Oauth2AuthManager{
+ Credential: "client:secret",
+ }
+
+ _, _, err := manager.AuthHeader()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "cannot fetch token without http
client")
+}
+
+func TestOauth2AuthManager_AuthHeader_FetchToken_Success(t *testing.T) {
+ mux := http.NewServeMux()
+ server := httptest.NewServer(mux)
+ defer server.Close()
+
+ mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r
*http.Request) {
+ assert.Equal(t, http.MethodPost, r.Method)
+ assert.Equal(t, "client_credentials", r.FormValue("grant_type"))
+ assert.Equal(t, "client", r.FormValue("client_id"))
+ assert.Equal(t, "secret", r.FormValue("client_secret"))
+ assert.Equal(t, "catalog", r.FormValue("scope"))
+
+ w.WriteHeader(http.StatusOK)
+ json.NewEncoder(w).Encode(oauthTokenResponse{
+ AccessToken: "fetched_token",
+ TokenType: "Bearer",
+ ExpiresIn: 3600,
+ })
+ })
+
+ authURL, err := url.Parse(server.URL + "/oauth/token")
+ require.NoError(t, err)
+
+ manager := &Oauth2AuthManager{
+ Credential: "client:secret",
+ AuthURI: authURL,
+ Client: server.Client(),
+ }
+
+ key, value, err := manager.AuthHeader()
+ require.NoError(t, err)
+ assert.Equal(t, "Authorization", key)
+ assert.Equal(t, "Bearer fetched_token", value)
+ assert.Equal(t, "fetched_token", manager.Token)
+}
+
+func TestOauth2AuthManager_AuthHeader_FetchToken_ErrorResponse(t *testing.T) {
+ mux := http.NewServeMux()
+ server := httptest.NewServer(mux)
+ defer server.Close()
+
+ mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r
*http.Request) {
+ w.WriteHeader(http.StatusBadRequest)
+ json.NewEncoder(w).Encode(oauthErrorResponse{
+ Err: "invalid_client",
+ ErrDesc: "Invalid client credentials",
+ })
+ })
+
+ authURL, err := url.Parse(server.URL + "/oauth/token")
+ require.NoError(t, err)
+
+ manager := &Oauth2AuthManager{
+ Credential: "client:secret",
+ AuthURI: authURL,
+ Client: server.Client(),
+ }
+
+ _, _, err = manager.AuthHeader()
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid_client: Invalid client
credentials")
+}
diff --git a/catalog/rest/options.go b/catalog/rest/options.go
index 3308c510..7f5add71 100644
--- a/catalog/rest/options.go
+++ b/catalog/rest/options.go
@@ -46,6 +46,12 @@ func WithHeaders(headers map[string]string) Option {
}
}
+func WithAuthManager(authManager AuthManager) Option {
+ return func(o *options) {
+ o.authManager = authManager
+ }
+}
+
func WithTLSConfig(config *tls.Config) Option {
return func(o *options) {
o.tlsConfig = config
@@ -127,8 +133,9 @@ type options struct {
awsConfig aws.Config
awsConfigSet bool
tlsConfig *tls.Config
- credential string
oauthToken string
+ credential string
+ authManager AuthManager
warehouseLocation string
metadataLocation string
enableSigv4 bool
diff --git a/catalog/rest/rest.go b/catalog/rest/rest.go
index e0d1c94b..3541d72c 100644
--- a/catalog/rest/rest.go
+++ b/catalog/rest/rest.go
@@ -161,34 +161,6 @@ type createTableRequest struct {
Props iceberg.Properties `json:"properties,omitempty"`
}
-type oauthTokenResponse struct {
- AccessToken string `json:"access_token"`
- TokenType string `json:"token_type"`
- ExpiresIn int `json:"expires_in"`
- Scope string `json:"scope"`
- RefreshToken string `json:"refresh_token"`
-}
-
-type oauthErrorResponse struct {
- Err string `json:"error"`
- ErrDesc string `json:"error_description"`
- ErrURI string `json:"error_uri"`
-}
-
-func (o oauthErrorResponse) Unwrap() error { return ErrOAuthError }
-func (o oauthErrorResponse) Error() string {
- msg := o.Err
- if o.ErrDesc != "" {
- msg += ": " + o.ErrDesc
- }
-
- if o.ErrURI != "" {
- msg += " (" + o.ErrURI + ")"
- }
-
- return msg
-}
-
type configResponse struct {
Defaults iceberg.Properties `json:"defaults"`
Overrides iceberg.Properties `json:"overrides"`
@@ -409,8 +381,6 @@ func handleNon200(rsp *http.Response, override
map[int]error) error {
func fromProps(props iceberg.Properties, o *options) {
for k, v := range props {
switch k {
- case keyOauthToken:
- o.oauthToken = v
case keyWarehouseLocation:
o.warehouseLocation = v
case keyMetadataLocation:
@@ -465,7 +435,6 @@ func toProps(o *options) iceberg.Properties {
}
setIf(keyOauthCredential, o.credential)
- setIf(keyOauthToken, o.oauthToken)
setIf(keyWarehouseLocation, o.warehouseLocation)
setIf(keyMetadataLocation, o.metadataLocation)
if o.enableSigv4 {
@@ -520,6 +489,23 @@ func NewCatalog(ctx context.Context, name, uri string,
opts ...Option) (*Catalog
return r, nil
}
+// setupOAuthManager creates an Oauth2AuthManager based on the provided
options.
+// The allows users to set their token, credential, or just get the defaults
if no auth manager is set.
+func setupOAuthManager(r *Catalog, cl *http.Client, opts *options)
*Oauth2AuthManager {
+ authURI := opts.authUri
+ if authURI == nil {
+ authURI = r.baseURI.JoinPath("oauth/tokens")
+ }
+
+ return &Oauth2AuthManager{
+ Token: opts.oauthToken,
+ Credential: opts.credential,
+ AuthURI: authURI,
+ Scope: opts.scope,
+ Client: cl,
+ }
+}
+
func (r *Catalog) init(ctx context.Context, ops *options, uri string) error {
baseuri, err := url.Parse(uri)
if err != nil {
@@ -539,62 +525,6 @@ func (r *Catalog) init(ctx context.Context, ops *options,
uri string) error {
return nil
}
-func (r *Catalog) fetchAccessToken(cl *http.Client, creds string, opts
*options) (string, error) {
- clientID, clientSecret, hasID := strings.Cut(creds, ":")
- if !hasID {
- clientID, clientSecret = "", clientID
- }
-
- scope := "catalog"
- if opts.scope != "" {
- scope = opts.scope
- }
- data := url.Values{
- "grant_type": {"client_credentials"},
- "client_id": {clientID},
- "client_secret": {clientSecret},
- "scope": {scope},
- }
-
- uri := opts.authUri
- if uri == nil {
- uri = r.baseURI.JoinPath("oauth/tokens")
- }
-
- rsp, err := cl.PostForm(uri.String(), data)
- if err != nil {
- return "", err
- }
-
- if rsp.StatusCode == http.StatusOK {
- defer rsp.Body.Close()
- dec := json.NewDecoder(rsp.Body)
- var tok oauthTokenResponse
- if err := dec.Decode(&tok); err != nil {
- return "", fmt.Errorf("failed to decode oauth token
response: %w", err)
- }
-
- return tok.AccessToken, nil
- }
-
- switch rsp.StatusCode {
- case http.StatusUnauthorized, http.StatusBadRequest:
- defer func() {
- _, _ = io.Copy(io.Discard, rsp.Body)
- _ = rsp.Body.Close()
- }()
- dec := json.NewDecoder(rsp.Body)
- var oauthErr oauthErrorResponse
- if err := dec.Decode(&oauthErr); err != nil {
- return "", fmt.Errorf("failed to decode oauth error:
%w", err)
- }
-
- return "", oauthErr
- default:
- return "", handleNon200(rsp, nil)
- }
-}
-
func (r *Catalog) createSession(ctx context.Context, opts *options)
(*http.Client, error) {
session := &sessionTransport{
defaultHeaders: http.Header{},
@@ -606,27 +536,26 @@ func (r *Catalog) createSession(ctx context.Context, opts
*options) (*http.Clien
}
cl := &http.Client{Transport: session}
- for k, v := range opts.headers {
- session.defaultHeaders.Set(k, v)
+ // If the user does not set an AuthManager, we can construct an
OAuth2AuthManager based off their options.
+ if opts.authManager == nil {
+ opts.authManager = setupOAuthManager(r, cl, opts)
}
session.defaultHeaders.Set("X-Client-Version", icebergRestSpecVersion)
session.defaultHeaders.Set("Content-Type", "application/json")
session.defaultHeaders.Set("User-Agent", "GoIceberg/"+iceberg.Version())
- if session.defaultHeaders.Get("X-Iceberg-Access-Delegation") == "" {
- session.defaultHeaders.Set("X-Iceberg-Access-Delegation",
"vended-credentials")
- }
+ session.defaultHeaders.Set("X-Iceberg-Access-Delegation",
"vended-credentials")
- token := opts.oauthToken
- if token == "" && opts.credential != "" {
- var err error
- if token, err = r.fetchAccessToken(cl, opts.credential, opts);
err != nil {
- return nil, fmt.Errorf("auth error: %w", err)
- }
+ for k, v := range opts.headers {
+ session.defaultHeaders.Set(k, v)
}
- if token != "" {
- session.defaultHeaders.Set(authorizationHeader, bearerPrefix+"
"+token)
+ if opts.authManager != nil {
+ k, v, err := opts.authManager.AuthHeader()
+ if err != nil {
+ return nil, err
+ }
+ session.defaultHeaders.Set(k, v)
}
if opts.enableSigv4 {