This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-go.git


The following commit(s) were added to refs/heads/main by this push:
     new ccdfcbf  fix(catalog/rest): Fix concurrency bug in REST catalog 
request signing (#384)
ccdfcbf is described below

commit ccdfcbf1d524de9fdaad2ac17b9ef8b52cad3254
Author: Joshua Humphries <2035234+jh...@users.noreply.github.com>
AuthorDate: Sat Apr 19 11:53:48 2025 -0400

    fix(catalog/rest): Fix concurrency bug in REST catalog request signing 
(#384)
    
    A hasher is not thread-safe, yet the same hasher was being used for all
    requests. If applications made concurrent calls to the same REST catalog
    implementation, they could end up writing to the same hasher, corrupting
    the signatures for both concurrent requests.
    
    This makes things thread-safe by creating a hasher for each signing
    operation. The could be safely re-used using a `sync.Pool`. But the
    hasher is only 116 bytes, and initialization just has to write 8 bytes
    (other than the zero'ing done by the allocator), so it doesn't seem
    worth trying to re-use them.
---
 catalog/rest/options.go            |   2 +
 catalog/rest/rest.go               |  53 ++++++++++---------
 catalog/rest/rest_internal_test.go | 102 +++++++++++++++++++++++++++++++++++++
 3 files changed, 130 insertions(+), 27 deletions(-)

diff --git a/catalog/rest/options.go b/catalog/rest/options.go
index ae1782c..b65854a 100644
--- a/catalog/rest/options.go
+++ b/catalog/rest/options.go
@@ -92,6 +92,7 @@ func WithPrefix(prefix string) Option {
 func WithAwsConfig(cfg aws.Config) Option {
        return func(o *options) {
                o.awsConfig = cfg
+               o.awsConfigSet = true
        }
 }
 
@@ -109,6 +110,7 @@ func WithAdditionalProps(props iceberg.Properties) Option {
 
 type options struct {
        awsConfig         aws.Config
+       awsConfigSet      bool
        tlsConfig         *tls.Config
        credential        string
        oauthToken        string
diff --git a/catalog/rest/rest.go b/catalog/rest/rest.go
index f6e66e7..e14e8b0 100644
--- a/catalog/rest/rest.go
+++ b/catalog/rest/rest.go
@@ -22,6 +22,7 @@ import (
        "context"
        "crypto/sha256"
        "crypto/tls"
+       "encoding/hex"
        "encoding/json"
        "errors"
        "fmt"
@@ -198,7 +199,7 @@ type sessionTransport struct {
        signer         v4.HTTPSigner
        cfg            aws.Config
        service        string
-       h              hash.Hash
+       newHash        func() hash.Hash
 }
 
 // from 
https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws/signer/v4#Signer.SignHTTP
@@ -221,12 +222,12 @@ func (s *sessionTransport) RoundTrip(r *http.Request) 
(*http.Response, error) {
                                return nil, err
                        }
 
-                       if _, err = io.Copy(s.h, rdr); err != nil {
+                       h := s.newHash()
+                       if _, err = io.Copy(h, rdr); err != nil {
                                return nil, err
                        }
 
-                       payloadHash = string(s.h.Sum(nil))
-                       s.h.Reset()
+                       payloadHash = hex.EncodeToString(h.Sum(nil))
                }
 
                creds, err := s.cfg.Credentials.Retrieve(r.Context())
@@ -375,10 +376,7 @@ func handleNon200(rsp *http.Response, override 
map[int]error) error {
        return e
 }
 
-func fromProps(props iceberg.Properties) *options {
-       o := &options{
-               additionalProps: iceberg.Properties{},
-       }
+func fromProps(props iceberg.Properties, o *options) {
        for k, v := range props {
                switch k {
                case keyOauthToken:
@@ -415,21 +413,18 @@ func fromProps(props iceberg.Properties) *options {
                case "uri", "type":
                default:
                        if v != "" {
+                               if o.additionalProps == nil {
+                                       o.additionalProps = iceberg.Properties{}
+                               }
                                o.additionalProps[k] = v
                        }
                }
        }
-
-       return o
 }
 
 func toProps(o *options) iceberg.Properties {
-       var props iceberg.Properties
-       if o.additionalProps != nil {
-               props = o.additionalProps
-       } else {
-               props = iceberg.Properties{}
-       }
+       props := iceberg.Properties{}
+       maps.Copy(props, o.additionalProps)
 
        setIf := func(key, v string) {
                if v != "" {
@@ -464,10 +459,11 @@ type Catalog struct {
 }
 
 func newCatalogFromProps(ctx context.Context, name string, uri string, p 
iceberg.Properties) (*Catalog, error) {
-       ops := fromProps(p)
+       var ops options
+       fromProps(p, &ops)
 
        r := &Catalog{name: name}
-       if err := r.init(ctx, ops, uri); err != nil {
+       if err := r.init(ctx, &ops, uri); err != nil {
                return nil, err
        }
 
@@ -585,17 +581,21 @@ func (r *Catalog) createSession(ctx context.Context, opts 
*options) (*http.Clien
        session.defaultHeaders.Set("X-Iceberg-Access-Delegation", 
"vended-credentials")
 
        if opts.enableSigv4 {
-               cfg, err := config.LoadDefaultConfig(ctx)
-               if err != nil {
-                       return nil, err
+               cfg := opts.awsConfig
+               if !opts.awsConfigSet {
+                       // If no config provided, load defaults from 
environment.
+                       var err error
+                       cfg, err = config.LoadDefaultConfig(ctx)
+                       if err != nil {
+                               return nil, err
+                       }
                }
-
                if opts.sigv4Region != "" {
                        cfg.Region = opts.sigv4Region
                }
 
                session.cfg, session.service = cfg, opts.sigv4Service
-               session.signer, session.h = v4.NewSigner(), sha256.New()
+               session.signer, session.newHash = v4.NewSigner(), sha256.New
        }
 
        return cl, nil
@@ -627,9 +627,8 @@ func (r *Catalog) fetchConfig(ctx context.Context, opts 
*options) (*http.Client,
        maps.Copy(cfg, toProps(opts))
        maps.Copy(cfg, rsp.Overrides)
 
-       o := fromProps(cfg)
-       o.awsConfig = opts.awsConfig
-       o.tlsConfig = opts.tlsConfig
+       o := *opts
+       fromProps(cfg, &o)
 
        if uri, ok := cfg["uri"]; ok {
                r.baseURI, err = url.Parse(uri)
@@ -639,7 +638,7 @@ func (r *Catalog) fetchConfig(ctx context.Context, opts 
*options) (*http.Client,
                r.baseURI = r.baseURI.JoinPath("v1")
        }
 
-       return sess, o, nil
+       return sess, &o, nil
 }
 
 func (r *Catalog) Name() string              { return r.name }
diff --git a/catalog/rest/rest_internal_test.go 
b/catalog/rest/rest_internal_test.go
index af3517e..4595c57 100644
--- a/catalog/rest/rest_internal_test.go
+++ b/catalog/rest/rest_internal_test.go
@@ -18,18 +18,32 @@
 package rest
 
 import (
+       "bytes"
        "context"
+       "crypto/rand"
+       "crypto/sha256"
+       "crypto/tls"
+       "crypto/x509"
+       "encoding/hex"
        "encoding/json"
+       "io"
        "net/http"
        "net/http/httptest"
        "net/url"
+       "sync/atomic"
        "testing"
+       "time"
 
+       "github.com/aws/aws-sdk-go-v2/aws"
+       "github.com/aws/aws-sdk-go-v2/config"
+       "github.com/aws/aws-sdk-go-v2/credentials"
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
+       "golang.org/x/sync/errgroup"
 )
 
 func TestAuthHeader(t *testing.T) {
+       t.Parallel()
        mux := http.NewServeMux()
        srv := httptest.NewServer(mux)
 
@@ -77,6 +91,7 @@ func TestAuthHeader(t *testing.T) {
 }
 
 func TestAuthUriHeader(t *testing.T) {
+       t.Parallel()
        mux := http.NewServeMux()
        srv := httptest.NewServer(mux)
 
@@ -124,3 +139,90 @@ func TestAuthUriHeader(t *testing.T) {
                "X-Iceberg-Access-Delegation": {"vended-credentials"},
        }, cat.cl.Transport.(*sessionTransport).defaultHeaders)
 }
+
+func TestSigv4EmptyStringHash(t *testing.T) {
+       t.Parallel()
+       hash := sha256.New()
+       payloadHash := hex.EncodeToString(hash.Sum(nil))
+       // Sanity check the constant.
+       require.Equal(t, payloadHash, emptyStringHash)
+}
+
+func TestSigv4ConcurrentSigners(t *testing.T) {
+       t.Parallel()
+       mux := http.NewServeMux()
+       srv := httptest.NewUnstartedServer(mux)
+       // If we use HTTP 1.1, this test can try to make too many connections
+       // and exhaust ephemeral ports.
+       srv.EnableHTTP2 = true
+       srv.StartTLS() // Using TLS to easily support HTTP/2
+       rootCAs := x509.NewCertPool()
+       rootCAs.AddCert(srv.Certificate())
+
+       mux.HandleFunc("/v1/config", func(w http.ResponseWriter, r 
*http.Request) {
+               json.NewEncoder(w).Encode(map[string]any{
+                       "defaults": map[string]any{}, "overrides": 
map[string]any{},
+               })
+       })
+
+       cfg, err := config.LoadDefaultConfig(context.Background(), func(opts 
*config.LoadOptions) error {
+               opts.Credentials = credentials.StaticCredentialsProvider{
+                       Value: aws.Credentials{
+                               AccessKeyID:     "abcdefghjklmnop",
+                               SecretAccessKey: 
"01234567abcdefgh01234567abcdefgh01234567abcdefgh01234567abcdefgh",
+                       },
+               }
+
+               return nil
+       })
+       require.NoError(t, err)
+
+       cat, err := NewCatalog(context.Background(), "rest", srv.URL,
+               WithSigV4(),
+               WithSigV4RegionSvc("abc", "def"),
+               WithAwsConfig(cfg),
+               WithTLSConfig(&tls.Config{
+                       RootCAs: rootCAs,
+               }))
+       require.NoError(t, err)
+       assert.NotNil(t, cat)
+
+       // We aren't recreating the signature logic to verify on the server. 
We're
+       // just running many concurrent requests to make sure the race detector
+       // doesn't find any data races with how the session transport and signer
+       // are used from concurrent goroutines.
+       ctx, cancel := context.WithCancel(context.Background())
+       grp, ctx := errgroup.WithContext(ctx)
+       var count atomic.Uint64
+       for range 10 {
+               grp.Go(func() error {
+                       for {
+                               if err := ctx.Err(); err != nil {
+                                       return nil
+                               }
+                               body := make([]byte, 1024)
+                               if _, err := rand.Read(body); err != nil {
+                                       return err
+                               }
+                               // Intentionally using context.Background 
instead of ctx so that we
+                               // don't get interrupted when context is 
cancelled.
+                               req, err := 
http.NewRequestWithContext(context.Background(), http.MethodPost, srv.URL, 
bytes.NewReader(body))
+                               if err != nil {
+                                       return err
+                               }
+                               resp, err := cat.cl.Do(req)
+                               if err != nil {
+                                       return err
+                               }
+                               // We don't actually care about the response, 
only that it actually made it to the server.
+                               _, _ = io.Copy(io.Discard, resp.Body)
+                               _ = resp.Body.Close()
+                               count.Add(1)
+                       }
+               })
+       }
+       time.Sleep(5 * time.Second)
+       cancel()
+       require.NoError(t, grp.Wait())
+       t.Logf("issued %d requests", count.Load())
+}

Reply via email to