This is an automated email from the ASF dual-hosted git repository. xiazcy pushed a commit to branch go-http-fix in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 7da0e97fa2faa5000e983f53e3753840e488eb12 Author: Yang Xia <[email protected]> AuthorDate: Tue Mar 24 22:45:32 2026 -0700 Update order of interceptor and serialization of request --- gremlin-go/driver/auth.go | 17 +- gremlin-go/driver/auth_test.go | 26 +-- gremlin-go/driver/connection.go | 121 +++++------ gremlin-go/driver/connection_test.go | 33 +++ gremlin-go/driver/interceptor.go | 113 ++++++++++ gremlin-go/driver/interceptor_test.go | 380 ++++++++++++++++++++++++++++++++++ 6 files changed, 602 insertions(+), 88 deletions(-) diff --git a/gremlin-go/driver/auth.go b/gremlin-go/driver/auth.go index 74ca43a444..2f6f8b9a4e 100644 --- a/gremlin-go/driver/auth.go +++ b/gremlin-go/driver/auth.go @@ -22,6 +22,7 @@ package gremlingo import ( "context" "encoding/base64" + "fmt" "sync" "time" @@ -39,17 +40,17 @@ func BasicAuth(username, password string) RequestInterceptor { } } -// Sigv4Auth returns a RequestInterceptor that signs requests using AWS SigV4. +// SigV4Auth returns a RequestInterceptor that signs requests using AWS SigV4. // It uses the default AWS credential chain (env vars, shared config, IAM role, etc.) -func Sigv4Auth(region, service string) RequestInterceptor { - return Sigv4AuthWithCredentials(region, service, nil) +func SigV4Auth(region, service string) RequestInterceptor { + return SigV4AuthWithCredentials(region, service, nil) } -// Sigv4AuthWithCredentials returns a RequestInterceptor that signs requests using AWS SigV4 +// SigV4AuthWithCredentials returns a RequestInterceptor that signs requests using AWS SigV4 // with the provided credentials provider. If provider is nil, uses default credential chain. // // Caches the signer and credentials provider for efficiency. -func Sigv4AuthWithCredentials(region, service string, credentialsProvider aws.CredentialsProvider) RequestInterceptor { +func SigV4AuthWithCredentials(region, service string, credentialsProvider aws.CredentialsProvider) RequestInterceptor { // Create signer once - it's stateless and safe to reuse signer := v4.NewSigner() @@ -59,6 +60,12 @@ func Sigv4AuthWithCredentials(region, service string, credentialsProvider aws.Cr var providerErr error return func(req *HttpRequest) error { + // SigV4 requires serialized body bytes to compute the payload hash. + if _, ok := req.Body.([]byte); !ok { + return fmt.Errorf("SigV4 signing requires serialized body bytes ([]byte); got %T. "+ + "Place SigV4Auth after serialization in the interceptor chain", req.Body) + } + ctx := context.Background() // Resolve credentials provider once if not provided diff --git a/gremlin-go/driver/auth_test.go b/gremlin-go/driver/auth_test.go index ba60f6e6c5..7ec4079b6e 100644 --- a/gremlin-go/driver/auth_test.go +++ b/gremlin-go/driver/auth_test.go @@ -30,7 +30,7 @@ import ( ) func createMockRequest() *HttpRequest { - req, _ := NewHttpRequest("POST", "https://localhost:8182/gremlin") + req, _ := NewHttpRequest("POST", "https://test_url:8182/gremlin") req.Headers.Set("Content-Type", graphBinaryMimeType) req.Headers.Set("Accept", graphBinaryMimeType) req.Body = []byte(`{"gremlin":"g.V()"}`) @@ -72,24 +72,24 @@ func (m *mockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials }, nil } -func TestSigv4Auth(t *testing.T) { +func TestSigV4Auth(t *testing.T) { t.Run("adds signed headers", func(t *testing.T) { req := createMockRequest() assert.Empty(t, req.Headers.Get("Authorization")) assert.Empty(t, req.Headers.Get("X-Amz-Date")) provider := &mockCredentialsProvider{ - accessKey: "MOCK_ACCESS_KEY", - secretKey: "MOCK_SECRET_KEY", + accessKey: "MOCK_ID", + secretKey: "MOCK_KEY", } - interceptor := Sigv4AuthWithCredentials("us-west-2", "neptune-db", provider) + interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider) err := interceptor(req) assert.NoError(t, err) assert.NotEmpty(t, req.Headers.Get("X-Amz-Date")) authHeader := req.Headers.Get("Authorization") - assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential=MOCK_ACCESS_KEY")) - assert.Contains(t, authHeader, "us-west-2/neptune-db/aws4_request") + assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential=MOCK_ID")) + assert.Contains(t, authHeader, "gremlin-east-1/tinkerpop-sigv4/aws4_request") assert.Contains(t, authHeader, "Signature=") }) @@ -98,17 +98,17 @@ func TestSigv4Auth(t *testing.T) { assert.Empty(t, req.Headers.Get("X-Amz-Security-Token")) provider := &mockCredentialsProvider{ - accessKey: "MOCK_ACCESS_KEY", - secretKey: "MOCK_SECRET_KEY", - sessionToken: "MOCK_SESSION_TOKEN", + accessKey: "MOCK_ID", + secretKey: "MOCK_KEY", + sessionToken: "MOCK_TOKEN", } - interceptor := Sigv4AuthWithCredentials("us-west-2", "neptune-db", provider) + interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider) err := interceptor(req) assert.NoError(t, err) - assert.Equal(t, "MOCK_SESSION_TOKEN", req.Headers.Get("X-Amz-Security-Token")) + assert.Equal(t, "MOCK_TOKEN", req.Headers.Get("X-Amz-Security-Token")) authHeader := req.Headers.Get("Authorization") assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential=")) - assert.Contains(t, authHeader, "Signature=") + assert.Contains(t, authHeader, "gremlin-east-1/tinkerpop-sigv4/aws4_request") }) } diff --git a/gremlin-go/driver/connection.go b/gremlin-go/driver/connection.go index 5086965a8d..562bad80a3 100644 --- a/gremlin-go/driver/connection.go +++ b/gremlin-go/driver/connection.go @@ -32,59 +32,6 @@ import ( "time" ) -// Common HTTP header keys -const ( - HeaderContentType = "Content-Type" - HeaderAccept = "Accept" - HeaderUserAgent = "User-Agent" - HeaderAcceptEncoding = "Accept-Encoding" - HeaderAuthorization = "Authorization" -) - -// HttpRequest represents an HTTP request that can be modified by interceptors. -type HttpRequest struct { - Method string - URL *url.URL - Headers http.Header - Body []byte -} - -// NewHttpRequest creates a new HttpRequest with the given method and URL. -func NewHttpRequest(method, rawURL string) (*HttpRequest, error) { - u, err := url.Parse(rawURL) - if err != nil { - return nil, err - } - return &HttpRequest{ - Method: method, - URL: u, - Headers: make(http.Header), - }, nil -} - -// ToStdRequest converts HttpRequest to a standard http.Request for signing. -// Returns nil if the request cannot be created (invalid method or URL). -func (r *HttpRequest) ToStdRequest() (*http.Request, error) { - req, err := http.NewRequest(r.Method, r.URL.String(), bytes.NewReader(r.Body)) - if err != nil { - return nil, err - } - req.Header = r.Headers - return req, nil -} - -// PayloadHash returns the SHA256 hash of the request body for SigV4 signing. -func (r *HttpRequest) PayloadHash() string { - if len(r.Body) == 0 { - return "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of empty string - } - h := sha256.Sum256(r.Body) - return hex.EncodeToString(h[:]) -} - -// RequestInterceptor is a function that modifies an HTTP request before it is sent. -type RequestInterceptor func(*HttpRequest) error - // connectionSettings holds configuration for the connection. type connectionSettings struct { tlsConfig *tls.Config @@ -174,18 +121,12 @@ func (c *connection) AddInterceptor(interceptor RequestInterceptor) { func (c *connection) submit(req *request) (ResultSet, error) { rs := newChannelResultSet() - data, err := c.serializer.SerializeMessage(req) - if err != nil { - rs.Close() - return rs, err - } - - go c.executeAndStream(data, rs) + go c.executeAndStream(req, rs) return rs, nil } -func (c *connection) executeAndStream(data []byte, rs ResultSet) { +func (c *connection) executeAndStream(req *request, rs ResultSet) { defer rs.Close() // Create HttpRequest for interceptors @@ -195,12 +136,15 @@ func (c *connection) executeAndStream(data []byte, rs ResultSet) { rs.setError(err) return } - httpReq.Body = data // Set default headers before interceptors c.setHttpRequestHeaders(httpReq) - // Apply interceptors + // Set Body to the raw *request so interceptors can inspect/modify it + httpReq.Body = req + + // Apply interceptors — they see *request in Body (pre-serialization). + // Interceptors may replace Body with []byte, io.Reader, or *http.Request. for _, interceptor := range c.interceptors { if err := interceptor(httpReq); err != nil { c.logHandler.logf(Error, failedToSendRequest, err.Error()) @@ -209,16 +153,53 @@ func (c *connection) executeAndStream(data []byte, rs ResultSet) { } } - // Create actual http.Request from HttpRequest - req, err := http.NewRequest(httpReq.Method, httpReq.URL.String(), bytes.NewReader(httpReq.Body)) - if err != nil { - c.logHandler.logf(Error, failedToSendRequest, err.Error()) - rs.setError(err) + // After interceptors, serialize if Body is still *request + if r, ok := httpReq.Body.(*request); ok { + if c.serializer != nil { + data, err := c.serializer.SerializeMessage(r) + if err != nil { + c.logHandler.logf(Error, failedToSendRequest, err.Error()) + rs.setError(err) + return + } + httpReq.Body = data + } else { + errMsg := "request body was not serialized; either provide a serializer or add an interceptor that serializes the request" + c.logHandler.logf(Error, failedToSendRequest, errMsg) + rs.setError(fmt.Errorf("%s", errMsg)) + return + } + } + + // Create actual http.Request from HttpRequest based on Body type + var httpGoReq *http.Request + switch body := httpReq.Body.(type) { + case []byte: + httpGoReq, err = http.NewRequest(httpReq.Method, httpReq.URL.String(), bytes.NewReader(body)) + if err != nil { + c.logHandler.logf(Error, failedToSendRequest, err.Error()) + rs.setError(err) + return + } + httpGoReq.Header = httpReq.Headers + case io.Reader: + httpGoReq, err = http.NewRequest(httpReq.Method, httpReq.URL.String(), body) + if err != nil { + c.logHandler.logf(Error, failedToSendRequest, err.Error()) + rs.setError(err) + return + } + httpGoReq.Header = httpReq.Headers + case *http.Request: + httpGoReq = body + default: + errMsg := fmt.Sprintf("unsupported body type after interceptors: %T", body) + c.logHandler.logf(Error, failedToSendRequest, errMsg) + rs.setError(fmt.Errorf("%s", errMsg)) return } - req.Header = httpReq.Headers - resp, err := c.httpClient.Do(req) + resp, err := c.httpClient.Do(httpGoReq) if err != nil { c.logHandler.logf(Error, failedToSendRequest, err.Error()) rs.setError(err) diff --git a/gremlin-go/driver/connection_test.go b/gremlin-go/driver/connection_test.go index a0e8414223..4bb2de5cce 100644 --- a/gremlin-go/driver/connection_test.go +++ b/gremlin-go/driver/connection_test.go @@ -1261,3 +1261,36 @@ func TestDriverRemoteConnectionSettingsWiring(t *testing.T) { assert.Equal(t, 180*time.Second, transport.IdleConnTimeout) }) } + +// TestConnectionWithMockServer_BasicAuth verifies that BasicAuth interceptor sets the correct +// Authorization header and the body is still valid serialized bytes. +func TestConnectionWithMockServer_BasicAuth(t *testing.T) { + var capturedAuthHeader string + var capturedBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuthHeader = r.Header.Get("Authorization") + body, err := io.ReadAll(r.Body) + if err == nil { + capturedBody = body + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + conn.AddInterceptor(BasicAuth("testuser", "testpass")) + + rs, err := conn.submit(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() // drain + + // BasicAuth should set Authorization header with base64("testuser:testpass") = "dGVzdHVzZXI6dGVzdHBhc3M=" + assert.Equal(t, "Basic dGVzdHVzZXI6dGVzdHBhc3M=", capturedAuthHeader, + "Authorization header should be Basic base64(testuser:testpass)") + + // Body should still be valid serialized bytes + assert.NotEmpty(t, capturedBody, "serialized body should be non-empty with BasicAuth") + assert.Equal(t, byte(0x81), capturedBody[0], + "body should start with GraphBinary version byte 0x81") +} diff --git a/gremlin-go/driver/interceptor.go b/gremlin-go/driver/interceptor.go new file mode 100644 index 0000000000..a5d63a31be --- /dev/null +++ b/gremlin-go/driver/interceptor.go @@ -0,0 +1,113 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "io" + "net/http" + "net/url" +) + +// Common HTTP header keys +const ( + HeaderContentType = "Content-Type" + HeaderAccept = "Accept" + HeaderUserAgent = "User-Agent" + HeaderAcceptEncoding = "Accept-Encoding" + HeaderAuthorization = "Authorization" +) + +// HttpRequest represents an HTTP request that can be modified by interceptors. +type HttpRequest struct { + Method string + URL *url.URL + Headers http.Header + Body any +} + +// NewHttpRequest creates a new HttpRequest with the given method and URL. +func NewHttpRequest(method, rawURL string) (*HttpRequest, error) { + u, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + return &HttpRequest{ + Method: method, + URL: u, + Headers: make(http.Header), + }, nil +} + +// ToStdRequest converts HttpRequest to a standard http.Request for signing. +// Returns nil if the request cannot be created (invalid method or URL). +func (r *HttpRequest) ToStdRequest() (*http.Request, error) { + var body io.Reader + switch b := r.Body.(type) { + case []byte: + body = bytes.NewReader(b) + default: + body = http.NoBody + } + req, err := http.NewRequest(r.Method, r.URL.String(), body) + if err != nil { + return nil, err + } + req.Header = r.Headers + return req, nil +} + +// PayloadHash returns the SHA256 hash of the request body for SigV4 signing. +func (r *HttpRequest) PayloadHash() string { + switch b := r.Body.(type) { + case []byte: + if len(b) == 0 { + return "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of empty string + } + h := sha256.Sum256(b) + return hex.EncodeToString(h[:]) + default: + return "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of empty string + } +} + +// RequestInterceptor is a function that modifies an HTTP request before it is sent. +type RequestInterceptor func(*HttpRequest) error + +// SerializeRequest returns a RequestInterceptor that serializes the raw *request body +// to GraphBinary []byte. Place this before auth interceptors (e.g., SigV4Auth) that +// need the serialized body bytes. +func SerializeRequest() RequestInterceptor { + serializer := newGraphBinarySerializer(nil) + return func(req *HttpRequest) error { + r, ok := req.Body.(*request) + if !ok { + return nil // already serialized or not a *request + } + data, err := serializer.SerializeMessage(r) + if err != nil { + return err + } + req.Body = data + return nil + } +} diff --git a/gremlin-go/driver/interceptor_test.go b/gremlin-go/driver/interceptor_test.go new file mode 100644 index 0000000000..78e36a0b95 --- /dev/null +++ b/gremlin-go/driver/interceptor_test.go @@ -0,0 +1,380 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestInterceptorReceivesRawRequest verifies that interceptors receive the raw *request +// object in HttpRequest.Body, not serialized []byte. +func TestInterceptorReceivesRawRequest(t *testing.T) { + // Mock server that accepts the request (we don't care about the response for this test) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create connection with non-nil serializer (default behavior of newConnection) + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + var capturedBodyType reflect.Type + var capturedBody interface{} + + conn.AddInterceptor(func(req *HttpRequest) error { + capturedBody = req.Body + capturedBodyType = reflect.TypeOf(req.Body) + return nil + }) + + // Submit a request with a known gremlin query + rs, err := conn.submit(&request{gremlin: "g.V().count()", fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() // drain result set + + assert.Equal(t, reflect.TypeOf((*request)(nil)), capturedBodyType, + "interceptor should receive *request in Body, got %v", capturedBodyType) + + r, typeAssertOk := capturedBody.(*request) + assert.True(t, typeAssertOk, "interceptor should be able to type-assert Body to *request") + if typeAssertOk { + assert.Equal(t, "g.V().count()", r.gremlin, + "interceptor should be able to read the gremlin field from the raw request") + } +} + +// TestSigV4AuthWithSerializeInterceptor verifies that SerializeRequest() + SigV4Auth +// works in a chain. SerializeRequest converts *request to []byte, then SigV4Auth +// can sign the serialized body. +func TestSigV4AuthWithSerializeInterceptor(t *testing.T) { + var capturedHeaders http.Header + var capturedBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + body, err := io.ReadAll(r.Body) + if err == nil { + capturedBody = body + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + mockProvider := &mockCredentialsProvider{ + accessKey: "MOCK_ID", + secretKey: "MOCK_KEY", + } + + conn.AddInterceptor(SerializeRequest()) + conn.AddInterceptor(SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", mockProvider)) + + rs, err := conn.submit(&request{gremlin: "g.V().count()", fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() // drain + + // SigV4 should have added Authorization and X-Amz-Date headers + assert.NotEmpty(t, capturedHeaders.Get("Authorization"), + "SigV4Auth should set Authorization header after SerializeRequest") + assert.NotEmpty(t, capturedHeaders.Get("X-Amz-Date"), + "SigV4Auth should set X-Amz-Date header") + assert.Contains(t, capturedHeaders.Get("Authorization"), "AWS4-HMAC-SHA256", + "Authorization header should use AWS4-HMAC-SHA256 signing algorithm") + + // Body should be valid serialized bytes + assert.NotEmpty(t, capturedBody, "body should be non-empty serialized bytes") + assert.Equal(t, byte(0x81), capturedBody[0], + "body should start with GraphBinary version byte 0x81") +} + +// TestMultipleInterceptors_SerializeThenAuth verifies that a custom interceptor can +// modify the raw request, then SerializeRequest serializes it, then BasicAuth adds headers. +func TestMultipleInterceptors_SerializeThenAuth(t *testing.T) { + var capturedAuthHeader string + var capturedBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuthHeader = r.Header.Get("Authorization") + body, err := io.ReadAll(r.Body) + if err == nil { + capturedBody = body + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + // Custom interceptor that modifies the raw request fields + conn.AddInterceptor(func(req *HttpRequest) error { + r, ok := req.Body.(*request) + if !ok { + return fmt.Errorf("expected *request, got %T", req.Body) + } + // Add a custom field to the request + r.fields["customField"] = "customValue" + return nil + }) + + // SerializeRequest converts the modified *request to []byte + conn.AddInterceptor(SerializeRequest()) + + // BasicAuth adds the Authorization header (works on any body type) + conn.AddInterceptor(BasicAuth("admin", "secret")) + + rs, err := conn.submit(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() // drain + + // BasicAuth should have set the Authorization header + assert.Equal(t, "Basic YWRtaW46c2VjcmV0", capturedAuthHeader, + "Authorization header should be Basic base64(admin:secret)") + + // Body should be valid serialized bytes (from SerializeRequest) + assert.NotEmpty(t, capturedBody, "body should be non-empty serialized bytes") + assert.Equal(t, byte(0x81), capturedBody[0], + "body should start with GraphBinary version byte 0x81") +} + +// TestInterceptor_IoReaderBody verifies that an interceptor can set Body to an io.Reader +// and the request is sent correctly. +func TestInterceptor_IoReaderBody(t *testing.T) { + var capturedBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err == nil { + capturedBody = body + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + customPayload := []byte("custom binary payload") + + // Interceptor replaces Body with an io.Reader + conn.AddInterceptor(func(req *HttpRequest) error { + req.Body = bytes.NewReader(customPayload) + return nil + }) + + rs, err := conn.submit(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() // drain + + // The server should receive the custom payload from the io.Reader + assert.Equal(t, customPayload, capturedBody, + "server should receive the custom payload set via io.Reader") +} + +// TestInterceptor_NilSerializerNoSerialization verifies that when serializer is nil +// and no interceptor serializes, the correct error message is produced. +func TestInterceptor_NilSerializerNoSerialization(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + conn.serializer = nil // explicitly nil serializer + + rs, err := conn.submit(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + require.NoError(t, err) + + _, _ = rs.All() // drain — this triggers the async executeAndStream + rsErr := rs.GetError() + require.Error(t, rsErr, "should get an error when serializer is nil and no interceptor serializes") + assert.Contains(t, rsErr.Error(), "request body was not serialized", + "error message should indicate the body was not serialized") +} + +// TestInterceptor_HttpRequestBody verifies that an interceptor can set Body to *http.Request +// and the driver sends it directly, using the *http.Request's headers and body instead of +// HttpRequest.Headers. +func TestInterceptor_HttpRequestBody(t *testing.T) { + var capturedHeaders http.Header + var capturedBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header.Clone() + body, err := io.ReadAll(r.Body) + if err == nil { + capturedBody = body + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + customBody := []byte("custom-http-request-body") + + // Interceptor builds a complete *http.Request and sets it as Body + conn.AddInterceptor(func(req *HttpRequest) error { + httpGoReq, err := http.NewRequest(http.MethodPost, req.URL.String(), bytes.NewReader(customBody)) + if err != nil { + return err + } + httpGoReq.Header.Set("X-Custom-Header", "custom-value") + httpGoReq.Header.Set("Content-Type", "application/octet-stream") + req.Body = httpGoReq + return nil + }) + + // Also set a header on HttpRequest.Headers that should NOT appear, + // because *http.Request body bypasses HttpRequest.Headers + conn.AddInterceptor(func(req *HttpRequest) error { + req.Headers.Set("X-Should-Not-Appear", "ignored") + return nil + }) + + rs, err := conn.submit(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() // drain + + // The server should receive headers from the *http.Request, not from HttpRequest.Headers + assert.Equal(t, "custom-value", capturedHeaders.Get("X-Custom-Header"), + "server should receive custom header from *http.Request") + assert.Equal(t, "application/octet-stream", capturedHeaders.Get("Content-Type"), + "server should receive Content-Type from *http.Request") + assert.Empty(t, capturedHeaders.Get("X-Should-Not-Appear"), + "headers set on HttpRequest.Headers should not appear when Body is *http.Request") + + // The server should receive the body from the *http.Request + assert.Equal(t, customBody, capturedBody, + "server should receive body from the *http.Request") +} + +// TestInterceptor_ErrorPropagation verifies that when an interceptor returns an error, +// it is propagated to the ResultSet. +func TestInterceptor_ErrorPropagation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + conn.AddInterceptor(func(req *HttpRequest) error { + return fmt.Errorf("interceptor failed") + }) + + rs, err := conn.submit(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + require.NoError(t, err) + + _, _ = rs.All() // drain — triggers async executeAndStream + rsErr := rs.GetError() + require.Error(t, rsErr, "interceptor error should propagate to ResultSet") + assert.Contains(t, rsErr.Error(), "interceptor failed", + "ResultSet error should contain the interceptor's error message") +} + +// TestInterceptor_UnsupportedBodyType verifies that setting Body to an unsupported type +// (e.g., an int) produces the "unsupported body type" error. +func TestInterceptor_UnsupportedBodyType(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + // Interceptor sets Body to an unsupported type + conn.AddInterceptor(func(req *HttpRequest) error { + req.Body = 42 + return nil + }) + + rs, err := conn.submit(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + require.NoError(t, err) + + _, _ = rs.All() // drain + rsErr := rs.GetError() + require.Error(t, rsErr, "unsupported body type should produce an error") + assert.Contains(t, rsErr.Error(), "unsupported body type", + "error message should indicate unsupported body type") +} + +// TestInterceptor_ChainOrder verifies that interceptors run in the order they are added. +func TestInterceptor_ChainOrder(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + var order []int + + conn.AddInterceptor(func(req *HttpRequest) error { + order = append(order, 1) + return nil + }) + conn.AddInterceptor(func(req *HttpRequest) error { + order = append(order, 2) + return nil + }) + conn.AddInterceptor(func(req *HttpRequest) error { + order = append(order, 3) + return nil + }) + + rs, err := conn.submit(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() // drain + + assert.Equal(t, []int{1, 2, 3}, order, + "interceptors should run in the order they were added") +} + +// TestSigV4Auth_RejectsNonByteBody verifies that SigV4Auth returns an error when Body +// is not []byte (e.g., an unserialized *request). +func TestSigV4Auth_RejectsNonByteBody(t *testing.T) { + provider := &mockCredentialsProvider{ + accessKey: "MOCK_ID", + secretKey: "MOCK_KEY", + } + interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider) + + req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin") + require.NoError(t, err) + req.Headers.Set("Content-Type", graphBinaryMimeType) + req.Headers.Set("Accept", graphBinaryMimeType) + + // Set Body to *request (not []byte) — SigV4Auth should reject this + req.Body = &request{gremlin: "g.V()", fields: map[string]interface{}{}} + + err = interceptor(req) + require.Error(t, err, "SigV4Auth should reject non-[]byte body") + assert.Contains(t, err.Error(), "SigV4 signing requires serialized body bytes", + "error message should indicate SigV4 requires serialized body bytes") +}
