This is an automated email from the ASF dual-hosted git repository. xiazcy pushed a commit to branch go-http-streaming in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit f3f34887802f47cb27204f77a69d85141ad570c4 Author: Yang Xia <[email protected]> AuthorDate: Mon Jan 19 16:03:51 2026 -0800 removed old httpProtocol/httpTransport code, removed authInfo as it's replaced by interceptors, and added basic & sigv4 auth ref --- gremlin-go/driver/auth.go | 85 +++++++++++ gremlin-go/driver/authInfo.go | 95 ------------ gremlin-go/driver/auth_test.go | 114 ++++++++++++++ gremlin-go/driver/client.go | 3 - gremlin-go/driver/client_test.go | 8 - gremlin-go/driver/connection.go | 1 - gremlin-go/driver/connection_test.go | 61 +++----- gremlin-go/driver/driverRemoteConnection.go | 3 - gremlin-go/driver/driverRemoteConnection_test.go | 22 +-- gremlin-go/driver/httpConnection.go | 62 +++----- gremlin-go/driver/httpConnection_test.go | 67 ++------ gremlin-go/driver/httpProtocol.go | 134 ---------------- gremlin-go/driver/httpTransporter.go | 186 ----------------------- gremlin-go/driver/strategies_test.go | 57 ++++--- gremlin-go/driver/traversal_test.go | 4 - gremlin-go/go.mod | 14 ++ gremlin-go/go.sum | 28 ++++ 17 files changed, 332 insertions(+), 612 deletions(-) diff --git a/gremlin-go/driver/auth.go b/gremlin-go/driver/auth.go new file mode 100644 index 0000000000..ab4a88cd89 --- /dev/null +++ b/gremlin-go/driver/auth.go @@ -0,0 +1,85 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "context" + "encoding/base64" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" +) + +// BasicAuth returns a RequestInterceptor that adds Basic authentication header. +func BasicAuth(username, password string) RequestInterceptor { + encoded := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + return func(req *HttpRequest) error { + req.Headers.Set(HeaderAuthorization, "Basic "+encoded) + return nil + } +} + +// Sigv4Auth returns a RequestInterceptor that signs requests using AWS SigV4. +// It uses the default AWS credential chain (env vars, shared config, IAM role, etc.) +func Sigv4Auth(region, service string) RequestInterceptor { + return Sigv4AuthWithCredentials(region, service, nil) +} + +// Sigv4AuthWithCredentials returns a RequestInterceptor that signs requests using AWS SigV4 +// with the provided credentials provider. If provider is nil, uses default credential chain. +func Sigv4AuthWithCredentials(region, service string, credentialsProvider aws.CredentialsProvider) RequestInterceptor { + return func(req *HttpRequest) error { + ctx := context.Background() + + creds, err := resolveCredentials(ctx, region, credentialsProvider) + if err != nil { + return err + } + + signer := v4.NewSigner() + stdReq := req.ToStdRequest() + stdReq.Body = nil // Body is handled separately via payload hash + + if err := signer.SignHTTP(ctx, creds, stdReq, req.PayloadHash(), service, region, time.Now()); err != nil { + return err + } + + // Copy signed headers back to HttpRequest + for k, v := range stdReq.Header { + req.Headers[k] = v + } + + return nil + } +} + +func resolveCredentials(ctx context.Context, region string, provider aws.CredentialsProvider) (aws.Credentials, error) { + if provider != nil { + return provider.Retrieve(ctx) + } + + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + return aws.Credentials{}, err + } + return cfg.Credentials.Retrieve(ctx) +} diff --git a/gremlin-go/driver/authInfo.go b/gremlin-go/driver/authInfo.go deleted file mode 100644 index 0671136470..0000000000 --- a/gremlin-go/driver/authInfo.go +++ /dev/null @@ -1,95 +0,0 @@ -/* -Licensed to the Apache Software Foundation (ASF) under one -or more contributor license agreements. See the NOTICE file -distributed with this work for additional information -regarding copyright ownership. The ASF licenses this file -to you under the Apache License, Version 2.0 (the -"License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, -software distributed under the License is distributed on an -"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -KIND, either express or implied. See the License for the -specific language governing permissions and limitations -under the License. -*/ - -package gremlingo - -import "net/http" - -// AuthInfoProvider is an interface that allows authentication information to be specified. -type AuthInfoProvider interface { - GetHeader() http.Header - GetBasicAuth() (ok bool, username, password string) -} - -// AuthInfo is an option struct that allows authentication information to be specified statically. -// Authentication can be provided via http.Header directly. -// Basic authentication can also be used via the BasicAuthInfo function. -type AuthInfo struct { - Header http.Header - Username string - Password string -} - -var _ AuthInfoProvider = (*AuthInfo)(nil) - -// GetHeader provides a safe way to get a header from the AuthInfo even if it is nil. -// This way we don't need any additional logic in the transport layer. -func (authInfo *AuthInfo) GetHeader() http.Header { - if authInfo == nil { - return nil - } else { - return authInfo.Header - } -} - -// GetBasicAuth provides a safe way to check if basic auth info is available from the AuthInfo even if it is nil. -// This way we don't need any additional logic in the transport layer. -func (authInfo *AuthInfo) GetBasicAuth() (bool, string, string) { - if authInfo == nil || (authInfo.Username == "" && authInfo.Password == "") { - return false, "", "" - } - return true, authInfo.Username, authInfo.Password -} - -// BasicAuthInfo provides a way to generate AuthInfo. Enter username and password and get the AuthInfo back. -func BasicAuthInfo(username string, password string) *AuthInfo { - return &AuthInfo{Username: username, Password: password} -} - -// HeaderAuthInfo provides a way to generate AuthInfo with only Header information. -func HeaderAuthInfo(header http.Header) *AuthInfo { - return &AuthInfo{Header: header} -} - -// DynamicAuth is an AuthInfoProvider that allows dynamic credential generation. -type DynamicAuth struct { - fn func() AuthInfoProvider -} - -var ( - _ AuthInfoProvider = (*DynamicAuth)(nil) - - // NoopAuthInfo is a no-op AuthInfoProvider that can be used to disable authentication. - NoopAuthInfo = NewDynamicAuth(func() AuthInfoProvider { return &AuthInfo{} }) -) - -// NewDynamicAuth provides a way to generate dynamic credentials with the specified generator function. -func NewDynamicAuth(f func() AuthInfoProvider) *DynamicAuth { - return &DynamicAuth{fn: f} -} - -// GetHeader calls the stored function to get the header dynamically. -func (d *DynamicAuth) GetHeader() http.Header { - return d.fn().GetHeader() -} - -// GetBasicAuth calls the stored function to get basic authentication dynamically. -func (d *DynamicAuth) GetBasicAuth() (bool, string, string) { - return d.fn().GetBasicAuth() -} diff --git a/gremlin-go/driver/auth_test.go b/gremlin-go/driver/auth_test.go new file mode 100644 index 0000000000..c0cce3ac9a --- /dev/null +++ b/gremlin-go/driver/auth_test.go @@ -0,0 +1,114 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "context" + "encoding/base64" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/stretchr/testify/assert" +) + +func createMockRequest() *HttpRequest { + req, _ := NewHttpRequest("POST", "https://localhost:8182/gremlin") + req.Headers.Set("Content-Type", graphBinaryMimeType) + req.Headers.Set("Accept", graphBinaryMimeType) + req.Body = []byte(`{"gremlin":"g.V()"}`) + return req +} + +func TestBasicAuth(t *testing.T) { + t.Run("adds authorization header", func(t *testing.T) { + req := createMockRequest() + assert.Empty(t, req.Headers.Get(HeaderAuthorization)) + + interceptor := BasicAuth("username", "password") + err := interceptor(req) + + assert.NoError(t, err) + authHeader := req.Headers.Get(HeaderAuthorization) + assert.True(t, strings.HasPrefix(authHeader, "Basic ")) + + // Verify encoding + encoded := strings.TrimPrefix(authHeader, "Basic ") + decoded, err := base64.StdEncoding.DecodeString(encoded) + assert.NoError(t, err) + assert.Equal(t, "username:password", string(decoded)) + }) +} + +// mockCredentialsProvider implements aws.CredentialsProvider for testing +type mockCredentialsProvider struct { + accessKey string + secretKey string + sessionToken string +} + +func (m *mockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: m.accessKey, + SecretAccessKey: m.secretKey, + SessionToken: m.sessionToken, + }, nil +} + +func TestSigv4Auth(t *testing.T) { + t.Run("adds signed headers", func(t *testing.T) { + req := createMockRequest() + assert.Empty(t, req.Headers.Get("Authorization")) + assert.Empty(t, req.Headers.Get("X-Amz-Date")) + + provider := &mockCredentialsProvider{ + accessKey: "MOCK_ACCESS_KEY", + secretKey: "MOCK_SECRET_KEY", + } + interceptor := Sigv4AuthWithCredentials("us-west-2", "neptune-db", provider) + err := interceptor(req) + + assert.NoError(t, err) + assert.NotEmpty(t, req.Headers.Get("X-Amz-Date")) + authHeader := req.Headers.Get("Authorization") + assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential=MOCK_ACCESS_KEY")) + assert.Contains(t, authHeader, "us-west-2/neptune-db/aws4_request") + assert.Contains(t, authHeader, "Signature=") + }) + + t.Run("adds session token when provided", func(t *testing.T) { + req := createMockRequest() + assert.Empty(t, req.Headers.Get("X-Amz-Security-Token")) + + provider := &mockCredentialsProvider{ + accessKey: "MOCK_ACCESS_KEY", + secretKey: "MOCK_SECRET_KEY", + sessionToken: "MOCK_SESSION_TOKEN", + } + interceptor := Sigv4AuthWithCredentials("us-west-2", "neptune-db", provider) + err := interceptor(req) + + assert.NoError(t, err) + assert.Equal(t, "MOCK_SESSION_TOKEN", req.Headers.Get("X-Amz-Security-Token")) + authHeader := req.Headers.Get("Authorization") + assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential=")) + assert.Contains(t, authHeader, "Signature=") + }) +} diff --git a/gremlin-go/driver/client.go b/gremlin-go/driver/client.go index 68277fdf7f..995a0259fc 100644 --- a/gremlin-go/driver/client.go +++ b/gremlin-go/driver/client.go @@ -36,7 +36,6 @@ type ClientSettings struct { LogVerbosity LogVerbosity Logger Logger Language language.Tag - AuthInfo AuthInfoProvider TlsConfig *tls.Config ConnectionTimeout time.Duration EnableCompression bool @@ -72,7 +71,6 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C LogVerbosity: Info, Logger: &defaultLogger{}, Language: language.English, - AuthInfo: &AuthInfo{}, TlsConfig: &tls.Config{}, ConnectionTimeout: connectionTimeoutDefault, EnableCompression: false, @@ -85,7 +83,6 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C } connSettings := &connectionSettings{ - authInfo: settings.AuthInfo, tlsConfig: settings.TlsConfig, connectionTimeout: settings.ConnectionTimeout, enableCompression: settings.EnableCompression, diff --git a/gremlin-go/driver/client_test.go b/gremlin-go/driver/client_test.go index 9e223dd193..fdfd1204da 100644 --- a/gremlin-go/driver/client_test.go +++ b/gremlin-go/driver/client_test.go @@ -30,7 +30,6 @@ func TestClient(t *testing.T) { // Integration test variables. testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) testNoAuthEnable := getEnvOrDefaultBool("RUN_INTEGRATION_TESTS", true) - testNoAuthAuthInfo := &AuthInfo{} testNoAuthTlsConfig := &tls.Config{} t.Run("Test client.SubmitWithOptions()", func(t *testing.T) { @@ -38,7 +37,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo }) assert.NoError(t, err) assert.NotNil(t, client) @@ -58,7 +56,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo }) assert.NoError(t, err) assert.NotNil(t, client) @@ -74,7 +71,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo settings.TraversalSource = testServerModernGraphAlias }) assert.NoError(t, err) @@ -97,7 +93,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo settings.TraversalSource = testServerModernGraphAlias }) assert.NoError(t, err) @@ -122,7 +117,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo settings.TraversalSource = testServerModernGraphAlias }) assert.NoError(t, err) @@ -147,7 +141,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo settings.TraversalSource = testServerModernGraphAlias }) @@ -170,7 +163,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo settings.TraversalSource = testServerCrewGraphAlias }) diff --git a/gremlin-go/driver/connection.go b/gremlin-go/driver/connection.go index 8960ff43c4..f04997eeb0 100644 --- a/gremlin-go/driver/connection.go +++ b/gremlin-go/driver/connection.go @@ -25,7 +25,6 @@ import ( ) type connectionSettings struct { - authInfo AuthInfoProvider tlsConfig *tls.Config connectionTimeout time.Duration enableCompression bool diff --git a/gremlin-go/driver/connection_test.go b/gremlin-go/driver/connection_test.go index dcd56f8e40..8dee3d6567 100644 --- a/gremlin-go/driver/connection_test.go +++ b/gremlin-go/driver/connection_test.go @@ -56,7 +56,6 @@ var testNames = []string{"Lyndon", "Yang", "Simon", "Rithin", "Alexey", "Valenty func newDefaultConnectionSettings() *connectionSettings { return &connectionSettings{ - authInfo: &AuthInfo{}, tlsConfig: &tls.Config{}, connectionTimeout: connectionTimeoutDefault, enableCompression: false, @@ -87,11 +86,10 @@ func addTestData(t *testing.T, g *GraphTraversalSource) { assert.Nil(t, <-promise) } -func getTestGraph(t *testing.T, url string, auth AuthInfoProvider, tls *tls.Config) *GraphTraversalSource { +func getTestGraph(t *testing.T, url string, tls *tls.Config) *GraphTraversalSource { remote, err := NewDriverRemoteConnection(url, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = tls - settings.AuthInfo = auth settings.TraversalSource = testServerGraphAlias }) assert.Nil(t, err) @@ -101,8 +99,8 @@ func getTestGraph(t *testing.T, url string, auth AuthInfoProvider, tls *tls.Conf return g } -func initializeGraph(t *testing.T, url string, auth AuthInfoProvider, tls *tls.Config) *GraphTraversalSource { - g := getTestGraph(t, url, auth, tls) +func initializeGraph(t *testing.T, url string, tls *tls.Config) *GraphTraversalSource { + g := getTestGraph(t, url, tls) // Drop the graph and check that it is empty. dropGraph(t, g) @@ -241,11 +239,6 @@ func getEnvOrDefaultBool(key string, defaultValue bool) bool { return defaultValue } -func getBasicAuthInfo() *AuthInfo { - return BasicAuthInfo(getEnvOrDefaultString("GREMLIN_GO_BASIC_AUTH_USERNAME", "stephen"), - getEnvOrDefaultString("GREMLIN_GO_BASIC_AUTH_PASSWORD", "password")) -} - func skipTestsIfNotEnabled(t *testing.T, testSuiteName string, testSuiteEnabled bool) { if !testSuiteEnabled { t.Skipf("Skipping %s because %s tests are not enabled.", t.Name(), testSuiteName) @@ -256,19 +249,18 @@ func TestConnection(t *testing.T) { // Integration test variables. testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) testNoAuthEnable := getEnvOrDefaultBool("RUN_INTEGRATION_TESTS", true) - testNoAuthAuthInfo := &AuthInfo{} testNoAuthTlsConfig := &tls.Config{} // No authentication integration test with graphs loaded and alias configured server testNoAuthWithAliasUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) testNoAuthWithAliasEnable := getEnvOrDefaultBool("RUN_INTEGRATION_WITH_ALIAS_TESTS", true) - testNoAuthWithAliasAuthInfo := &AuthInfo{} testNoAuthWithAliasTlsConfig := &tls.Config{} // Basic authentication integration test variables. testBasicAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_BASIC_AUTH_URL", basicAuthWithSsl) testBasicAuthEnable := getEnvOrDefaultBool("RUN_BASIC_AUTH_INTEGRATION_TESTS", false) - testBasicAuthAuthInfo := getBasicAuthInfo() + testBasicAuthUsername := getEnvOrDefaultString("GREMLIN_GO_BASIC_AUTH_USERNAME", "stephen") + testBasicAuthPassword := getEnvOrDefaultString("GREMLIN_GO_BASIC_AUTH_PASSWORD", "password") testBasicAuthTlsConfig := &tls.Config{InsecureSkipVerify: true} // this test is used to test the ws->http POC changes via manual execution with a local TP 4.0 gremlin server running on 8182 @@ -283,7 +275,6 @@ func TestConnection(t *testing.T) { //client, err := NewClient(noAuthSslUrl, func(settings *ClientSettings) { settings.TlsConfig = &tlsConf - settings.AuthInfo = testNoAuthAuthInfo settings.EnableCompression = true settings.TraversalSource = testServerModernGraphAlias }) @@ -314,7 +305,6 @@ func TestConnection(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo settings.EnableCompression = true settings.TraversalSource = testServerModernGraphAlias }) @@ -343,7 +333,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() // Read test data out of the graph and check that it is correct. @@ -358,7 +348,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() readWithNextAndHasNext(t, g) @@ -369,7 +359,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() resultSet, err := g.V().HasLabel(personLabel).Properties(nameKey).GetResultSet() @@ -393,7 +383,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := getTestGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := getTestGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() // Drop the graph. @@ -438,7 +428,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() // Read test data out of the graph and check that it is correct. @@ -454,7 +444,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() // Run traversal and test Next/HasNext calls @@ -480,7 +470,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() readUsingAnonymousTraversal(t, g) @@ -509,7 +499,9 @@ func TestConnection(t *testing.T) { remote, err := NewDriverRemoteConnection(testBasicAuthUrl, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = testBasicAuthTlsConfig - settings.AuthInfo = testBasicAuthAuthInfo + settings.RequestInterceptors = []RequestInterceptor{ + BasicAuth(testBasicAuthUsername, testBasicAuthPassword), + } }) assert.Nil(t, err) assert.NotNil(t, remote) @@ -535,7 +527,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() r, err := g.WithSack(1).V().Has("name", "Lyndon").Values("foo").Sack(Operator.Sum).Sack().ToList() @@ -554,7 +546,7 @@ func TestConnection(t *testing.T) { // skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // // // Initialize graph - // g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + // g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) // defer g.remoteConnection.Close() // // r, err := g.V().Has("name", "Lyndon").Values("foo").Profile().ToList() @@ -572,7 +564,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() prop := &BigDecimal{11, big.NewInt(int64(22))} @@ -592,7 +584,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() prop := &ByteBuffer{[]byte{byte(127), byte(255)}} @@ -612,7 +604,6 @@ func TestConnection(t *testing.T) { remote, err := NewDriverRemoteConnection(testNoAuthWithAliasUrl, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = testNoAuthWithAliasTlsConfig - settings.AuthInfo = testNoAuthWithAliasAuthInfo settings.TraversalSource = testServerModernGraphAlias }) assert.Nil(t, err) @@ -632,7 +623,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) // Drop the graph. dropGraph(t, g) @@ -640,8 +631,7 @@ func TestConnection(t *testing.T) { // Add vertices and edges to graph. rs, err := g.AddV("person").Property("id", T__.Unfold().Property().AddV()).ToList() assert.Nil(t, rs) - fmt.Println(err.Error()) - assert.True(t, isSameErrorCode(newError(err0502ResponseHandlerReadLoopError), err)) + assert.True(t, isSameErrorCode(newError(err0502ResponseHandlerError), err)) rs, err = g.V().Count().ToList() assert.NotNil(t, rs) @@ -654,7 +644,7 @@ func TestConnection(t *testing.T) { t.Run("Get all properties when materializeProperties is all", func(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() // vertex contains 2 properties, name and age @@ -667,7 +657,7 @@ func TestConnection(t *testing.T) { t.Run("Skip properties when materializeProperties is tokens", func(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() // vertex contains 2 properties, name and age @@ -680,7 +670,7 @@ func TestConnection(t *testing.T) { t.Run("Get all properties when no materializeProperties", func(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() r, err := g.V().Has("person", "name", "marko").Next() @@ -692,7 +682,7 @@ func TestConnection(t *testing.T) { t.Run("Test DriverRemoteConnection Traversal With materializeProperties in Modern Graph", func(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() vertices, err := g.With("materializeProperties", MaterializeProperties.Tokens).V().ToList() @@ -800,7 +790,6 @@ func TestStreamingResultDelivery(t *testing.T) { remote, err := NewDriverRemoteConnection(getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl), func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = &tls.Config{} - settings.AuthInfo = &AuthInfo{} settings.TraversalSource = "ggrateful" }) assert.Nil(t, err) diff --git a/gremlin-go/driver/driverRemoteConnection.go b/gremlin-go/driver/driverRemoteConnection.go index ccf726ff29..0cf18bd729 100644 --- a/gremlin-go/driver/driverRemoteConnection.go +++ b/gremlin-go/driver/driverRemoteConnection.go @@ -33,7 +33,6 @@ type DriverRemoteConnectionSettings struct { LogVerbosity LogVerbosity Logger Logger Language language.Tag - AuthInfo AuthInfoProvider TlsConfig *tls.Config ConnectionTimeout time.Duration EnableCompression bool @@ -65,7 +64,6 @@ func NewDriverRemoteConnection( LogVerbosity: Info, Logger: &defaultLogger{}, Language: language.English, - AuthInfo: &AuthInfo{}, TlsConfig: &tls.Config{}, ConnectionTimeout: connectionTimeoutDefault, EnableCompression: false, @@ -78,7 +76,6 @@ func NewDriverRemoteConnection( } connSettings := &connectionSettings{ - authInfo: settings.AuthInfo, tlsConfig: settings.TlsConfig, connectionTimeout: settings.ConnectionTimeout, enableCompression: settings.EnableCompression, diff --git a/gremlin-go/driver/driverRemoteConnection_test.go b/gremlin-go/driver/driverRemoteConnection_test.go index d702d1809e..bdae0c0aa8 100644 --- a/gremlin-go/driver/driverRemoteConnection_test.go +++ b/gremlin-go/driver/driverRemoteConnection_test.go @@ -20,7 +20,6 @@ under the License. package gremlingo import ( - "net/http" "testing" "github.com/stretchr/testify/assert" @@ -28,20 +27,11 @@ import ( func TestAuthentication(t *testing.T) { - t.Run("Test BasicAuthInfo.", func(t *testing.T) { - header := BasicAuthInfo("Lyndon", "Bauto") - assert.Nil(t, header.GetHeader()) - b, _, _ := header.GetBasicAuth() - assert.True(t, b) - }) - - t.Run("Test GetHeader.", func(t *testing.T) { - header := &AuthInfo{} - assert.Nil(t, header.GetHeader()) - header = nil - assert.Nil(t, header.GetHeader()) - httpHeader := http.Header{} - header = &AuthInfo{Header: httpHeader} - assert.Equal(t, httpHeader, header.GetHeader()) + t.Run("Test BasicAuth interceptor", func(t *testing.T) { + interceptor := BasicAuth("user", "pass") + req, _ := NewHttpRequest("POST", "http://localhost:8182/gremlin") + err := interceptor(req) + assert.Nil(t, err) + assert.Contains(t, req.Headers.Get(HeaderAuthorization), "Basic ") }) } diff --git a/gremlin-go/driver/httpConnection.go b/gremlin-go/driver/httpConnection.go index 5f4d74166b..b96a6211ef 100644 --- a/gremlin-go/driver/httpConnection.go +++ b/gremlin-go/driver/httpConnection.go @@ -22,7 +22,8 @@ package gremlingo import ( "bytes" "compress/zlib" - "encoding/base64" + "crypto/sha256" + "encoding/hex" "io" "net" "net/http" @@ -60,6 +61,22 @@ func NewHttpRequest(method, rawURL string) (*HttpRequest, error) { }, nil } +// ToStdRequest converts HttpRequest to a standard http.Request for signing. +func (r *HttpRequest) ToStdRequest() *http.Request { + req, _ := http.NewRequest(r.Method, r.URL.String(), bytes.NewReader(r.Body)) + req.Header = r.Headers + return req +} + +// PayloadHash returns the SHA256 hash of the request body for SigV4 signing. +func (r *HttpRequest) PayloadHash() string { + if len(r.Body) == 0 { + return "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of empty string + } + h := sha256.Sum256(r.Body) + return hex.EncodeToString(h[:]) +} + // RequestInterceptor is a function that modifies an HTTP request before it is sent. type RequestInterceptor func(*HttpRequest) error @@ -193,49 +210,6 @@ func (c *httpConnection) setHttpRequestHeaders(req *HttpRequest) { if c.connSettings.enableCompression { req.Headers.Set(HeaderAcceptEncoding, "deflate") } - if c.connSettings.authInfo != nil { - if headers := c.connSettings.authInfo.GetHeader(); headers != nil { - for k, vals := range headers { - for _, v := range vals { - req.Headers.Add(k, v) - } - } - } - if ok, user, pass := c.connSettings.authInfo.GetBasicAuth(); ok { - req.Headers.Set(HeaderAuthorization, "Basic "+basicAuth(user, pass)) - } - } -} - -// basicAuth encodes username and password for Basic auth header -func basicAuth(username, password string) string { - auth := username + ":" + password - return base64.StdEncoding.EncodeToString([]byte(auth)) -} - -// setHeaders sets headers on http.Request (legacy, kept for compatibility) -func (c *httpConnection) setHeaders(req *http.Request) { - req.Header.Set("Content-Type", graphBinaryMimeType) - req.Header.Set("Accept", graphBinaryMimeType) - - if c.connSettings.enableUserAgentOnConnect { - req.Header.Set(userAgentHeader, userAgent) - } - if c.connSettings.enableCompression { - req.Header.Set("Accept-Encoding", "deflate") - } - if c.connSettings.authInfo != nil { - if headers := c.connSettings.authInfo.GetHeader(); headers != nil { - for k, vals := range headers { - for _, v := range vals { - req.Header.Add(k, v) - } - } - } - if ok, user, pass := c.connSettings.authInfo.GetBasicAuth(); ok { - req.SetBasicAuth(user, pass) - } - } } func (c *httpConnection) getReader(resp *http.Response) (io.Reader, io.Closer, error) { diff --git a/gremlin-go/driver/httpConnection_test.go b/gremlin-go/driver/httpConnection_test.go index f106cac66e..76d9146936 100644 --- a/gremlin-go/driver/httpConnection_test.go +++ b/gremlin-go/driver/httpConnection_test.go @@ -38,19 +38,11 @@ func newTestLogHandler() *logHandler { } func TestNewHttpConnection(t *testing.T) { - t.Run("applies default timeout when not set", func(t *testing.T) { + t.Run("creates connection with default settings", func(t *testing.T) { conn := newHttpConnection(newTestLogHandler(), "http://localhost:8182/gremlin", &connectionSettings{}) - assert.Equal(t, defaultConnectionTimeout, conn.httpClient.Timeout) - }) - - t.Run("uses provided timeout", func(t *testing.T) { - customTimeout := 30 * time.Second - conn := newHttpConnection(newTestLogHandler(), "http://localhost:8182/gremlin", &connectionSettings{ - connectionTimeout: customTimeout, - }) - - assert.Equal(t, customTimeout, conn.httpClient.Timeout) + assert.NotNil(t, conn.httpClient) + assert.NotNil(t, conn.httpClient.Transport) }) t.Run("applies TLS config", func(t *testing.T) { @@ -64,68 +56,37 @@ func TestNewHttpConnection(t *testing.T) { }) } -func TestSetHeaders(t *testing.T) { +func TestSetHttpRequestHeaders(t *testing.T) { t.Run("sets content type and accept headers", func(t *testing.T) { conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{}) - req, err := http.NewRequest(http.MethodPost, "http://localhost/gremlin", nil) - require.NoError(t, err) + req, _ := NewHttpRequest(http.MethodPost, "http://localhost/gremlin") - conn.setHeaders(req) + conn.setHttpRequestHeaders(req) - assert.Equal(t, graphBinaryMimeType, req.Header.Get("Content-Type")) - assert.Equal(t, graphBinaryMimeType, req.Header.Get("Accept")) + assert.Equal(t, graphBinaryMimeType, req.Headers.Get("Content-Type")) + assert.Equal(t, graphBinaryMimeType, req.Headers.Get("Accept")) }) t.Run("sets user agent when enabled", func(t *testing.T) { conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{ enableUserAgentOnConnect: true, }) - req, err := http.NewRequest(http.MethodPost, "http://localhost/gremlin", nil) - require.NoError(t, err) + req, _ := NewHttpRequest(http.MethodPost, "http://localhost/gremlin") - conn.setHeaders(req) + conn.setHttpRequestHeaders(req) - assert.NotEmpty(t, req.Header.Get(userAgentHeader)) + assert.NotEmpty(t, req.Headers.Get(HeaderUserAgent)) }) t.Run("sets compression header when enabled", func(t *testing.T) { conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{ enableCompression: true, }) - req, err := http.NewRequest(http.MethodPost, "http://localhost/gremlin", nil) - require.NoError(t, err) - - conn.setHeaders(req) - - assert.Equal(t, "deflate", req.Header.Get("Accept-Encoding")) - }) - - t.Run("sets basic auth when provided", func(t *testing.T) { - conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{ - authInfo: BasicAuthInfo("user", "pass"), - }) - req, err := http.NewRequest(http.MethodPost, "http://localhost/gremlin", nil) - require.NoError(t, err) - - conn.setHeaders(req) - - user, pass, ok := req.BasicAuth() - assert.True(t, ok) - assert.Equal(t, "user", user) - assert.Equal(t, "pass", pass) - }) - - t.Run("handles nil authInfo", func(t *testing.T) { - conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{ - authInfo: nil, - }) - req, err := http.NewRequest(http.MethodPost, "http://localhost/gremlin", nil) - require.NoError(t, err) + req, _ := NewHttpRequest(http.MethodPost, "http://localhost/gremlin") - conn.setHeaders(req) + conn.setHttpRequestHeaders(req) - _, _, ok := req.BasicAuth() - assert.False(t, ok) + assert.Equal(t, "deflate", req.Headers.Get("Accept-Encoding")) }) } diff --git a/gremlin-go/driver/httpProtocol.go b/gremlin-go/driver/httpProtocol.go deleted file mode 100644 index 18bf762923..0000000000 --- a/gremlin-go/driver/httpProtocol.go +++ /dev/null @@ -1,134 +0,0 @@ -/* -Licensed to the Apache Software Foundation (ASF) under one -or more contributor license agreements. See the NOTICE file -distributed with this work for additional information -regarding copyright ownership. The ASF licenses this file -to you under the Apache License, Version 2.0 (the -"License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, -software distributed under the License is distributed on an -"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -KIND, either express or implied. See the License for the -specific language governing permissions and limitations -under the License. -*/ - -package gremlingo - -import ( - "errors" - "net/http" -) - -// responsible for serializing and sending requests and then receiving and deserializing responses -type httpProtocol struct { - serializer *GraphBinarySerializer - logHandler *logHandler - url string - connSettings *connectionSettings - httpClient *http.Client -} - -func newHttpProtocol(handler *logHandler, url string, connSettings *connectionSettings) *httpProtocol { - transport := &http.Transport{ - TLSClientConfig: connSettings.tlsConfig, - MaxConnsPerHost: 0, // TODO - IdleConnTimeout: 0, // TODO - DisableCompression: !connSettings.enableCompression, - } - - httpClient := http.Client{ - Transport: transport, - Timeout: connSettings.connectionTimeout, - } - - httpProt := &httpProtocol{ - serializer: newGraphBinarySerializer(handler), - logHandler: handler, - url: url, - connSettings: connSettings, - httpClient: &httpClient, - } - return httpProt -} - -// sends a query request and returns a ResultSet that can be used to obtain query results -func (protocol *httpProtocol) send(request *request) (ResultSet, error) { - rs := newChannelResultSet() - bytes, err := protocol.serializer.SerializeMessage(request) - if err != nil { - rs.setError(err) - rs.Close() - return rs, err - } - - // one transport per request - transport := newHttpTransporter(protocol.url, protocol.connSettings, protocol.httpClient, protocol.logHandler) - - // async send request and receive response - go func() { - err := transport.Write(bytes) - if err != nil { - transport.Close() - rs.setError(err) - rs.Close() - return - } - - err = protocol.receiveChunkedResponse(rs, transport) - if err != nil { - rs.setError(err) - } - transport.Close() - }() - - return rs, nil -} - -// receiveChunkedResponse processes individual chunk responses -func (protocol *httpProtocol) receiveChunkedResponse(rs ResultSet, transport *httpTransporter) error { - for { - resp, err := transport.Read() - if err != nil { - if errors.Is(err, ErrResponseStreamClosed) { - rs.Close() - return nil - } - rs.Close() - return err - } - - endOfStream := false - if data, ok := resp.ResponseResult.Data.([]interface{}); ok { - for _, obj := range data { - if marker, ok := obj.(Marker); ok && marker == EndOfStream() { - endOfStream = true - break - } - - rs.Channel() <- &Result{obj} - } - } - - // Check status code (error status comes after EndOfStream) - if resp.ResponseStatus.code != 0 && resp.ResponseStatus.code != 200 { - rs.Close() - err := newError(err0502ResponseHandlerReadLoopError, resp.ResponseStatus, resp.ResponseStatus.code) - rs.setError(err) - return err - } - - if endOfStream { - rs.Close() - return nil - } - } -} - -func (protocol *httpProtocol) close() { - protocol.httpClient.CloseIdleConnections() -} diff --git a/gremlin-go/driver/httpTransporter.go b/gremlin-go/driver/httpTransporter.go deleted file mode 100644 index 9a9c5bb99b..0000000000 --- a/gremlin-go/driver/httpTransporter.go +++ /dev/null @@ -1,186 +0,0 @@ -/* -Licensed to the Apache Software Foundation (ASF) under one -or more contributor license agreements. See the NOTICE file -distributed with this work for additional information -regarding copyright ownership. The ASF licenses this file -to you under the Apache License, Version 2.0 (the -"License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, -software distributed under the License is distributed on an -"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -KIND, either express or implied. See the License for the -specific language governing permissions and limitations -under the License. -*/ - -package gremlingo - -import ( - "bytes" - "compress/zlib" - "errors" - "io" - "net/http" - "sync" - "time" -) - -// httpTransporter responsible for sending and receiving bytes to/from the server -type httpTransporter struct { - url string - isClosed bool - connSettings *connectionSettings - responseChannel chan Response // receives response messages - httpClient *http.Client - wg *sync.WaitGroup - logHandler *logHandler - closeOnce sync.Once -} - -func newHttpTransporter(url string, connSettings *connectionSettings, httpClient *http.Client, logHandler *logHandler) *httpTransporter { - wg := &sync.WaitGroup{} - - return &httpTransporter{ - url: url, - connSettings: connSettings, - responseChannel: make(chan Response, 10), - httpClient: httpClient, - wg: wg, - logHandler: logHandler, - closeOnce: sync.Once{}, - } -} - -// Write sends bytes to the server and starts streaming response processing -func (transporter *httpTransporter) Write(data []byte) error { - req, err := http.NewRequest("POST", transporter.url, bytes.NewBuffer(data)) - if err != nil { - transporter.logHandler.logf(Error, failedToSendRequest, err.Error()) - return err - } - req.Header.Set("content-type", graphBinaryMimeType) - req.Header.Set("accept", graphBinaryMimeType) - if transporter.connSettings.enableUserAgentOnConnect { - req.Header.Set(userAgentHeader, userAgent) - } - if transporter.connSettings.enableCompression { - req.Header.Set("accept-encoding", "deflate") - } - if transporter.connSettings.authInfo != nil { - // Add custom headers - if headers := transporter.connSettings.authInfo.GetHeader(); headers != nil { - for key, values := range headers { - for _, value := range values { - req.Header.Add(key, value) - } - } - } - - // Add basic auth - if ok, username, password := transporter.connSettings.authInfo.GetBasicAuth(); ok { - req.SetBasicAuth(username, password) - } - } - - resp, err := transporter.httpClient.Do(req) - if err != nil { - transporter.logHandler.logf(Error, failedToSendRequest, err.Error()) - return err - } - - reader := resp.Body - if resp.Header.Get("content-encoding") == "deflate" { - reader, err = zlib.NewReader(resp.Body) - if err != nil { - transporter.logHandler.logf(Error, failedToReceiveResponse, err.Error()) - err := resp.Body.Close() - if err != nil { - return err - } - return err - } - } - - // Start streaming processing in background - go transporter.streamResponse(reader, resp.Body) - return nil -} - -// streamResponse processes HTTP chunks independently -func (transporter *httpTransporter) streamResponse(reader io.Reader, body io.Closer) { - defer func(body io.Closer) { - err := body.Close() - if err != nil { - } - }(body) - defer transporter.closeResponseChannel() - - serializer := newGraphBinarySerializer(transporter.logHandler) - isFirstChunk := true - - chunk := make([]byte, transporter.connSettings.readBufferSize) - timer := time.NewTimer(5 * time.Second) - defer timer.Stop() - - for { - n, err := reader.Read(chunk) - if n > 0 { - msg, procErr := serializer.readChunk(chunk[:n], isFirstChunk) - if procErr != nil { - transporter.logHandler.logf(Error, failedToReceiveResponse, procErr.Error()) - return - } - isFirstChunk = false - - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - timer.Reset(5 * time.Second) - - select { - case transporter.responseChannel <- *msg: - case <-timer.C: - transporter.logHandler.logf(Error, failedToReceiveResponse, "timeout") - return - } - } - - if err == io.EOF { - break - } - if err != nil { - transporter.logHandler.logf(Error, failedToReceiveResponse, err.Error()) - return - } - } -} - -func (transporter *httpTransporter) closeResponseChannel() { - transporter.closeOnce.Do(func() { - close(transporter.responseChannel) - }) -} - -// Read reads response messages from the stream -func (transporter *httpTransporter) Read() (Response, error) { - resp, ok := <-transporter.responseChannel - if !ok { - return Response{}, errors.New("response stream closed") - } - return resp, nil -} - -// Close closes the transporter and its corresponding responseChannel -func (transporter *httpTransporter) Close() { - if !transporter.isClosed { - transporter.closeResponseChannel() - transporter.isClosed = true - } -} diff --git a/gremlin-go/driver/strategies_test.go b/gremlin-go/driver/strategies_test.go index ee1435890b..2a855be2d5 100644 --- a/gremlin-go/driver/strategies_test.go +++ b/gremlin-go/driver/strategies_test.go @@ -27,11 +27,10 @@ import ( "github.com/stretchr/testify/assert" ) -func getModernGraph(t *testing.T, url string, auth AuthInfoProvider, tls *tls.Config) *GraphTraversalSource { +func getModernGraph(t *testing.T, url string, tls *tls.Config) *GraphTraversalSource { remote, err := NewDriverRemoteConnection(url, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = tls - settings.AuthInfo = auth settings.TraversalSource = testServerModernGraphAlias }) assert.Nil(t, err) @@ -45,7 +44,7 @@ func TestStrategy(t *testing.T) { testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) t.Run("Test read with ConnectiveStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(ConnectiveStrategy()).V().Count().ToList() @@ -58,7 +57,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with OptionsStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(OptionsStrategy(map[string]interface{}{"a": "b"})).V().Count().ToList() @@ -71,7 +70,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with PartitionStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := PartitionStrategyConfig{ @@ -90,7 +89,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with SeedStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := SeedStrategyConfig{Seed: 1} @@ -104,7 +103,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with SubgraphStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := SubgraphStrategyConfig{ Vertices: T__.HasLabel(testLabel), @@ -130,7 +129,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with AdjacentToIncidentStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(AdjacentToIncidentStrategy()).V().Count().ToList() @@ -143,7 +142,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with ByModulatorOptimizationStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(ByModulatorOptimizationStrategy()).V().Count().ToList() @@ -156,7 +155,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with CountStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(CountStrategy()).V().Count().ToList() @@ -169,7 +168,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with EarlyLimitStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(EarlyLimitStrategy()).V().Count().ToList() @@ -182,7 +181,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with FilterRankingStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(FilterRankingStrategy()).V().Count().ToList() @@ -195,7 +194,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with IdentityRemovalStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(IdentityRemovalStrategy()).V().Count().ToList() @@ -208,7 +207,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with IncidentToAdjacentStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(IncidentToAdjacentStrategy()).V().Count().ToList() @@ -221,7 +220,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with InlineFilterStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(InlineFilterStrategy()).V().Count().ToList() @@ -234,7 +233,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with LazyBarrierStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(LazyBarrierStrategy()).V().Count().ToList() @@ -247,7 +246,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with MatchPredicateStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(MatchPredicateStrategy()).V().Count().ToList() @@ -260,7 +259,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with OrderLimitStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(OrderLimitStrategy()).V().Count().ToList() @@ -273,7 +272,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with PathProcessorStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(PathProcessorStrategy()).V().Count().ToList() @@ -286,7 +285,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with PathRetractionStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(PathRetractionStrategy()).V().Count().ToList() @@ -299,7 +298,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with ProductiveByStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := ProductiveByStrategyConfig{ProductiveKeys: []string{"a", "b"}} @@ -313,7 +312,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with RepeatUnrollStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(RepeatUnrollStrategy()).V().Count().ToList() @@ -326,7 +325,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with EdgeLabelVerificationStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := EdgeLabelVerificationStrategyConfig{ @@ -343,7 +342,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with LambdaRestrictionStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(LambdaRestrictionStrategy()).V().Count().ToList() @@ -356,7 +355,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with TestReadOnlyStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(ReadOnlyStrategy()).V().Count().ToList() @@ -369,7 +368,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test write with TestReadOnlyStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() promise := g.WithStrategies(ReadOnlyStrategy()).AddV("person").Property("name", "foo").Iterate() @@ -377,7 +376,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with ReservedKeysVerificationStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := ReservedKeysVerificationStrategyConfig{ @@ -396,7 +395,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with RepeatUnrollStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(RepeatUnrollStrategy()).V().Count().ToList() @@ -409,7 +408,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test without strategies MessagePassingReductionStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithoutStrategies(MessagePassingReductionStrategy()).V().Count().ToList() diff --git a/gremlin-go/driver/traversal_test.go b/gremlin-go/driver/traversal_test.go index 1fdf1a52d9..0efded3ac9 100644 --- a/gremlin-go/driver/traversal_test.go +++ b/gremlin-go/driver/traversal_test.go @@ -547,13 +547,11 @@ func TestTraversal(t *testing.T) { func newWithOptionsConnection(t *testing.T) *GraphTraversalSource { // No authentication integration test with graphs loaded and alias configured server testNoAuthWithAliasUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) - testNoAuthWithAliasAuthInfo := &AuthInfo{} testNoAuthWithAliasTlsConfig := &tls.Config{} remote, err := NewDriverRemoteConnection(testNoAuthWithAliasUrl, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = testNoAuthWithAliasTlsConfig - settings.AuthInfo = testNoAuthWithAliasAuthInfo settings.TraversalSource = "gmodern" }) assert.Nil(t, err) @@ -563,13 +561,11 @@ func newWithOptionsConnection(t *testing.T) *GraphTraversalSource { func newConnection(t *testing.T) *DriverRemoteConnection { testNoAuthWithAliasUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) - testNoAuthWithAliasAuthInfo := &AuthInfo{} testNoAuthWithAliasTlsConfig := &tls.Config{} remote, err := NewDriverRemoteConnection(testNoAuthWithAliasUrl, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = testNoAuthWithAliasTlsConfig - settings.AuthInfo = testNoAuthWithAliasAuthInfo settings.TraversalSource = "gtx" }) assert.Nil(t, err) diff --git a/gremlin-go/go.mod b/gremlin-go/go.mod index c36ea3718d..169beee051 100644 --- a/gremlin-go/go.mod +++ b/gremlin-go/go.mod @@ -28,6 +28,20 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2 v1.41.1 // indirect + github.com/aws/aws-sdk-go-v2/config v1.32.7 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.7 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.9 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 // indirect + github.com/aws/smithy-go v1.24.0 // indirect github.com/cucumber/gherkin/go/v26 v26.2.0 // indirect github.com/cucumber/messages/go/v21 v21.0.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/gremlin-go/go.sum b/gremlin-go/go.sum index 59675a75c5..cc7f1d6a7d 100644 --- a/gremlin-go/go.sum +++ b/gremlin-go/go.sum @@ -1,5 +1,33 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/aws/aws-sdk-go-v2 v1.41.1 h1:ABlyEARCDLN034NhxlRUSZr4l71mh+T5KAeGh6cerhU= +github.com/aws/aws-sdk-go-v2 v1.41.1/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2/config v1.32.7 h1:vxUyWGUwmkQ2g19n7JY/9YL8MfAIl7bTesIUykECXmY= +github.com/aws/aws-sdk-go-v2/config v1.32.7/go.mod h1:2/Qm5vKUU/r7Y+zUk/Ptt2MDAEKAfUtKc1+3U1Mo3oY= +github.com/aws/aws-sdk-go-v2/credentials v1.19.7 h1:tHK47VqqtJxOymRrNtUXN5SP/zUTvZKeLx4tH6PGQc8= +github.com/aws/aws-sdk-go-v2/credentials v1.19.7/go.mod h1:qOZk8sPDrxhf+4Wf4oT2urYJrYt3RejHSzgAquYeppw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17 h1:I0GyV8wiYrP8XpA70g1HBcQO1JlQxCMTW9npl5UbDHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17/go.mod h1:tyw7BOl5bBe/oqvoIeECFJjMdzXoa/dfVz3QQ5lgHGA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 h1:xOLELNKGp2vsiteLsvLPwxC+mYmO6OZ8PYgiuPJzF8U= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17/go.mod h1:5M5CI3D12dNOtH3/mk6minaRwI2/37ifCURZISxA/IQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 h1:WWLqlh79iO48yLkj1v3ISRNiv+3KdQoZ6JWyfcsyQik= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17/go.mod h1:EhG22vHRrvF8oXSTYStZhJc1aUgKtnJe+aOiFEV90cM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17 h1:RuNSMoozM8oXlgLG/n6WLaFGoea7/CddrCfIiSA+xdY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17/go.mod h1:F2xxQ9TZz5gDWsclCtPQscGpP0VUOc8RqgFM3vDENmU= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.5 h1:VrhDvQib/i0lxvr3zqlUwLwJP4fpmpyD9wYG1vfSu+Y= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.5/go.mod h1:k029+U8SY30/3/ras4G/Fnv/b88N4mAfliNn08Dem4M= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.9 h1:v6EiMvhEYBoHABfbGB4alOYmCIrcgyPPiBE1wZAEbqk= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.9/go.mod h1:yifAsgBxgJWn3ggx70A3urX2AN49Y5sJTD1UQFlfqBw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.13 h1:gd84Omyu9JLriJVCbGApcLzVR3XtmC4ZDPcAI6Ftvds= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.13/go.mod h1:sTGThjphYE4Ohw8vJiRStAcu3rbjtXRsdNB0TvZ5wwo= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 h1:5fFjR/ToSOzB2OQ/XqWpZBmNvmP/pJ1jOWYlFDJTjRQ= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.6/go.mod h1:qgFDZQSD/Kys7nJnVqYlWKnh0SSdMjAi0uSwON4wgYQ= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cucumber/gherkin/go/v26 v26.2.0 h1:EgIjePLWiPeslwIWmNQ3XHcypPsWAHoMCz/YEBKP4GI= github.com/cucumber/gherkin/go/v26 v26.2.0/go.mod h1:t2GAPnB8maCT4lkHL99BDCVNzCh1d7dBhCLt150Nr/0=
