This is an automated email from the ASF dual-hosted git repository. zeroshade pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/iceberg-go.git
The following commit(s) were added to refs/heads/main by this push: new ccdfcbf fix(catalog/rest): Fix concurrency bug in REST catalog request signing (#384) ccdfcbf is described below commit ccdfcbf1d524de9fdaad2ac17b9ef8b52cad3254 Author: Joshua Humphries <2035234+jh...@users.noreply.github.com> AuthorDate: Sat Apr 19 11:53:48 2025 -0400 fix(catalog/rest): Fix concurrency bug in REST catalog request signing (#384) A hasher is not thread-safe, yet the same hasher was being used for all requests. If applications made concurrent calls to the same REST catalog implementation, they could end up writing to the same hasher, corrupting the signatures for both concurrent requests. This makes things thread-safe by creating a hasher for each signing operation. The could be safely re-used using a `sync.Pool`. But the hasher is only 116 bytes, and initialization just has to write 8 bytes (other than the zero'ing done by the allocator), so it doesn't seem worth trying to re-use them. --- catalog/rest/options.go | 2 + catalog/rest/rest.go | 53 ++++++++++--------- catalog/rest/rest_internal_test.go | 102 +++++++++++++++++++++++++++++++++++++ 3 files changed, 130 insertions(+), 27 deletions(-) diff --git a/catalog/rest/options.go b/catalog/rest/options.go index ae1782c..b65854a 100644 --- a/catalog/rest/options.go +++ b/catalog/rest/options.go @@ -92,6 +92,7 @@ func WithPrefix(prefix string) Option { func WithAwsConfig(cfg aws.Config) Option { return func(o *options) { o.awsConfig = cfg + o.awsConfigSet = true } } @@ -109,6 +110,7 @@ func WithAdditionalProps(props iceberg.Properties) Option { type options struct { awsConfig aws.Config + awsConfigSet bool tlsConfig *tls.Config credential string oauthToken string diff --git a/catalog/rest/rest.go b/catalog/rest/rest.go index f6e66e7..e14e8b0 100644 --- a/catalog/rest/rest.go +++ b/catalog/rest/rest.go @@ -22,6 +22,7 @@ import ( "context" "crypto/sha256" "crypto/tls" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -198,7 +199,7 @@ type sessionTransport struct { signer v4.HTTPSigner cfg aws.Config service string - h hash.Hash + newHash func() hash.Hash } // from https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws/signer/v4#Signer.SignHTTP @@ -221,12 +222,12 @@ func (s *sessionTransport) RoundTrip(r *http.Request) (*http.Response, error) { return nil, err } - if _, err = io.Copy(s.h, rdr); err != nil { + h := s.newHash() + if _, err = io.Copy(h, rdr); err != nil { return nil, err } - payloadHash = string(s.h.Sum(nil)) - s.h.Reset() + payloadHash = hex.EncodeToString(h.Sum(nil)) } creds, err := s.cfg.Credentials.Retrieve(r.Context()) @@ -375,10 +376,7 @@ func handleNon200(rsp *http.Response, override map[int]error) error { return e } -func fromProps(props iceberg.Properties) *options { - o := &options{ - additionalProps: iceberg.Properties{}, - } +func fromProps(props iceberg.Properties, o *options) { for k, v := range props { switch k { case keyOauthToken: @@ -415,21 +413,18 @@ func fromProps(props iceberg.Properties) *options { case "uri", "type": default: if v != "" { + if o.additionalProps == nil { + o.additionalProps = iceberg.Properties{} + } o.additionalProps[k] = v } } } - - return o } func toProps(o *options) iceberg.Properties { - var props iceberg.Properties - if o.additionalProps != nil { - props = o.additionalProps - } else { - props = iceberg.Properties{} - } + props := iceberg.Properties{} + maps.Copy(props, o.additionalProps) setIf := func(key, v string) { if v != "" { @@ -464,10 +459,11 @@ type Catalog struct { } func newCatalogFromProps(ctx context.Context, name string, uri string, p iceberg.Properties) (*Catalog, error) { - ops := fromProps(p) + var ops options + fromProps(p, &ops) r := &Catalog{name: name} - if err := r.init(ctx, ops, uri); err != nil { + if err := r.init(ctx, &ops, uri); err != nil { return nil, err } @@ -585,17 +581,21 @@ func (r *Catalog) createSession(ctx context.Context, opts *options) (*http.Clien session.defaultHeaders.Set("X-Iceberg-Access-Delegation", "vended-credentials") if opts.enableSigv4 { - cfg, err := config.LoadDefaultConfig(ctx) - if err != nil { - return nil, err + cfg := opts.awsConfig + if !opts.awsConfigSet { + // If no config provided, load defaults from environment. + var err error + cfg, err = config.LoadDefaultConfig(ctx) + if err != nil { + return nil, err + } } - if opts.sigv4Region != "" { cfg.Region = opts.sigv4Region } session.cfg, session.service = cfg, opts.sigv4Service - session.signer, session.h = v4.NewSigner(), sha256.New() + session.signer, session.newHash = v4.NewSigner(), sha256.New } return cl, nil @@ -627,9 +627,8 @@ func (r *Catalog) fetchConfig(ctx context.Context, opts *options) (*http.Client, maps.Copy(cfg, toProps(opts)) maps.Copy(cfg, rsp.Overrides) - o := fromProps(cfg) - o.awsConfig = opts.awsConfig - o.tlsConfig = opts.tlsConfig + o := *opts + fromProps(cfg, &o) if uri, ok := cfg["uri"]; ok { r.baseURI, err = url.Parse(uri) @@ -639,7 +638,7 @@ func (r *Catalog) fetchConfig(ctx context.Context, opts *options) (*http.Client, r.baseURI = r.baseURI.JoinPath("v1") } - return sess, o, nil + return sess, &o, nil } func (r *Catalog) Name() string { return r.name } diff --git a/catalog/rest/rest_internal_test.go b/catalog/rest/rest_internal_test.go index af3517e..4595c57 100644 --- a/catalog/rest/rest_internal_test.go +++ b/catalog/rest/rest_internal_test.go @@ -18,18 +18,32 @@ package rest import ( + "bytes" "context" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/hex" "encoding/json" + "io" "net/http" "net/http/httptest" "net/url" + "sync/atomic" "testing" + "time" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) func TestAuthHeader(t *testing.T) { + t.Parallel() mux := http.NewServeMux() srv := httptest.NewServer(mux) @@ -77,6 +91,7 @@ func TestAuthHeader(t *testing.T) { } func TestAuthUriHeader(t *testing.T) { + t.Parallel() mux := http.NewServeMux() srv := httptest.NewServer(mux) @@ -124,3 +139,90 @@ func TestAuthUriHeader(t *testing.T) { "X-Iceberg-Access-Delegation": {"vended-credentials"}, }, cat.cl.Transport.(*sessionTransport).defaultHeaders) } + +func TestSigv4EmptyStringHash(t *testing.T) { + t.Parallel() + hash := sha256.New() + payloadHash := hex.EncodeToString(hash.Sum(nil)) + // Sanity check the constant. + require.Equal(t, payloadHash, emptyStringHash) +} + +func TestSigv4ConcurrentSigners(t *testing.T) { + t.Parallel() + mux := http.NewServeMux() + srv := httptest.NewUnstartedServer(mux) + // If we use HTTP 1.1, this test can try to make too many connections + // and exhaust ephemeral ports. + srv.EnableHTTP2 = true + srv.StartTLS() // Using TLS to easily support HTTP/2 + rootCAs := x509.NewCertPool() + rootCAs.AddCert(srv.Certificate()) + + mux.HandleFunc("/v1/config", func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "defaults": map[string]any{}, "overrides": map[string]any{}, + }) + }) + + cfg, err := config.LoadDefaultConfig(context.Background(), func(opts *config.LoadOptions) error { + opts.Credentials = credentials.StaticCredentialsProvider{ + Value: aws.Credentials{ + AccessKeyID: "abcdefghjklmnop", + SecretAccessKey: "01234567abcdefgh01234567abcdefgh01234567abcdefgh01234567abcdefgh", + }, + } + + return nil + }) + require.NoError(t, err) + + cat, err := NewCatalog(context.Background(), "rest", srv.URL, + WithSigV4(), + WithSigV4RegionSvc("abc", "def"), + WithAwsConfig(cfg), + WithTLSConfig(&tls.Config{ + RootCAs: rootCAs, + })) + require.NoError(t, err) + assert.NotNil(t, cat) + + // We aren't recreating the signature logic to verify on the server. We're + // just running many concurrent requests to make sure the race detector + // doesn't find any data races with how the session transport and signer + // are used from concurrent goroutines. + ctx, cancel := context.WithCancel(context.Background()) + grp, ctx := errgroup.WithContext(ctx) + var count atomic.Uint64 + for range 10 { + grp.Go(func() error { + for { + if err := ctx.Err(); err != nil { + return nil + } + body := make([]byte, 1024) + if _, err := rand.Read(body); err != nil { + return err + } + // Intentionally using context.Background instead of ctx so that we + // don't get interrupted when context is cancelled. + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, srv.URL, bytes.NewReader(body)) + if err != nil { + return err + } + resp, err := cat.cl.Do(req) + if err != nil { + return err + } + // We don't actually care about the response, only that it actually made it to the server. + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + count.Add(1) + } + }) + } + time.Sleep(5 * time.Second) + cancel() + require.NoError(t, grp.Wait()) + t.Logf("issued %d requests", count.Load()) +}