This is an automated email from the ASF dual-hosted git repository. xiazcy pushed a commit to branch go-http-streaming in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit ab8713bf61c087473fa260808f8c6360158e9d70 Author: Yang Xia <[email protected]> AuthorDate: Mon Jan 27 18:28:00 2025 -0800 Added streaming deserializer, combined httpProtocol/httpTransport, added interceptors which replaces AuthInfo, added basic & sigv4 auth interceptors, updated connection settings --- gremlin-go/driver/auth.go | 104 ++++ gremlin-go/driver/authInfo.go | 95 ---- gremlin-go/driver/auth_test.go | 114 ++++ gremlin-go/driver/client.go | 56 +- gremlin-go/driver/client_test.go | 8 - gremlin-go/driver/connection.go | 5 - gremlin-go/driver/connection_test.go | 160 ++++-- gremlin-go/driver/cucumber/cucumberSteps_test.go | 2 +- gremlin-go/driver/cucumber/cucumberWorld.go | 4 +- gremlin-go/driver/driverRemoteConnection.go | 32 +- gremlin-go/driver/driverRemoteConnection_test.go | 22 +- gremlin-go/driver/graphBinary.go | 75 ++- gremlin-go/driver/gremlinlang.go | 14 - gremlin-go/driver/httpConnection.go | 277 +++++++++ gremlin-go/driver/httpConnection_test.go | 178 ++++++ gremlin-go/driver/httpProtocol.go | 157 ------ gremlin-go/driver/httpTransporter.go | 145 ----- gremlin-go/driver/logger.go | 3 +- gremlin-go/driver/performance/performanceSuite.go | 3 - gremlin-go/driver/resources/error-messages/en.json | 1 + .../driver/resources/logger-messages/en.json | 4 +- gremlin-go/driver/resultSet.go | 4 +- gremlin-go/driver/serializer.go | 112 +++- gremlin-go/driver/strategies_test.go | 57 +- gremlin-go/driver/streamingDeserializer.go | 617 +++++++++++++++++++++ gremlin-go/driver/streamingDeserializer_test.go | 402 ++++++++++++++ gremlin-go/driver/traversal.go | 30 - gremlin-go/driver/traversal_test.go | 4 - gremlin-go/go.mod | 14 + gremlin-go/go.sum | 28 + 30 files changed, 2086 insertions(+), 641 deletions(-) diff --git a/gremlin-go/driver/auth.go b/gremlin-go/driver/auth.go new file mode 100644 index 0000000000..70cff17a6c --- /dev/null +++ b/gremlin-go/driver/auth.go @@ -0,0 +1,104 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "context" + "encoding/base64" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" +) + +// BasicAuth returns a RequestInterceptor that adds Basic authentication header. +func BasicAuth(username, password string) RequestInterceptor { + encoded := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + return func(req *HttpRequest) error { + req.Headers.Set(HeaderAuthorization, "Basic "+encoded) + return nil + } +} + +// Sigv4Auth returns a RequestInterceptor that signs requests using AWS SigV4. +// It uses the default AWS credential chain (env vars, shared config, IAM role, etc.) +func Sigv4Auth(region, service string) RequestInterceptor { + return Sigv4AuthWithCredentials(region, service, nil) +} + +// Sigv4AuthWithCredentials returns a RequestInterceptor that signs requests using AWS SigV4 +// with the provided credentials provider. If provider is nil, uses default credential chain. +// +// Caches the signer and credentials provider for efficiency. +func Sigv4AuthWithCredentials(region, service string, credentialsProvider aws.CredentialsProvider) RequestInterceptor { + // Create signer once - it's stateless and safe to reuse + signer := v4.NewSigner() + + // Cache for resolved credentials provider (lazy initialization) + var cachedProvider aws.CredentialsProvider + var providerOnce sync.Once + var providerErr error + + return func(req *HttpRequest) error { + ctx := context.Background() + + // Resolve credentials provider once if not provided + provider := credentialsProvider + if provider == nil { + providerOnce.Do(func() { + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + providerErr = err + return + } + cachedProvider = cfg.Credentials + }) + if providerErr != nil { + return providerErr + } + provider = cachedProvider + } + + // Retrieve credentials (the provider handles caching internally) + creds, err := provider.Retrieve(ctx) + if err != nil { + return err + } + + stdReq, err := req.ToStdRequest() + if err != nil { + return err + } + stdReq.Body = nil // Body is handled separately via payload hash + + if err := signer.SignHTTP(ctx, creds, stdReq, req.PayloadHash(), service, region, time.Now()); err != nil { + return err + } + + // Copy signed headers back to HttpRequest + for k, v := range stdReq.Header { + req.Headers[k] = v + } + + return nil + } +} diff --git a/gremlin-go/driver/authInfo.go b/gremlin-go/driver/authInfo.go deleted file mode 100644 index 0671136470..0000000000 --- a/gremlin-go/driver/authInfo.go +++ /dev/null @@ -1,95 +0,0 @@ -/* -Licensed to the Apache Software Foundation (ASF) under one -or more contributor license agreements. See the NOTICE file -distributed with this work for additional information -regarding copyright ownership. The ASF licenses this file -to you under the Apache License, Version 2.0 (the -"License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, -software distributed under the License is distributed on an -"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -KIND, either express or implied. See the License for the -specific language governing permissions and limitations -under the License. -*/ - -package gremlingo - -import "net/http" - -// AuthInfoProvider is an interface that allows authentication information to be specified. -type AuthInfoProvider interface { - GetHeader() http.Header - GetBasicAuth() (ok bool, username, password string) -} - -// AuthInfo is an option struct that allows authentication information to be specified statically. -// Authentication can be provided via http.Header directly. -// Basic authentication can also be used via the BasicAuthInfo function. -type AuthInfo struct { - Header http.Header - Username string - Password string -} - -var _ AuthInfoProvider = (*AuthInfo)(nil) - -// GetHeader provides a safe way to get a header from the AuthInfo even if it is nil. -// This way we don't need any additional logic in the transport layer. -func (authInfo *AuthInfo) GetHeader() http.Header { - if authInfo == nil { - return nil - } else { - return authInfo.Header - } -} - -// GetBasicAuth provides a safe way to check if basic auth info is available from the AuthInfo even if it is nil. -// This way we don't need any additional logic in the transport layer. -func (authInfo *AuthInfo) GetBasicAuth() (bool, string, string) { - if authInfo == nil || (authInfo.Username == "" && authInfo.Password == "") { - return false, "", "" - } - return true, authInfo.Username, authInfo.Password -} - -// BasicAuthInfo provides a way to generate AuthInfo. Enter username and password and get the AuthInfo back. -func BasicAuthInfo(username string, password string) *AuthInfo { - return &AuthInfo{Username: username, Password: password} -} - -// HeaderAuthInfo provides a way to generate AuthInfo with only Header information. -func HeaderAuthInfo(header http.Header) *AuthInfo { - return &AuthInfo{Header: header} -} - -// DynamicAuth is an AuthInfoProvider that allows dynamic credential generation. -type DynamicAuth struct { - fn func() AuthInfoProvider -} - -var ( - _ AuthInfoProvider = (*DynamicAuth)(nil) - - // NoopAuthInfo is a no-op AuthInfoProvider that can be used to disable authentication. - NoopAuthInfo = NewDynamicAuth(func() AuthInfoProvider { return &AuthInfo{} }) -) - -// NewDynamicAuth provides a way to generate dynamic credentials with the specified generator function. -func NewDynamicAuth(f func() AuthInfoProvider) *DynamicAuth { - return &DynamicAuth{fn: f} -} - -// GetHeader calls the stored function to get the header dynamically. -func (d *DynamicAuth) GetHeader() http.Header { - return d.fn().GetHeader() -} - -// GetBasicAuth calls the stored function to get basic authentication dynamically. -func (d *DynamicAuth) GetBasicAuth() (bool, string, string) { - return d.fn().GetBasicAuth() -} diff --git a/gremlin-go/driver/auth_test.go b/gremlin-go/driver/auth_test.go new file mode 100644 index 0000000000..c0cce3ac9a --- /dev/null +++ b/gremlin-go/driver/auth_test.go @@ -0,0 +1,114 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "context" + "encoding/base64" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/stretchr/testify/assert" +) + +func createMockRequest() *HttpRequest { + req, _ := NewHttpRequest("POST", "https://localhost:8182/gremlin") + req.Headers.Set("Content-Type", graphBinaryMimeType) + req.Headers.Set("Accept", graphBinaryMimeType) + req.Body = []byte(`{"gremlin":"g.V()"}`) + return req +} + +func TestBasicAuth(t *testing.T) { + t.Run("adds authorization header", func(t *testing.T) { + req := createMockRequest() + assert.Empty(t, req.Headers.Get(HeaderAuthorization)) + + interceptor := BasicAuth("username", "password") + err := interceptor(req) + + assert.NoError(t, err) + authHeader := req.Headers.Get(HeaderAuthorization) + assert.True(t, strings.HasPrefix(authHeader, "Basic ")) + + // Verify encoding + encoded := strings.TrimPrefix(authHeader, "Basic ") + decoded, err := base64.StdEncoding.DecodeString(encoded) + assert.NoError(t, err) + assert.Equal(t, "username:password", string(decoded)) + }) +} + +// mockCredentialsProvider implements aws.CredentialsProvider for testing +type mockCredentialsProvider struct { + accessKey string + secretKey string + sessionToken string +} + +func (m *mockCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + return aws.Credentials{ + AccessKeyID: m.accessKey, + SecretAccessKey: m.secretKey, + SessionToken: m.sessionToken, + }, nil +} + +func TestSigv4Auth(t *testing.T) { + t.Run("adds signed headers", func(t *testing.T) { + req := createMockRequest() + assert.Empty(t, req.Headers.Get("Authorization")) + assert.Empty(t, req.Headers.Get("X-Amz-Date")) + + provider := &mockCredentialsProvider{ + accessKey: "MOCK_ACCESS_KEY", + secretKey: "MOCK_SECRET_KEY", + } + interceptor := Sigv4AuthWithCredentials("us-west-2", "neptune-db", provider) + err := interceptor(req) + + assert.NoError(t, err) + assert.NotEmpty(t, req.Headers.Get("X-Amz-Date")) + authHeader := req.Headers.Get("Authorization") + assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential=MOCK_ACCESS_KEY")) + assert.Contains(t, authHeader, "us-west-2/neptune-db/aws4_request") + assert.Contains(t, authHeader, "Signature=") + }) + + t.Run("adds session token when provided", func(t *testing.T) { + req := createMockRequest() + assert.Empty(t, req.Headers.Get("X-Amz-Security-Token")) + + provider := &mockCredentialsProvider{ + accessKey: "MOCK_ACCESS_KEY", + secretKey: "MOCK_SECRET_KEY", + sessionToken: "MOCK_SESSION_TOKEN", + } + interceptor := Sigv4AuthWithCredentials("us-west-2", "neptune-db", provider) + err := interceptor(req) + + assert.NoError(t, err) + assert.Equal(t, "MOCK_SESSION_TOKEN", req.Headers.Get("X-Amz-Security-Token")) + authHeader := req.Headers.Get("Authorization") + assert.True(t, strings.HasPrefix(authHeader, "AWS4-HMAC-SHA256 Credential=")) + assert.Contains(t, authHeader, "Signature=") + }) +} diff --git a/gremlin-go/driver/client.go b/gremlin-go/driver/client.go index ecacb54c5d..e98f3321e1 100644 --- a/gremlin-go/driver/client.go +++ b/gremlin-go/driver/client.go @@ -28,34 +28,30 @@ import ( "golang.org/x/text/language" ) -const keepAliveIntervalDefault = 5 * time.Second -const writeDeadlineDefault = 3 * time.Second const connectionTimeoutDefault = 5 * time.Second -// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. The default is 1MB. -// If a buffer size is set zero, then the transporter default size is used. The I/O buffer -// sizes do not limit the size of the messages that can be sent or received. -const readBufferSizeDefault = 1048576 -const writeBufferSizeDefault = 1048576 - // ClientSettings is used to modify a Client's settings on initialization. type ClientSettings struct { TraversalSource string LogVerbosity LogVerbosity Logger Logger Language language.Tag - AuthInfo AuthInfoProvider TlsConfig *tls.Config - KeepAliveInterval time.Duration - WriteDeadline time.Duration ConnectionTimeout time.Duration EnableCompression bool - ReadBufferSize int - WriteBufferSize int // Maximum number of concurrent connections. Default: number of runtime processors MaximumConcurrentConnections int EnableUserAgentOnConnect bool + + // RequestInterceptors are functions that modify HTTP requests before sending. + RequestInterceptors []RequestInterceptor +} + +// protocol defines the interface for HTTP communication with Gremlin server +type protocol interface { + send(request *request) (ResultSet, error) + close() } // Client is used to connect and interact with a Gremlin-supported server. @@ -64,7 +60,7 @@ type Client struct { traversalSource string logHandler *logHandler connectionSettings *connectionSettings - httpProtocol *httpProtocol + protocol protocol } // NewClient creates a Client and configures it with the given parameters. @@ -75,15 +71,10 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C LogVerbosity: Info, Logger: &defaultLogger{}, Language: language.English, - AuthInfo: &AuthInfo{}, TlsConfig: &tls.Config{}, - KeepAliveInterval: keepAliveIntervalDefault, - WriteDeadline: writeDeadlineDefault, ConnectionTimeout: connectionTimeoutDefault, EnableCompression: false, EnableUserAgentOnConnect: true, - ReadBufferSize: readBufferSizeDefault, - WriteBufferSize: writeBufferSizeDefault, MaximumConcurrentConnections: runtime.NumCPU(), } @@ -92,27 +83,27 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C } connSettings := &connectionSettings{ - authInfo: settings.AuthInfo, tlsConfig: settings.TlsConfig, - keepAliveInterval: settings.KeepAliveInterval, - writeDeadline: settings.WriteDeadline, connectionTimeout: settings.ConnectionTimeout, enableCompression: settings.EnableCompression, - readBufferSize: settings.ReadBufferSize, - writeBufferSize: settings.WriteBufferSize, enableUserAgentOnConnect: settings.EnableUserAgentOnConnect, } logHandler := newLogHandler(settings.Logger, settings.LogVerbosity, settings.Language) - httpProt := newHttpProtocol(logHandler, url, connSettings) + conn := newHttpConnection(logHandler, url, connSettings) + + // Add user-provided interceptors + for _, interceptor := range settings.RequestInterceptors { + conn.AddInterceptor(interceptor) + } client := &Client{ url: url, traversalSource: settings.TraversalSource, logHandler: logHandler, connectionSettings: connSettings, - httpProtocol: httpProt, + protocol: conn, } return client, nil @@ -121,8 +112,10 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C // Close closes the client via connection. // This is idempotent due to the underlying close() methods being idempotent as well. func (client *Client) Close() { - // TODO check what needs to be closed client.logHandler.logf(Info, closeClient, client.url) + if client.protocol != nil { + client.protocol.close() + } } func (client *Client) errorCallback() { @@ -133,10 +126,7 @@ func (client *Client) errorCallback() { func (client *Client) SubmitWithOptions(traversalString string, requestOptions RequestOptions) (ResultSet, error) { client.logHandler.logf(Debug, submitStartedString, traversalString) request := MakeStringRequest(traversalString, client.traversalSource, requestOptions) - - // TODO interceptors (ie. auth) - - rs, err := client.httpProtocol.send(&request) + rs, err := client.protocol.send(&request) return rs, err } @@ -152,10 +142,8 @@ func (client *Client) Submit(traversalString string, bindings ...map[string]inte } // submitGremlinLang submits GremlinLang to the server to execute and returns a ResultSet. -// TODO test and update when connection is set up func (client *Client) submitGremlinLang(gremlinLang *GremlinLang) (ResultSet, error) { client.logHandler.logf(Debug, submitStartedString, *gremlinLang) - // TODO placeholder requestOptionsBuilder := new(RequestOptionsBuilder) if len(gremlinLang.GetParameters()) > 0 { requestOptionsBuilder.SetBindings(gremlinLang.GetParameters()) @@ -165,7 +153,7 @@ func (client *Client) submitGremlinLang(gremlinLang *GremlinLang) (ResultSet, er } request := MakeStringRequest(gremlinLang.GetGremlin(), client.traversalSource, requestOptionsBuilder.Create()) - return client.httpProtocol.send(&request) + return client.protocol.send(&request) } func applyOptionsConfig(builder *RequestOptionsBuilder, config map[string]interface{}) *RequestOptionsBuilder { diff --git a/gremlin-go/driver/client_test.go b/gremlin-go/driver/client_test.go index 9e223dd193..fdfd1204da 100644 --- a/gremlin-go/driver/client_test.go +++ b/gremlin-go/driver/client_test.go @@ -30,7 +30,6 @@ func TestClient(t *testing.T) { // Integration test variables. testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) testNoAuthEnable := getEnvOrDefaultBool("RUN_INTEGRATION_TESTS", true) - testNoAuthAuthInfo := &AuthInfo{} testNoAuthTlsConfig := &tls.Config{} t.Run("Test client.SubmitWithOptions()", func(t *testing.T) { @@ -38,7 +37,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo }) assert.NoError(t, err) assert.NotNil(t, client) @@ -58,7 +56,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo }) assert.NoError(t, err) assert.NotNil(t, client) @@ -74,7 +71,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo settings.TraversalSource = testServerModernGraphAlias }) assert.NoError(t, err) @@ -97,7 +93,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo settings.TraversalSource = testServerModernGraphAlias }) assert.NoError(t, err) @@ -122,7 +117,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo settings.TraversalSource = testServerModernGraphAlias }) assert.NoError(t, err) @@ -147,7 +141,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo settings.TraversalSource = testServerModernGraphAlias }) @@ -170,7 +163,6 @@ func TestClient(t *testing.T) { client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo settings.TraversalSource = testServerCrewGraphAlias }) diff --git a/gremlin-go/driver/connection.go b/gremlin-go/driver/connection.go index 9dbb1ec8ea..f04997eeb0 100644 --- a/gremlin-go/driver/connection.go +++ b/gremlin-go/driver/connection.go @@ -25,13 +25,8 @@ import ( ) type connectionSettings struct { - authInfo AuthInfoProvider tlsConfig *tls.Config - keepAliveInterval time.Duration - writeDeadline time.Duration connectionTimeout time.Duration enableCompression bool - readBufferSize int - writeBufferSize int enableUserAgentOnConnect bool } diff --git a/gremlin-go/driver/connection_test.go b/gremlin-go/driver/connection_test.go index d27e65d84c..8dee3d6567 100644 --- a/gremlin-go/driver/connection_test.go +++ b/gremlin-go/driver/connection_test.go @@ -29,6 +29,7 @@ import ( "strconv" "sync" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -55,15 +56,10 @@ var testNames = []string{"Lyndon", "Yang", "Simon", "Rithin", "Alexey", "Valenty func newDefaultConnectionSettings() *connectionSettings { return &connectionSettings{ - authInfo: &AuthInfo{}, tlsConfig: &tls.Config{}, - keepAliveInterval: keepAliveIntervalDefault, - writeDeadline: writeDeadlineDefault, connectionTimeout: connectionTimeoutDefault, enableCompression: false, enableUserAgentOnConnect: true, - readBufferSize: readBufferSizeDefault, - writeBufferSize: writeBufferSizeDefault, } } @@ -90,11 +86,10 @@ func addTestData(t *testing.T, g *GraphTraversalSource) { assert.Nil(t, <-promise) } -func getTestGraph(t *testing.T, url string, auth AuthInfoProvider, tls *tls.Config) *GraphTraversalSource { +func getTestGraph(t *testing.T, url string, tls *tls.Config) *GraphTraversalSource { remote, err := NewDriverRemoteConnection(url, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = tls - settings.AuthInfo = auth settings.TraversalSource = testServerGraphAlias }) assert.Nil(t, err) @@ -104,8 +99,8 @@ func getTestGraph(t *testing.T, url string, auth AuthInfoProvider, tls *tls.Conf return g } -func initializeGraph(t *testing.T, url string, auth AuthInfoProvider, tls *tls.Config) *GraphTraversalSource { - g := getTestGraph(t, url, auth, tls) +func initializeGraph(t *testing.T, url string, tls *tls.Config) *GraphTraversalSource { + g := getTestGraph(t, url, tls) // Drop the graph and check that it is empty. dropGraph(t, g) @@ -244,11 +239,6 @@ func getEnvOrDefaultBool(key string, defaultValue bool) bool { return defaultValue } -func getBasicAuthInfo() *AuthInfo { - return BasicAuthInfo(getEnvOrDefaultString("GREMLIN_GO_BASIC_AUTH_USERNAME", "stephen"), - getEnvOrDefaultString("GREMLIN_GO_BASIC_AUTH_PASSWORD", "password")) -} - func skipTestsIfNotEnabled(t *testing.T, testSuiteName string, testSuiteEnabled bool) { if !testSuiteEnabled { t.Skipf("Skipping %s because %s tests are not enabled.", t.Name(), testSuiteName) @@ -259,29 +249,62 @@ func TestConnection(t *testing.T) { // Integration test variables. testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) testNoAuthEnable := getEnvOrDefaultBool("RUN_INTEGRATION_TESTS", true) - testNoAuthAuthInfo := &AuthInfo{} testNoAuthTlsConfig := &tls.Config{} // No authentication integration test with graphs loaded and alias configured server testNoAuthWithAliasUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) testNoAuthWithAliasEnable := getEnvOrDefaultBool("RUN_INTEGRATION_WITH_ALIAS_TESTS", true) - testNoAuthWithAliasAuthInfo := &AuthInfo{} testNoAuthWithAliasTlsConfig := &tls.Config{} // Basic authentication integration test variables. testBasicAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_BASIC_AUTH_URL", basicAuthWithSsl) testBasicAuthEnable := getEnvOrDefaultBool("RUN_BASIC_AUTH_INTEGRATION_TESTS", false) - testBasicAuthAuthInfo := getBasicAuthInfo() + testBasicAuthUsername := getEnvOrDefaultString("GREMLIN_GO_BASIC_AUTH_USERNAME", "stephen") + testBasicAuthPassword := getEnvOrDefaultString("GREMLIN_GO_BASIC_AUTH_PASSWORD", "password") testBasicAuthTlsConfig := &tls.Config{InsecureSkipVerify: true} + // this test is used to test the ws->http POC changes via manual execution with a local TP 4.0 gremlin server running on 8182 + t.Run("Test client.submit()", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + + tlsConf := tls.Config{ + InsecureSkipVerify: true, + } + + client, err := NewClient(testNoAuthUrl, + //client, err := NewClient(noAuthSslUrl, + func(settings *ClientSettings) { + settings.TlsConfig = &tlsConf + settings.EnableCompression = true + settings.TraversalSource = testServerModernGraphAlias + }) + assert.Nil(t, err) + assert.NotNil(t, client) + defer client.Close() + + // synchronous + for i := 0; i < 5; i++ { + submitCount(i, client, t) + } + + // async + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + submitCount(i, client, t) + }(i) + } + wg.Wait() + }) + t.Run("Test client.submit() with concurrency", func(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) client, err := NewClient(testNoAuthUrl, func(settings *ClientSettings) { settings.TlsConfig = testNoAuthTlsConfig - settings.AuthInfo = testNoAuthAuthInfo - settings.WriteBufferSize = 1024 settings.EnableCompression = true settings.TraversalSource = testServerModernGraphAlias }) @@ -310,7 +333,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() // Read test data out of the graph and check that it is correct. @@ -325,7 +348,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() readWithNextAndHasNext(t, g) @@ -336,7 +359,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() resultSet, err := g.V().HasLabel(personLabel).Properties(nameKey).GetResultSet() @@ -360,7 +383,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := getTestGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := getTestGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() // Drop the graph. @@ -405,7 +428,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() // Read test data out of the graph and check that it is correct. @@ -421,7 +444,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() // Run traversal and test Next/HasNext calls @@ -447,7 +470,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() readUsingAnonymousTraversal(t, g) @@ -476,7 +499,9 @@ func TestConnection(t *testing.T) { remote, err := NewDriverRemoteConnection(testBasicAuthUrl, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = testBasicAuthTlsConfig - settings.AuthInfo = testBasicAuthAuthInfo + settings.RequestInterceptors = []RequestInterceptor{ + BasicAuth(testBasicAuthUsername, testBasicAuthPassword), + } }) assert.Nil(t, err) assert.NotNil(t, remote) @@ -502,7 +527,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() r, err := g.WithSack(1).V().Has("name", "Lyndon").Values("foo").Sack(Operator.Sum).Sack().ToList() @@ -521,7 +546,7 @@ func TestConnection(t *testing.T) { // skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // // // Initialize graph - // g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + // g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) // defer g.remoteConnection.Close() // // r, err := g.V().Has("name", "Lyndon").Values("foo").Profile().ToList() @@ -539,7 +564,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() prop := &BigDecimal{11, big.NewInt(int64(22))} @@ -559,7 +584,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) defer g.remoteConnection.Close() prop := &ByteBuffer{[]byte{byte(127), byte(255)}} @@ -579,7 +604,6 @@ func TestConnection(t *testing.T) { remote, err := NewDriverRemoteConnection(testNoAuthWithAliasUrl, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = testNoAuthWithAliasTlsConfig - settings.AuthInfo = testNoAuthWithAliasAuthInfo settings.TraversalSource = testServerModernGraphAlias }) assert.Nil(t, err) @@ -599,7 +623,7 @@ func TestConnection(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) // Initialize graph - g := initializeGraph(t, testNoAuthUrl, testNoAuthAuthInfo, testNoAuthTlsConfig) + g := initializeGraph(t, testNoAuthUrl, testNoAuthTlsConfig) // Drop the graph. dropGraph(t, g) @@ -607,8 +631,7 @@ func TestConnection(t *testing.T) { // Add vertices and edges to graph. rs, err := g.AddV("person").Property("id", T__.Unfold().Property().AddV()).ToList() assert.Nil(t, rs) - fmt.Println(err.Error()) - assert.True(t, isSameErrorCode(newError(err0502ResponseHandlerReadLoopError), err)) + assert.True(t, isSameErrorCode(newError(err0502ResponseHandlerError), err)) rs, err = g.V().Count().ToList() assert.NotNil(t, rs) @@ -621,7 +644,7 @@ func TestConnection(t *testing.T) { t.Run("Get all properties when materializeProperties is all", func(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() // vertex contains 2 properties, name and age @@ -634,7 +657,7 @@ func TestConnection(t *testing.T) { t.Run("Skip properties when materializeProperties is tokens", func(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() // vertex contains 2 properties, name and age @@ -647,7 +670,7 @@ func TestConnection(t *testing.T) { t.Run("Get all properties when no materializeProperties", func(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() r, err := g.V().Has("person", "name", "marko").Next() @@ -659,7 +682,7 @@ func TestConnection(t *testing.T) { t.Run("Test DriverRemoteConnection Traversal With materializeProperties in Modern Graph", func(t *testing.T) { skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() vertices, err := g.With("materializeProperties", MaterializeProperties.Tokens).V().ToList() @@ -760,3 +783,62 @@ func submitCount(i int, client *Client, t *testing.T) { assert.Equal(t, 6+i, c) _, _ = fmt.Fprintf(os.Stdout, "Received result : %s\n", result) } + +func TestStreamingResultDelivery(t *testing.T) { + testNoAuthWithAliasEnable := getEnvOrDefaultBool("RUN_INTEGRATION_WITH_ALIAS_TESTS", true) + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthWithAliasEnable) + remote, err := NewDriverRemoteConnection(getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl), + func(settings *DriverRemoteConnectionSettings) { + settings.TlsConfig = &tls.Config{} + settings.TraversalSource = "ggrateful" + }) + assert.Nil(t, err) + assert.NotNil(t, remote) + g := Traversal_().With(remote) + defer g.remoteConnection.Close() + + t.Run("first result arrives before all results", func(t *testing.T) { + start := time.Now() + rs, err := g.V().Properties().GetResultSet() + assert.Nil(t, err) + + // First result should arrive quickly + _, ok, err := rs.One() + firstResultTime := time.Since(start) + assert.Nil(t, err) + assert.True(t, ok) + + // Drain remaining + _, err = rs.All() + assert.Nil(t, err) + totalTime := time.Since(start) + + t.Logf("First result: %v, Total: %v, Ratio: %.2f%%", + firstResultTime, totalTime, float64(firstResultTime)/float64(totalTime)*100) + }) + + t.Run("results arrive incrementally", func(t *testing.T) { + rs, err := g.V().Properties().GetResultSet() + assert.Nil(t, err) + + var timestamps []time.Time + start := time.Now() + + for { + _, ok, err := rs.One() + assert.Nil(t, err) + if !ok { + break + } + timestamps = append(timestamps, time.Now()) + } + + if len(timestamps) < 2 { + t.Skip("need more results to test incremental delivery") + } + + firstHalf := timestamps[len(timestamps)/2].Sub(start) + total := timestamps[len(timestamps)-1].Sub(start) + t.Logf("Half results at: %v, All results at: %v", firstHalf, total) + }) +} diff --git a/gremlin-go/driver/cucumber/cucumberSteps_test.go b/gremlin-go/driver/cucumber/cucumberSteps_test.go index f9caa7e187..532f950823 100644 --- a/gremlin-go/driver/cucumber/cucumberSteps_test.go +++ b/gremlin-go/driver/cucumber/cucumberSteps_test.go @@ -430,7 +430,7 @@ func (tg *tinkerPopGraph) nothingShouldHappenBecause(arg1 *godog.DocString) erro func (tg *tinkerPopGraph) chooseGraph(graphName string) error { tg.graphName = graphName data := tg.graphDataMap[graphName] - tg.g = gremlingo.Traversal_().With(data.connection) + tg.g = gremlingo.Traversal_().With(data.connection).With("language", "gremlin-lang") if graphName == "empty" { err := tg.cleanEmptyDataGraph(tg.g) if err != nil { diff --git a/gremlin-go/driver/cucumber/cucumberWorld.go b/gremlin-go/driver/cucumber/cucumberWorld.go index ac92ea1db8..dcbec09ef4 100644 --- a/gremlin-go/driver/cucumber/cucumberWorld.go +++ b/gremlin-go/driver/cucumber/cucumberWorld.go @@ -109,7 +109,7 @@ func (t *CucumberWorld) loadAllDataGraph() { if err != nil { panic(fmt.Sprintf("Failed to create connection '%v'", err)) } - g := gremlingo.Traversal_().With(connection) + g := gremlingo.Traversal_().With(connection).With("language", "gremlin-lang") t.graphDataMap[name] = &DataGraph{ name: name, connection: connection, @@ -130,7 +130,7 @@ func (t *CucumberWorld) loadEmptyDataGraph() { func (t *CucumberWorld) reloadEmptyData() { graphData := t.getDataGraphFromMap("empty") - g := gremlingo.Traversal_().With(graphData.connection) + g := gremlingo.Traversal_().With(graphData.connection).With("language", "gremlin-lang") graphData.vertices = getVertices(g) graphData.edges = getEdges(g) } diff --git a/gremlin-go/driver/driverRemoteConnection.go b/gremlin-go/driver/driverRemoteConnection.go index e4ec91d7bf..93d0b42683 100644 --- a/gremlin-go/driver/driverRemoteConnection.go +++ b/gremlin-go/driver/driverRemoteConnection.go @@ -33,22 +33,16 @@ type DriverRemoteConnectionSettings struct { LogVerbosity LogVerbosity Logger Logger Language language.Tag - AuthInfo AuthInfoProvider TlsConfig *tls.Config - KeepAliveInterval time.Duration - WriteDeadline time.Duration ConnectionTimeout time.Duration EnableCompression bool EnableUserAgentOnConnect bool - ReadBufferSize int - WriteBufferSize int - // Minimum amount of concurrent active traversals on a connection to trigger creation of a new connection - NewConnectionThreshold int // Maximum number of concurrent connections. Default: number of runtime processors MaximumConcurrentConnections int - // Initial amount of instantiated connections. Default: 1 - InitialConcurrentConnections int + + // RequestInterceptors are functions that modify HTTP requests before sending. + RequestInterceptors []RequestInterceptor } // DriverRemoteConnection is a remote connection. @@ -70,15 +64,10 @@ func NewDriverRemoteConnection( LogVerbosity: Info, Logger: &defaultLogger{}, Language: language.English, - AuthInfo: &AuthInfo{}, TlsConfig: &tls.Config{}, - KeepAliveInterval: keepAliveIntervalDefault, - WriteDeadline: writeDeadlineDefault, ConnectionTimeout: connectionTimeoutDefault, EnableCompression: false, EnableUserAgentOnConnect: true, - ReadBufferSize: readBufferSizeDefault, - WriteBufferSize: writeBufferSizeDefault, MaximumConcurrentConnections: runtime.NumCPU(), } @@ -87,27 +76,27 @@ func NewDriverRemoteConnection( } connSettings := &connectionSettings{ - authInfo: settings.AuthInfo, tlsConfig: settings.TlsConfig, - keepAliveInterval: settings.KeepAliveInterval, - writeDeadline: settings.WriteDeadline, connectionTimeout: settings.ConnectionTimeout, enableCompression: settings.EnableCompression, - readBufferSize: settings.ReadBufferSize, - writeBufferSize: settings.WriteBufferSize, enableUserAgentOnConnect: settings.EnableUserAgentOnConnect, } logHandler := newLogHandler(settings.Logger, settings.LogVerbosity, settings.Language) - httpProt := newHttpProtocol(logHandler, url, connSettings) + conn := newHttpConnection(logHandler, url, connSettings) + + // Add user-provided interceptors + for _, interceptor := range settings.RequestInterceptors { + conn.AddInterceptor(interceptor) + } client := &Client{ url: url, traversalSource: settings.TraversalSource, logHandler: logHandler, connectionSettings: connSettings, - httpProtocol: httpProt, + protocol: conn, } return &DriverRemoteConnection{client: client, isClosed: false, settings: settings}, nil @@ -136,7 +125,6 @@ func (driver *DriverRemoteConnection) Submit(traversalString string) (ResultSet, } // submitGremlinLang sends a GremlinLang traversal to the server. -// TODO test and update when connection is set up func (driver *DriverRemoteConnection) submitGremlinLang(gremlinLang *GremlinLang) (ResultSet, error) { if driver.isClosed { return nil, newError(err0203SubmitGremlinLangToClosedConnectionError) diff --git a/gremlin-go/driver/driverRemoteConnection_test.go b/gremlin-go/driver/driverRemoteConnection_test.go index d702d1809e..bdae0c0aa8 100644 --- a/gremlin-go/driver/driverRemoteConnection_test.go +++ b/gremlin-go/driver/driverRemoteConnection_test.go @@ -20,7 +20,6 @@ under the License. package gremlingo import ( - "net/http" "testing" "github.com/stretchr/testify/assert" @@ -28,20 +27,11 @@ import ( func TestAuthentication(t *testing.T) { - t.Run("Test BasicAuthInfo.", func(t *testing.T) { - header := BasicAuthInfo("Lyndon", "Bauto") - assert.Nil(t, header.GetHeader()) - b, _, _ := header.GetBasicAuth() - assert.True(t, b) - }) - - t.Run("Test GetHeader.", func(t *testing.T) { - header := &AuthInfo{} - assert.Nil(t, header.GetHeader()) - header = nil - assert.Nil(t, header.GetHeader()) - httpHeader := http.Header{} - header = &AuthInfo{Header: httpHeader} - assert.Equal(t, httpHeader, header.GetHeader()) + t.Run("Test BasicAuth interceptor", func(t *testing.T) { + interceptor := BasicAuth("user", "pass") + req, _ := NewHttpRequest("POST", "http://localhost:8182/gremlin") + err := interceptor(req) + assert.Nil(t, err) + assert.Contains(t, req.Headers.Get(HeaderAuthorization), "Basic ") }) } diff --git a/gremlin-go/driver/graphBinary.go b/gremlin-go/driver/graphBinary.go index c553b0a2e2..c28aa633dc 100644 --- a/gremlin-go/driver/graphBinary.go +++ b/gremlin-go/driver/graphBinary.go @@ -22,6 +22,7 @@ package gremlingo import ( "bytes" "encoding/binary" + "errors" "fmt" "math" "math/big" @@ -31,6 +32,8 @@ import ( "github.com/google/uuid" ) +var ErrIncompleteData = errors.New("incomplete data") + // Version 1.0 // dataType graphBinary types. @@ -613,12 +616,12 @@ func (serializer *graphBinaryTypeSerializer) writeValueFlagNone(buffer *bytes.Bu // readers -func readTemp(data *[]byte, i *int, len int) *[]byte { - tmp := make([]byte, len) - for j := 0; j < len; j++ { - tmp[j] = (*data)[j+*i] +func readTemp(data *[]byte, i *int, length int) *[]byte { + if *i+length > len(*data) { + panic(ErrIncompleteData) } - *i += len + tmp := (*data)[*i : *i+length] + *i += length return &tmp } @@ -655,21 +658,30 @@ func readLong(data *[]byte, i *int) (interface{}, error) { } func readBigInt(data *[]byte, i *int) (interface{}, error) { - sz := readIntSafe(data, i) - b := readTemp(data, i, int(sz)) + // Length lookahead - check if complete BigInt is available + if *i+4 > len(*data) { + panic(ErrIncompleteData) + } + sz := int(binary.BigEndian.Uint32((*data)[*i : *i+4])) + if sz > 0 && *i+4+sz > len(*data) { + panic(ErrIncompleteData) + } + *i += 4 + if sz == 0 { + return big.NewInt(0), nil + } + b := (*data)[*i : *i+sz] + *i += sz - var newBigInt = big.NewInt(0).SetBytes(*b) + var newBigInt = big.NewInt(0).SetBytes(b) var one = big.NewInt(1) - if len(*b) == 0 { - return newBigInt, nil - } // If the first bit in the first element of the byte array is a 1, we need to interpret the byte array as a two's complement representation - if (*b)[0]&0x80 == 0x00 { - newBigInt.SetBytes(*b) + if b[0]&0x80 == 0x00 { + newBigInt.SetBytes(b) return newBigInt, nil } // Undo two's complement to byte array and set negative boolean to true - length := uint((len(*b)*8)/8+1) * 8 + length := uint((len(b)*8)/8+1) * 8 b2 := new(big.Int).Sub(newBigInt, new(big.Int).Lsh(one, length)).Bytes() // Strip the resulting 0xff byte at the start of array @@ -709,12 +721,21 @@ func readDouble(data *[]byte, i *int) (interface{}, error) { } func readString(data *[]byte, i *int) (interface{}, error) { - sz := int(readUint32Safe(data, i)) + // Length lookahead - check if complete string is available + if *i+4 > len(*data) { + panic(ErrIncompleteData) + } + sz := int(binary.BigEndian.Uint32((*data)[*i : *i+4])) + if sz > 0 && *i+4+sz > len(*data) { + panic(ErrIncompleteData) // Don't advance index - wait for more data + } + *i += 4 // Now safe to advance past length if sz == 0 { return "", nil } + result := string((*data)[*i : *i+sz]) *i += sz - return string((*data)[*i-sz : *i]), nil + return result, nil } func readDataType(data *[]byte, i *int) dataType { @@ -776,12 +797,19 @@ func readList(data *[]byte, i *int, flag byte) (interface{}, error) { } func readByteBuffer(data *[]byte, i *int) (interface{}, error) { + // Length lookahead - check if complete ByteBuffer is available + if *i+4 > len(*data) { + panic(ErrIncompleteData) + } + sz := int(binary.BigEndian.Uint32((*data)[*i : *i+4])) + if sz > 0 && *i+4+sz > len(*data) { + panic(ErrIncompleteData) + } + *i += 4 r := &ByteBuffer{} - sz := readIntSafe(data, i) r.Data = make([]byte, sz) - for j := int32(0); j < sz; j++ { - r.Data[j] = readByteSafe(data, i) - } + copy(r.Data, (*data)[*i:*i+sz]) + *i += sz return r, nil } @@ -846,7 +874,12 @@ func readSet(data *[]byte, i *int, flag byte) (interface{}, error) { } func readUuid(data *[]byte, i *int) (interface{}, error) { - id, _ := uuid.FromBytes(*readTemp(data, i, 16)) + // Bounds check - UUID is fixed 16 bytes + if *i+16 > len(*data) { + panic(ErrIncompleteData) + } + id, _ := uuid.FromBytes((*data)[*i : *i+16]) + *i += 16 return id, nil } diff --git a/gremlin-go/driver/gremlinlang.go b/gremlin-go/driver/gremlinlang.go index 498cda8206..cd2b724487 100644 --- a/gremlin-go/driver/gremlinlang.go +++ b/gremlin-go/driver/gremlinlang.go @@ -567,17 +567,3 @@ func (gl *GremlinLang) convertArgument(arg interface{}) (interface{}, error) { } } } - -// TODO revisit and remove if necessary -//var withOptionsMap map[any]string = map[any]string{ -// WithOptions.Tokens: "WithOptions.tokens", -// WithOptions.None: "WithOptions.none", -// WithOptions.Ids: "WithOptions.ids", -// WithOptions.Labels: "WithOptions.labels", -// WithOptions.Keys: "WithOptions.keys", -// WithOptions.Values: "WithOptions.values", -// WithOptions.All: "WithOptions.all", -// WithOptions.Indexer: "WithOptions.indexer", -// WithOptions.List: "WithOptions.list", -// WithOptions.Map: "WithOptions.map", -//} diff --git a/gremlin-go/driver/httpConnection.go b/gremlin-go/driver/httpConnection.go new file mode 100644 index 0000000000..4f034e0e65 --- /dev/null +++ b/gremlin-go/driver/httpConnection.go @@ -0,0 +1,277 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "bytes" + "compress/zlib" + "crypto/sha256" + "encoding/hex" + "io" + "net" + "net/http" + "net/url" + "time" +) + +// Common HTTP header keys +const ( + HeaderContentType = "Content-Type" + HeaderAccept = "Accept" + HeaderUserAgent = "User-Agent" + HeaderAcceptEncoding = "Accept-Encoding" + HeaderAuthorization = "Authorization" +) + +// HttpRequest represents an HTTP request that can be modified by interceptors. +type HttpRequest struct { + Method string + URL *url.URL + Headers http.Header + Body []byte +} + +// NewHttpRequest creates a new HttpRequest with the given method and URL. +func NewHttpRequest(method, rawURL string) (*HttpRequest, error) { + u, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + return &HttpRequest{ + Method: method, + URL: u, + Headers: make(http.Header), + }, nil +} + +// ToStdRequest converts HttpRequest to a standard http.Request for signing. +// Returns nil if the request cannot be created (invalid method or URL). +func (r *HttpRequest) ToStdRequest() (*http.Request, error) { + req, err := http.NewRequest(r.Method, r.URL.String(), bytes.NewReader(r.Body)) + if err != nil { + return nil, err + } + req.Header = r.Headers + return req, nil +} + +// PayloadHash returns the SHA256 hash of the request body for SigV4 signing. +func (r *HttpRequest) PayloadHash() string { + if len(r.Body) == 0 { + return "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" // SHA256 of empty string + } + h := sha256.Sum256(r.Body) + return hex.EncodeToString(h[:]) +} + +// RequestInterceptor is a function that modifies an HTTP request before it is sent. +type RequestInterceptor func(*HttpRequest) error + +// httpConnection handles HTTP request/response for Gremlin queries +type httpConnection struct { + url string + httpClient *http.Client + connSettings *connectionSettings + logHandler *logHandler + serializer *GraphBinarySerializer + interceptors []RequestInterceptor +} + +// Connection pool defaults aligned with Java driver +const ( + defaultMaxConnsPerHost = 128 // Java: ConnectionPool.MAX_POOL_SIZE + defaultMaxIdleConnsPerHost = 8 // Keep some connections warm + defaultIdleConnTimeout = 180 * time.Second // Java: CONNECTION_IDLE_TIMEOUT_MILLIS + defaultConnectionTimeout = 15 * time.Second // Java: CONNECTION_SETUP_TIMEOUT_MILLIS +) + +func newHttpConnection(handler *logHandler, url string, connSettings *connectionSettings) *httpConnection { + connectionTimeout := connSettings.connectionTimeout + if connectionTimeout == 0 { + connectionTimeout = defaultConnectionTimeout + } + + transport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: connectionTimeout, // Connection setup timeout only + KeepAlive: 30 * time.Second, + }).DialContext, + TLSClientConfig: connSettings.tlsConfig, + MaxConnsPerHost: defaultMaxConnsPerHost, + MaxIdleConnsPerHost: defaultMaxIdleConnsPerHost, + IdleConnTimeout: defaultIdleConnTimeout, + DisableCompression: !connSettings.enableCompression, + } + + return &httpConnection{ + url: url, + httpClient: &http.Client{Transport: transport}, // No Timeout - allows streaming + connSettings: connSettings, + logHandler: handler, + serializer: newGraphBinarySerializer(handler), + } +} + +// AddInterceptor adds a request interceptor to the chain. +func (c *httpConnection) AddInterceptor(interceptor RequestInterceptor) { + c.interceptors = append(c.interceptors, interceptor) +} + +// send sends request and streams results directly to ResultSet +func (c *httpConnection) send(req *request) (ResultSet, error) { + rs := newChannelResultSet() + + data, err := c.serializer.SerializeMessage(req) + if err != nil { + rs.Close() + return rs, err + } + + go c.executeAndStream(data, rs) + + return rs, nil +} + +func (c *httpConnection) executeAndStream(data []byte, rs ResultSet) { + defer rs.Close() + + // Create HttpRequest for interceptors + httpReq, err := NewHttpRequest(http.MethodPost, c.url) + if err != nil { + c.logHandler.logf(Error, failedToSendRequest, err.Error()) + rs.setError(err) + return + } + httpReq.Body = data + + // Set default headers before interceptors + c.setHttpRequestHeaders(httpReq) + + // Apply interceptors + for _, interceptor := range c.interceptors { + if err := interceptor(httpReq); err != nil { + c.logHandler.logf(Error, failedToSendRequest, err.Error()) + rs.setError(err) + return + } + } + + // Create actual http.Request from HttpRequest + req, err := http.NewRequest(httpReq.Method, httpReq.URL.String(), bytes.NewReader(httpReq.Body)) + if err != nil { + c.logHandler.logf(Error, failedToSendRequest, err.Error()) + rs.setError(err) + return + } + req.Header = httpReq.Headers + + resp, err := c.httpClient.Do(req) + if err != nil { + c.logHandler.logf(Error, failedToSendRequest, err.Error()) + rs.setError(err) + return + } + defer func() { + if err := resp.Body.Close(); err != nil { + c.logHandler.logf(Debug, failedToCloseResponseBody, err.Error()) + } + }() + + reader, zlibReader, err := c.getReader(resp) + if err != nil { + c.logHandler.logf(Error, failedToReceiveResponse, err.Error()) + rs.setError(err) + return + } + if zlibReader != nil { + defer func() { + if err := zlibReader.Close(); err != nil { + c.logHandler.logf(Debug, failedToCloseDecompReader, err.Error()) + } + }() + } + + c.streamToResultSet(reader, rs) +} + +// setHttpRequestHeaders sets default headers on HttpRequest (for interceptors) +func (c *httpConnection) setHttpRequestHeaders(req *HttpRequest) { + req.Headers.Set(HeaderContentType, graphBinaryMimeType) + req.Headers.Set(HeaderAccept, graphBinaryMimeType) + + if c.connSettings.enableUserAgentOnConnect { + req.Headers.Set(HeaderUserAgent, userAgent) + } + if c.connSettings.enableCompression { + req.Headers.Set(HeaderAcceptEncoding, "deflate") + } +} + +func (c *httpConnection) getReader(resp *http.Response) (io.Reader, io.Closer, error) { + if resp.Header.Get("Content-Encoding") == "deflate" { + zr, err := zlib.NewReader(resp.Body) + if err != nil { + return nil, nil, err + } + return zr, zr, nil + } + return resp.Body, nil, nil +} + +func (c *httpConnection) streamToResultSet(reader io.Reader, rs ResultSet) { + d := NewStreamingDeserializer(reader) + if err := d.ReadHeader(); err != nil { + if err != io.EOF { + c.logHandler.logf(Error, failedToReceiveResponse, err.Error()) + rs.setError(err) + } + return + } + + for { + obj, err := d.ReadFullyQualified() + if err != nil { + if err != io.EOF { + c.logHandler.logf(Error, failedToReceiveResponse, err.Error()) + rs.setError(err) + } + return + } + + if marker, ok := obj.(Marker); ok && marker == EndOfStream() { + code, msg, _, err := d.ReadStatus() + if err != nil { + c.logHandler.logf(Error, failedToReceiveResponse, err.Error()) + rs.setError(err) + return + } + if code != 200 && code != 0 { + rs.setError(newError(err0502ResponseHandlerReadLoopError, msg, code)) + } + return + } + + rs.Channel() <- &Result{obj} + } +} + +func (c *httpConnection) close() { + c.httpClient.CloseIdleConnections() +} diff --git a/gremlin-go/driver/httpConnection_test.go b/gremlin-go/driver/httpConnection_test.go new file mode 100644 index 0000000000..76d9146936 --- /dev/null +++ b/gremlin-go/driver/httpConnection_test.go @@ -0,0 +1,178 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "bytes" + "crypto/tls" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/text/language" +) + +func newTestLogHandler() *logHandler { + return newLogHandler(&defaultLogger{}, Warning, language.English) +} + +func TestNewHttpConnection(t *testing.T) { + t.Run("creates connection with default settings", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost:8182/gremlin", &connectionSettings{}) + + assert.NotNil(t, conn.httpClient) + assert.NotNil(t, conn.httpClient.Transport) + }) + + t.Run("applies TLS config", func(t *testing.T) { + tlsConfig := &tls.Config{InsecureSkipVerify: true} + conn := newHttpConnection(newTestLogHandler(), "https://localhost:8182/gremlin", &connectionSettings{ + tlsConfig: tlsConfig, + }) + + transport := conn.httpClient.Transport.(*http.Transport) + assert.Equal(t, tlsConfig, transport.TLSClientConfig) + }) +} + +func TestSetHttpRequestHeaders(t *testing.T) { + t.Run("sets content type and accept headers", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{}) + req, _ := NewHttpRequest(http.MethodPost, "http://localhost/gremlin") + + conn.setHttpRequestHeaders(req) + + assert.Equal(t, graphBinaryMimeType, req.Headers.Get("Content-Type")) + assert.Equal(t, graphBinaryMimeType, req.Headers.Get("Accept")) + }) + + t.Run("sets user agent when enabled", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{ + enableUserAgentOnConnect: true, + }) + req, _ := NewHttpRequest(http.MethodPost, "http://localhost/gremlin") + + conn.setHttpRequestHeaders(req) + + assert.NotEmpty(t, req.Headers.Get(HeaderUserAgent)) + }) + + t.Run("sets compression header when enabled", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{ + enableCompression: true, + }) + req, _ := NewHttpRequest(http.MethodPost, "http://localhost/gremlin") + + conn.setHttpRequestHeaders(req) + + assert.Equal(t, "deflate", req.Headers.Get("Accept-Encoding")) + }) +} + +func TestGetReader(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{}) + + t.Run("returns body for non-compressed response", func(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + Body: http.NoBody, + } + + reader, closer, err := conn.getReader(resp) + + assert.NoError(t, err) + assert.Nil(t, closer) + assert.Equal(t, resp.Body, reader) + }) + + t.Run("returns zlib reader for deflate response", func(t *testing.T) { + // Valid zlib compressed empty data + zlibData := []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01} + resp := &http.Response{ + Header: http.Header{"Content-Encoding": []string{"deflate"}}, + Body: io.NopCloser(bytes.NewReader(zlibData)), + } + + reader, closer, err := conn.getReader(resp) + + assert.NoError(t, err) + assert.NotNil(t, closer) + assert.NotNil(t, reader) + require.NoError(t, closer.Close()) + }) + + t.Run("returns error for invalid zlib data", func(t *testing.T) { + resp := &http.Response{ + Header: http.Header{"Content-Encoding": []string{"deflate"}}, + Body: io.NopCloser(bytes.NewReader([]byte{0x00, 0x00})), + } + + _, _, err := conn.getReader(resp) + + assert.Error(t, err) + }) +} + +func TestHttpConnectionWithMockServer(t *testing.T) { + t.Run("handles connection error", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost:99999/gremlin", &connectionSettings{ + connectionTimeout: 100 * time.Millisecond, + }) + + rs, err := conn.send(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + assert.NoError(t, err) // send returns nil, error goes to ResultSet + + // All() blocks until stream closes, then we can check error + _, _ = rs.All() + assert.Error(t, rs.GetError()) + }) + + t.Run("receives headers from request", func(t *testing.T) { + headersCh := make(chan http.Header, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headersCh <- r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newHttpConnection(newTestLogHandler(), server.URL, &connectionSettings{ + enableUserAgentOnConnect: true, + enableCompression: true, + }) + + rs, err := conn.send(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + require.NoError(t, err) + + select { + case receivedHeaders := <-headersCh: + assert.Equal(t, graphBinaryMimeType, receivedHeaders.Get("Content-Type")) + assert.Equal(t, "deflate", receivedHeaders.Get("Accept-Encoding")) + assert.NotEmpty(t, receivedHeaders.Get(userAgentHeader)) + case <-time.After(time.Second): + t.Fatal("timeout waiting for request") + } + + _, _ = rs.All() // drain + }) +} diff --git a/gremlin-go/driver/httpProtocol.go b/gremlin-go/driver/httpProtocol.go deleted file mode 100644 index 9c9bba4a91..0000000000 --- a/gremlin-go/driver/httpProtocol.go +++ /dev/null @@ -1,157 +0,0 @@ -/* -Licensed to the Apache Software Foundation (ASF) under one -or more contributor license agreements. See the NOTICE file -distributed with this work for additional information -regarding copyright ownership. The ASF licenses this file -to you under the Apache License, Version 2.0 (the -"License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, -software distributed under the License is distributed on an -"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -KIND, either express or implied. See the License for the -specific language governing permissions and limitations -under the License. -*/ - -package gremlingo - -import ( - "fmt" - "net/http" -) - -// responsible for serializing and sending requests and then receiving and deserializing responses -type httpProtocol struct { - serializer Serializer - logHandler *logHandler - url string - connSettings *connectionSettings - httpClient *http.Client -} - -func newHttpProtocol(handler *logHandler, url string, connSettings *connectionSettings) *httpProtocol { - transport := &http.Transport{ - TLSClientConfig: connSettings.tlsConfig, - MaxConnsPerHost: 0, // TODO - IdleConnTimeout: 0, // TODO - DisableCompression: !connSettings.enableCompression, - } - - httpClient := http.Client{ - Transport: transport, - Timeout: connSettings.connectionTimeout, - } - - httpProt := &httpProtocol{ - serializer: newGraphBinarySerializer(handler), - logHandler: handler, - url: url, - connSettings: connSettings, - httpClient: &httpClient, - } - return httpProt -} - -// sends a query request and returns a ResultSet that can be used to obtain query results -func (protocol *httpProtocol) send(request *request) (ResultSet, error) { - rs := newChannelResultSet() - fmt.Println("Serializing request") - bytes, err := protocol.serializer.SerializeMessage(request) - if err != nil { - rs.setError(err) - rs.Close() - return rs, err - } - - // one transport per request - transport := NewHttpTransporter(protocol.url, protocol.connSettings, protocol.httpClient, protocol.logHandler) - - // async send request - transport.wg.Add(1) - go func() { - defer transport.wg.Done() - err := transport.Write(bytes) - if err != nil { - rs.setError(err) - rs.Close() - } - }() - - // async receive response - transport.wg.Add(1) - go func() { - defer transport.wg.Done() - msg, err := transport.Read() - if err != nil { - rs.setError(err) - rs.Close() - } else { - err = protocol.receive(rs, msg) - } - transport.Close() - }() - - // Wait for both async operations to complete - transport.wg.Wait() - - return rs, rs.GetError() -} - -// receives a binary response message, deserializes, and adds results to the ResultSet -func (protocol *httpProtocol) receive(rs ResultSet, msg []byte) error { - fmt.Println("Deserializing response") - resp, err := protocol.serializer.DeserializeMessage(msg) - if err != nil { - protocol.logHandler.logf(Error, logErrorGeneric, "receive()", err.Error()) - rs.Close() - return err - } - - fmt.Println("Handling response") - err = protocol.handleResponse(rs, resp) - if err != nil { - protocol.logHandler.logf(Error, logErrorGeneric, "receive()", err.Error()) - rs.Close() - return err - } - return nil -} - -// processes a deserialized response and attempts to add results to the ResultSet -func (protocol *httpProtocol) handleResponse(rs ResultSet, response Response) error { - fmt.Println("Handling response") - - statusCode, data := response.ResponseStatus.code, response.ResponseResult.Data - if rs == nil { - return newError(err0501ResponseHandlerResultSetNotCreatedError) - } - - if statusCode == http.StatusNoContent { - rs.addResult(&Result{make([]interface{}, 0)}) - rs.Close() - protocol.logHandler.logf(Debug, readComplete) - } else if statusCode == http.StatusOK { - rs.addResult(&Result{data}) - rs.Close() - protocol.logHandler.logf(Debug, readComplete) - } else if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { - rs.Close() - err := newError(err0503ResponseHandlerAuthError, response.ResponseStatus, response.ResponseResult) - rs.setError(err) - return err - } else { - rs.Close() - err := newError(err0502ResponseHandlerReadLoopError, response.ResponseStatus, statusCode) - rs.setError(err) - return err - } - return nil -} - -func (protocol *httpProtocol) close() { - protocol.httpClient.CloseIdleConnections() -} diff --git a/gremlin-go/driver/httpTransporter.go b/gremlin-go/driver/httpTransporter.go deleted file mode 100644 index 94059a19c1..0000000000 --- a/gremlin-go/driver/httpTransporter.go +++ /dev/null @@ -1,145 +0,0 @@ -/* -Licensed to the Apache Software Foundation (ASF) under one -or more contributor license agreements. See the NOTICE file -distributed with this work for additional information -regarding copyright ownership. The ASF licenses this file -to you under the Apache License, Version 2.0 (the -"License"); you may not use this file except in compliance -with the License. You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, -software distributed under the License is distributed on an -"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -KIND, either express or implied. See the License for the -specific language governing permissions and limitations -under the License. -*/ - -package gremlingo - -import ( - "bytes" - "compress/zlib" - "errors" - "fmt" - "io" - "net/http" - "sync" -) - -// TODO decide channel size when chunked response handling is implemented - for now just set to 1 -const responseChannelSizeDefault = 1 - -// HttpTransporter responsible for sending and receiving bytes to/from the server -type HttpTransporter struct { - url string - isClosed bool - connSettings *connectionSettings - responseChannel chan []byte // receives response bytes from the server - httpClient *http.Client - wg *sync.WaitGroup - logHandler *logHandler -} - -func NewHttpTransporter(url string, connSettings *connectionSettings, httpClient *http.Client, logHandler *logHandler) *HttpTransporter { - wg := &sync.WaitGroup{} - - return &HttpTransporter{ - url: url, - connSettings: connSettings, - responseChannel: make(chan []byte, responseChannelSizeDefault), - httpClient: httpClient, - wg: wg, - logHandler: logHandler, - } -} - -// Write sends bytes to the server as a POST request and sends received response bytes to the responseChannel -func (transporter *HttpTransporter) Write(data []byte) error { - req, err := http.NewRequest("POST", transporter.url, bytes.NewBuffer(data)) - if err != nil { - transporter.logHandler.logf(Error, failedToSendRequest, err.Error()) - return err - } - req.Header.Set("content-type", graphBinaryMimeType) - req.Header.Set("accept", graphBinaryMimeType) - if transporter.connSettings.enableUserAgentOnConnect { - req.Header.Set(userAgentHeader, userAgent) - } - if transporter.connSettings.enableCompression { - req.Header.Set("accept-encoding", "deflate") - } - if transporter.connSettings.authInfo != nil { - // Add custom headers - if headers := transporter.connSettings.authInfo.GetHeader(); headers != nil { - for key, values := range headers { - for _, value := range values { - req.Header.Add(key, value) - } - } - } - - // Add basic auth - if ok, username, password := transporter.connSettings.authInfo.GetBasicAuth(); ok { - req.SetBasicAuth(username, password) - } - } - - fmt.Println("Sending request") - resp, err := transporter.httpClient.Do(req) - if err != nil { - transporter.logHandler.logf(Error, failedToSendRequest, err.Error()) - return err - } - - reader := resp.Body - if resp.Header.Get("content-encoding") == "deflate" { - reader, err = zlib.NewReader(resp.Body) - if err != nil { - transporter.logHandler.logf(Error, failedToReceiveResponse, err.Error()) - return err - } - } - - // TODO handle chunked encoding and send chunks to responseChannel - all, err := io.ReadAll(reader) - if err != nil { - transporter.logHandler.logf(Error, failedToReceiveResponse, err.Error()) - return err - } - err = reader.Close() - if err != nil { - return err - } - - // TODO for debug, remove later, and check response handling - //str := hex.EncodeToString(all) - //_, _ = fmt.Fprintf(os.Stdout, "Received response data : %s\n", str) - - fmt.Println("Sending response to responseChannel") - transporter.responseChannel <- all - return nil -} - -// Read reads bytes from the responseChannel -func (transporter *HttpTransporter) Read() ([]byte, error) { - fmt.Println("Reading from responseChannel") - msg, ok := <-transporter.responseChannel - if !ok { - return []byte{}, errors.New("failed to read from response channel") - } - return msg, nil -} - -// Close closes the transporter and its corresponding responseChannel -func (transporter *HttpTransporter) Close() { - fmt.Println("Closing http transporter") - if !transporter.isClosed { - if transporter.responseChannel != nil { - close(transporter.responseChannel) - } - transporter.isClosed = true - } -} diff --git a/gremlin-go/driver/logger.go b/gremlin-go/driver/logger.go index f1f8c6986a..388ce44a2c 100644 --- a/gremlin-go/driver/logger.go +++ b/gremlin-go/driver/logger.go @@ -109,10 +109,11 @@ const ( submitStartedString errorKey = "SUBMIT_STARTED_STRING" failedToCloseInErrorCallback errorKey = "FAILED_TO_CLOSE_IN_ERROR_CALLBACK" failedToWriteMessage errorKey = "FAILED_TO_WRITE_MESSAGE" - failedToSetWriteDeadline errorKey = "FAILED_TO_SET_WRITE_DEADLINE" failedToReceiveResponse errorKey = "FAILED_TO_RECEIVE_RESPONSE" failedToSendRequest errorKey = "FAILED_TO_SEND_REQUEST" logErrorGeneric errorKey = "LOG_ERROR_GENERIC" closeDriverRemoteConnection errorKey = "CLOSE_DRIVER_REMOTE_CONNECTION" closeClient errorKey = "CLOSE_CLIENT" + failedToCloseResponseBody errorKey = "FAILED_TO_CLOSE_RESPONSE_BODY" + failedToCloseDecompReader errorKey = "FAILED_TO_CLOSE_DECOMPRESSION_READER" ) diff --git a/gremlin-go/driver/performance/performanceSuite.go b/gremlin-go/driver/performance/performanceSuite.go index ba2b493b91..bcad844f67 100644 --- a/gremlin-go/driver/performance/performanceSuite.go +++ b/gremlin-go/driver/performance/performanceSuite.go @@ -378,10 +378,7 @@ func createConnection(host string, port, poolSize, buffersSize int) (*GraphTrave drc, err = gremlingo.NewDriverRemoteConnection(endpoint, func(settings *DriverRemoteConnectionSettings) { settings.LogVerbosity = GremlinWarning settings.TraversalSource = gratefulGraphAlias - settings.NewConnectionThreshold = threshold settings.MaximumConcurrentConnections = poolSize - settings.WriteBufferSize = buffersSize - settings.ReadBufferSize = buffersSize }) if err != nil { diff --git a/gremlin-go/driver/resources/error-messages/en.json b/gremlin-go/driver/resources/error-messages/en.json index cfd070af47..0120b5b77f 100644 --- a/gremlin-go/driver/resources/error-messages/en.json +++ b/gremlin-go/driver/resources/error-messages/en.json @@ -23,6 +23,7 @@ "E0501_PROTOCOL_RESPONSEHANDLER_NO_RESULTSET_ON_DATA_RECEIVE":"E0501: resultSet was not created before data was received", "E0502_PROTOCOL_RESPONSEHANDLER_ERROR": "E0502: error handling response, error message '%+v'. statusCode: %d", + "E0502_PROTOCOL_RESPONSEHANDLER_READ_LOOP_ERROR": "E0502: error in read loop, error message '%v'. statusCode: %d", "E0503_PROTOCOL_RESPONSEHANDLER_AUTH_ERROR":"E0503: failed to authenticate %v : %v", "E0601_RESULT_NOT_VERTEX_ERROR":"E0601: result is not a Vertex", diff --git a/gremlin-go/driver/resources/logger-messages/en.json b/gremlin-go/driver/resources/logger-messages/en.json index 82bdce46a3..f7025e837a 100644 --- a/gremlin-go/driver/resources/logger-messages/en.json +++ b/gremlin-go/driver/resources/logger-messages/en.json @@ -20,5 +20,7 @@ "POOL_NEW_CONNECTION_ERROR": "Falling back to least-used connection. Creating new connection due to least-used connection exceeding concurrent usage threshold failed: %s", "POOL_INITIAL_EXCEEDS_MAXIMUM": "InitialConcurrentConnections setting %d exceeded MaximumConcurrentConnections setting %d - limiting InitialConcurrentConnections to %d.", "FAILED_TO_RECEIVE_RESPONSE": "Failed to receive response: %s", - "FAILED_TO_SEND_REQUEST": "Failed to send request: %s" + "FAILED_TO_SEND_REQUEST": "Failed to send request: %s", + "FAILED_TO_CLOSE_RESPONSE_BODY": "Error closing response body: %s", + "FAILED_TO_CLOSE_DECOMPRESSION_READER": "Error closing decompression reader: %s" } diff --git a/gremlin-go/driver/resultSet.go b/gremlin-go/driver/resultSet.go index 59e22e66d1..cf1230cd53 100644 --- a/gremlin-go/driver/resultSet.go +++ b/gremlin-go/driver/resultSet.go @@ -152,8 +152,8 @@ func (channelResultSet *channelResultSet) All() ([]*Result, error) { func (channelResultSet *channelResultSet) addResult(r *Result) { channelResultSet.channelMutex.Lock() - if r.GetType().Kind() == reflect.Array || r.GetType().Kind() == reflect.Slice { - for _, v := range r.Data.([]interface{}) { + if data, ok := r.Data.([]interface{}); ok { + for _, v := range data { if reflect.TypeOf(v) == reflect.TypeOf(&Traverser{}) { for i := int64(0); i < (v.(*Traverser)).bulk; i++ { channelResultSet.channel <- &Result{(v.(*Traverser)).value} diff --git a/gremlin-go/driver/serializer.go b/gremlin-go/driver/serializer.go index 68be6862ac..98a5f3142c 100644 --- a/gremlin-go/driver/serializer.go +++ b/gremlin-go/driver/serializer.go @@ -35,7 +35,9 @@ type Serializer interface { // GraphBinarySerializer serializes/deserializes message to/from GraphBinary. type GraphBinarySerializer struct { - ser *graphBinaryTypeSerializer + ser *graphBinaryTypeSerializer + bulked bool // State maintained between chunks + buffer []byte // Buffer for incomplete objects across chunks } // CustomTypeReader user provided function to deserialize custom types @@ -56,9 +58,13 @@ func init() { initDeserializers() } -func newGraphBinarySerializer(handler *logHandler) Serializer { +func newGraphBinarySerializer(handler *logHandler) *GraphBinarySerializer { serializer := graphBinaryTypeSerializer{handler} - return GraphBinarySerializer{&serializer} + return &GraphBinarySerializer{ + ser: &serializer, + bulked: false, + buffer: make([]byte, 0), + } } // TODO change for graph binary 4.0 version is finalized @@ -86,8 +92,8 @@ const versionByte byte = 0x81 // bytes, err := serializer.(graphBinarySerializer).SerializeMessage(&req) // // Send bytes over custom transport // -// serializeMessage serializes a request message into GraphBinary. -func (gs GraphBinarySerializer) SerializeMessage(request *request) ([]byte, error) { +// SerializeMessage serializes a request message into GraphBinary. +func (gs *GraphBinarySerializer) SerializeMessage(request *request) ([]byte, error) { finalMessage, err := gs.buildMessage(request.gremlin, request.fields) if err != nil { return nil, err @@ -132,7 +138,7 @@ func (gs *GraphBinarySerializer) buildMessage(gremlin string, args map[string]in // serializer := newGraphBinarySerializer(nil) // resp, err := serializer.(graphBinarySerializer).DeserializeMessage(responseBytes) // results := resp.responseResult.data -func (gs GraphBinarySerializer) DeserializeMessage(message []byte) (Response, error) { +func (gs *GraphBinarySerializer) DeserializeMessage(message []byte) (Response, error) { var msg Response if message == nil || len(message) == 0 { @@ -143,26 +149,20 @@ func (gs GraphBinarySerializer) DeserializeMessage(message []byte) (Response, er //Skip version and nullable byte. i := 2 - // TODO temp serialization before fully streaming set-up for len(message) > 0 { n, err := readFullyQualifiedNullable(&message, &i, true) if err != nil { return msg, err } - // TODO for debug, remove later - //_, _ = fmt.Fprintf(os.Stdout, "Deserializing data : %v\n", n) if n == EndOfStream() { break } results = append(results, n) } - // TODO for debug, remove later - //_, _ = fmt.Fprintf(os.Stdout, "Deserialized results : %s\n", results) msg.ResponseResult.Data = results code := readUint32Safe(&message, &i) msg.ResponseStatus.code = code - // TODO read status message msg.ResponseStatus.message = "OK" statusMsg, err := readUnqualified(&message, &i, stringType, true) if err != nil { @@ -181,6 +181,94 @@ func (gs GraphBinarySerializer) DeserializeMessage(message []byte) (Response, er return msg, nil } +// readChunk processes HTTP chunks with simple buffering +func (gs *GraphBinarySerializer) readChunk(chunk []byte, isFirstChunk bool) (*Response, error) { + var msg Response + + if len(chunk) == 0 { + msg.ResponseStatus.code = 204 + return &msg, nil + } + + // Append to buffer + gs.buffer = append(gs.buffer, chunk...) + + i := 0 + + if isFirstChunk { + if len(gs.buffer) < 2 { + msg.ResponseResult.Data = make([]interface{}, 0) + return &msg, nil + } + i++ // skip version + gs.bulked = (gs.buffer[i] & 1) == 1 + i++ + } + + results := make([]interface{}, 0) + startPos := i + processedPos := i // Track position after last successfully processed object + endOfStream := false + + // Process complete objects, recover from panic on incomplete data + func() { + defer func() { + if r := recover(); r != nil { + if r == ErrIncompleteData { + i = startPos // reset to start of incomplete object + } else { + panic(r) // re-panic for other errors + } + } + }() + for i < len(gs.buffer) { + // Peek before parsing - need at least 2 bytes for type + nullable + if len(gs.buffer)-i < 2 { + break + } + startPos = i + obj, err := readFullyQualifiedNullable(&gs.buffer, &i, true) + if err != nil { + return + } + + if marker, ok := obj.(Marker); ok && marker == EndOfStream() { + gs.bulked = false + // Read status after end of stream + msg.ResponseStatus.code = readUint32Safe(&gs.buffer, &i) + if statusMsg, _ := readUnqualified(&gs.buffer, &i, stringType, true); statusMsg != nil { + msg.ResponseStatus.message = statusMsg.(string) + } + if exception, _ := readUnqualified(&gs.buffer, &i, stringType, true); exception != nil { + msg.ResponseStatus.exception = exception.(string) + } + // Only set endOfStream after successfully reading status + endOfStream = true + gs.buffer = gs.buffer[:0] + return + } + + results = append(results, obj) + processedPos = i // Update after successful processing + } + }() + + msg.ResponseResult.Data = results + + // Keep unprocessed data in buffer (unless we hit end of stream) + if !endOfStream { + if processedPos < len(gs.buffer) { + remaining := make([]byte, len(gs.buffer)-processedPos) + copy(remaining, gs.buffer[processedPos:]) + gs.buffer = remaining + } else { + gs.buffer = gs.buffer[:0] + } + } + + return &msg, nil +} + func initSerializers() { serializers = map[dataType]writer{ stringType: stringWriter, diff --git a/gremlin-go/driver/strategies_test.go b/gremlin-go/driver/strategies_test.go index ee1435890b..2a855be2d5 100644 --- a/gremlin-go/driver/strategies_test.go +++ b/gremlin-go/driver/strategies_test.go @@ -27,11 +27,10 @@ import ( "github.com/stretchr/testify/assert" ) -func getModernGraph(t *testing.T, url string, auth AuthInfoProvider, tls *tls.Config) *GraphTraversalSource { +func getModernGraph(t *testing.T, url string, tls *tls.Config) *GraphTraversalSource { remote, err := NewDriverRemoteConnection(url, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = tls - settings.AuthInfo = auth settings.TraversalSource = testServerModernGraphAlias }) assert.Nil(t, err) @@ -45,7 +44,7 @@ func TestStrategy(t *testing.T) { testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) t.Run("Test read with ConnectiveStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(ConnectiveStrategy()).V().Count().ToList() @@ -58,7 +57,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with OptionsStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(OptionsStrategy(map[string]interface{}{"a": "b"})).V().Count().ToList() @@ -71,7 +70,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with PartitionStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := PartitionStrategyConfig{ @@ -90,7 +89,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with SeedStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := SeedStrategyConfig{Seed: 1} @@ -104,7 +103,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with SubgraphStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := SubgraphStrategyConfig{ Vertices: T__.HasLabel(testLabel), @@ -130,7 +129,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with AdjacentToIncidentStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(AdjacentToIncidentStrategy()).V().Count().ToList() @@ -143,7 +142,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with ByModulatorOptimizationStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(ByModulatorOptimizationStrategy()).V().Count().ToList() @@ -156,7 +155,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with CountStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(CountStrategy()).V().Count().ToList() @@ -169,7 +168,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with EarlyLimitStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(EarlyLimitStrategy()).V().Count().ToList() @@ -182,7 +181,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with FilterRankingStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(FilterRankingStrategy()).V().Count().ToList() @@ -195,7 +194,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with IdentityRemovalStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(IdentityRemovalStrategy()).V().Count().ToList() @@ -208,7 +207,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with IncidentToAdjacentStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(IncidentToAdjacentStrategy()).V().Count().ToList() @@ -221,7 +220,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with InlineFilterStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(InlineFilterStrategy()).V().Count().ToList() @@ -234,7 +233,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with LazyBarrierStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(LazyBarrierStrategy()).V().Count().ToList() @@ -247,7 +246,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with MatchPredicateStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(MatchPredicateStrategy()).V().Count().ToList() @@ -260,7 +259,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with OrderLimitStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(OrderLimitStrategy()).V().Count().ToList() @@ -273,7 +272,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with PathProcessorStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(PathProcessorStrategy()).V().Count().ToList() @@ -286,7 +285,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with PathRetractionStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(PathRetractionStrategy()).V().Count().ToList() @@ -299,7 +298,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with ProductiveByStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := ProductiveByStrategyConfig{ProductiveKeys: []string{"a", "b"}} @@ -313,7 +312,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with RepeatUnrollStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(RepeatUnrollStrategy()).V().Count().ToList() @@ -326,7 +325,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with EdgeLabelVerificationStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := EdgeLabelVerificationStrategyConfig{ @@ -343,7 +342,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with LambdaRestrictionStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(LambdaRestrictionStrategy()).V().Count().ToList() @@ -356,7 +355,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with TestReadOnlyStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(ReadOnlyStrategy()).V().Count().ToList() @@ -369,7 +368,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test write with TestReadOnlyStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() promise := g.WithStrategies(ReadOnlyStrategy()).AddV("person").Property("name", "foo").Iterate() @@ -377,7 +376,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with ReservedKeysVerificationStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() config := ReservedKeysVerificationStrategyConfig{ @@ -396,7 +395,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test read with RepeatUnrollStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithStrategies(RepeatUnrollStrategy()).V().Count().ToList() @@ -409,7 +408,7 @@ func TestStrategy(t *testing.T) { }) t.Run("Test without strategies MessagePassingReductionStrategy", func(t *testing.T) { - g := getModernGraph(t, testNoAuthUrl, &AuthInfo{}, &tls.Config{}) + g := getModernGraph(t, testNoAuthUrl, &tls.Config{}) defer g.remoteConnection.Close() count, err := g.WithoutStrategies(MessagePassingReductionStrategy()).V().Count().ToList() diff --git a/gremlin-go/driver/streamingDeserializer.go b/gremlin-go/driver/streamingDeserializer.go new file mode 100644 index 0000000000..0c69460a1c --- /dev/null +++ b/gremlin-go/driver/streamingDeserializer.go @@ -0,0 +1,617 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package gremlingo + +import ( + "bufio" + "encoding/binary" + "fmt" + "io" + "math" + "math/big" + "reflect" + "time" + + "github.com/google/uuid" +) + +// StreamingDeserializer reads GraphBinary data directly from an io.Reader, +// enabling streaming deserialization of server responses. +// +// Streaming Behavior: +// The deserializer is designed to work with HTTP chunked transfer encoding where +// the server streams results as they become available. Key behaviors: +// +// 1. Blocking on partial data: When reading an object, if the underlying reader +// doesn't have enough bytes available, the deserializer blocks until the data +// arrives. This is handled by io.ReadFull which waits for the exact number of +// bytes needed based on GraphBinary's self-describing format (type codes and +// length prefixes). +// +// 2. Chunk boundary independence: Go's HTTP client may receive data in chunks that +// don't align with server-sent GraphBinary object boundaries. The deserializer +// handles this transparently - it reads exactly the bytes needed for each object, +// blocking if necessary, regardless of how the data was chunked by the network. +// +// 3. Immediate object delivery: Each complete object is returned as soon as it's +// fully read, allowing the caller to process results incrementally rather than +// waiting for the entire response. +// +// The bufio.Reader wrapper provides efficient buffering without affecting the +// streaming semantics - it simply reduces the number of underlying read syscalls. +type StreamingDeserializer struct { + r *bufio.Reader + buf [8]byte + err error // sticky error +} + +// GraphBinary flag for bulked list/set +const flagBulked = 0x02 + +// NewStreamingDeserializer creates a new StreamingDeserializer that reads from the given io.Reader. +// The reader is wrapped in a buffered reader for efficient reading. +func NewStreamingDeserializer(r io.Reader) *StreamingDeserializer { + return &StreamingDeserializer{r: bufio.NewReaderSize(r, 8192)} +} + +func (d *StreamingDeserializer) readByte() (byte, error) { + if d.err != nil { + return 0, d.err + } + b, err := d.r.ReadByte() + if err != nil { + d.err = err + return 0, err + } + return b, nil +} + +func (d *StreamingDeserializer) readBytes(n int) ([]byte, error) { + if d.err != nil { + return nil, d.err + } + buf := make([]byte, n) + _, err := io.ReadFull(d.r, buf) + if err != nil { + d.err = err + return nil, err + } + return buf, nil +} + +func (d *StreamingDeserializer) readInt32() (int32, error) { + if d.err != nil { + return 0, d.err + } + _, err := io.ReadFull(d.r, d.buf[:4]) + if err != nil { + d.err = err + return 0, err + } + return int32(binary.BigEndian.Uint32(d.buf[:4])), nil +} + +func (d *StreamingDeserializer) readUint32() (uint32, error) { + if d.err != nil { + return 0, d.err + } + _, err := io.ReadFull(d.r, d.buf[:4]) + if err != nil { + d.err = err + return 0, err + } + return binary.BigEndian.Uint32(d.buf[:4]), nil +} + +func (d *StreamingDeserializer) readInt64() (int64, error) { + if d.err != nil { + return 0, d.err + } + _, err := io.ReadFull(d.r, d.buf[:8]) + if err != nil { + d.err = err + return 0, err + } + return int64(binary.BigEndian.Uint64(d.buf[:8])), nil +} + +// ReadHeader reads and validates the GraphBinary response header. +// This must be called before reading any objects from the stream. +func (d *StreamingDeserializer) ReadHeader() error { + if _, err := d.readByte(); err != nil { + return err + } + _, err := d.readByte() + return err +} + +// ReadFullyQualified reads the next fully-qualified GraphBinary value from the stream. +// Returns the deserialized object, or an error if reading fails. +// When the end of the result stream is reached, this returns a Marker equal to EndOfStream(). +func (d *StreamingDeserializer) ReadFullyQualified() (interface{}, error) { + dtByte, err := d.readByte() + if err != nil { + return nil, err + } + dt := dataType(dtByte) + if dt == nullType { + if _, err := d.readByte(); err != nil { + return nil, err + } + return nil, nil + } + flag, err := d.readByte() + if err != nil { + return nil, err + } + if flag == valueFlagNull { + return nil, nil + } + return d.readValue(dt, flag) +} + +func (d *StreamingDeserializer) readValue(dt dataType, flag byte) (interface{}, error) { + switch dt { + case intType: + return d.readInt32() + case longType: + return d.readInt64() + case stringType: + return d.readString() + case doubleType: + if d.err != nil { + return nil, d.err + } + if _, err := io.ReadFull(d.r, d.buf[:8]); err != nil { + d.err = err + return nil, err + } + return math.Float64frombits(binary.BigEndian.Uint64(d.buf[:8])), nil + case floatType: + if d.err != nil { + return nil, d.err + } + if _, err := io.ReadFull(d.r, d.buf[:4]); err != nil { + d.err = err + return nil, err + } + return math.Float32frombits(binary.BigEndian.Uint32(d.buf[:4])), nil + case booleanType: + b, err := d.readByte() + return b != 0, err + case byteType: + return d.readByte() + case shortType: + if d.err != nil { + return nil, d.err + } + if _, err := io.ReadFull(d.r, d.buf[:2]); err != nil { + d.err = err + return nil, err + } + return int16(binary.BigEndian.Uint16(d.buf[:2])), nil + case uuidType: + buf, err := d.readBytes(16) + if err != nil { + return nil, err + } + id, err := uuid.FromBytes(buf) + if err != nil { + return nil, err + } + return id, nil + case listType: + return d.readList(flag == flagBulked) + case setType: + list, err := d.readList(flag == flagBulked) + if err != nil { + return nil, err + } + return NewSimpleSet(list.([]interface{})...), nil + case mapType: + return d.readMap() + case vertexType: + return d.readVertex(true) + case edgeType: + return d.readEdge() + case pathType: + return d.readPath() + case propertyType: + return d.readProperty() + case vertexPropertyType: + return d.readVertexProperty() + case bigIntegerType: + return d.readBigInt() + case bigDecimalType: + return d.readBigDecimal() + case datetimeType: + return d.readDateTime() + case durationType: + return d.readDuration() + case markerType: + b, err := d.readByte() + if err != nil { + return nil, err + } + return Of(b) + case byteBuffer: + return d.readByteBuffer() + case tType, directionType, mergeType, gTypeType: + return d.readEnum() + default: + return nil, newError(err0408GetSerializerToReadUnknownTypeError, dt) + } +} + +func (d *StreamingDeserializer) readString() (string, error) { + length, err := d.readInt32() + if err != nil { + return "", err + } + if length == 0 { + return "", nil + } + buf, err := d.readBytes(int(length)) + if err != nil { + return "", err + } + return string(buf), nil +} + +func (d *StreamingDeserializer) readList(bulked bool) (interface{}, error) { + length, err := d.readInt32() + if err != nil { + return nil, err + } + list := make([]interface{}, 0, length) + for i := int32(0); i < length; i++ { + val, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + if bulked { + bulk, err := d.readInt64() + if err != nil { + return nil, err + } + for j := int64(0); j < bulk; j++ { + list = append(list, val) + } + } else { + list = append(list, val) + } + } + return list, nil +} + +func (d *StreamingDeserializer) readMap() (interface{}, error) { + length, err := d.readUint32() + if err != nil { + return nil, err + } + m := make(map[interface{}]interface{}, length) + for i := uint32(0); i < length; i++ { + key, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + val, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + if key == nil { + m[nil] = val + } else if reflect.TypeOf(key).Comparable() { + m[key] = val + } else if reflect.TypeOf(key).Kind() == reflect.Map { + m[&key] = val + } else { + m[fmt.Sprint(key)] = val + } + } + return m, nil +} + +func (d *StreamingDeserializer) readVertex(withProps bool) (*Vertex, error) { + id, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + labels, err := d.readList(false) + if err != nil { + return nil, err + } + labelSlice, ok := labels.([]interface{}) + if !ok || len(labelSlice) == 0 { + return nil, newError(err0404ReadNullTypeError) + } + label, ok := labelSlice[0].(string) + if !ok { + return nil, newError(err0404ReadNullTypeError) + } + v := &Vertex{Element: Element{Id: id, Label: label}} + if withProps { + props, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + v.Properties = make([]interface{}, 0) + if props != nil { + v.Properties = props + } + } + return v, nil +} + +func (d *StreamingDeserializer) readEdge() (*Edge, error) { + id, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + labels, err := d.readList(false) + if err != nil { + return nil, err + } + labelSlice, ok := labels.([]interface{}) + if !ok || len(labelSlice) == 0 { + return nil, newError(err0404ReadNullTypeError) + } + label, ok := labelSlice[0].(string) + if !ok { + return nil, newError(err0404ReadNullTypeError) + } + inV, err := d.readVertex(false) + if err != nil { + return nil, err + } + outV, err := d.readVertex(false) + if err != nil { + return nil, err + } + if _, err := d.readBytes(2); err != nil { + return nil, err + } + props, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + e := &Edge{ + Element: Element{Id: id, Label: label}, + InV: *inV, + OutV: *outV, + } + e.Properties = make([]interface{}, 0) + if props != nil { + e.Properties = props + } + return e, nil +} + +func (d *StreamingDeserializer) readPath() (*Path, error) { + labels, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + objects, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + objectSlice, ok := objects.([]interface{}) + if !ok { + return nil, newError(err0404ReadNullTypeError) + } + path := &Path{Objects: objectSlice} + if labels != nil { + labelSlice, ok := labels.([]interface{}) + if !ok { + return nil, newError(err0404ReadNullTypeError) + } + for _, l := range labelSlice { + set, ok := l.(*SimpleSet) + if !ok { + return nil, newError(err0404ReadNullTypeError) + } + path.Labels = append(path.Labels, set) + } + } + return path, nil +} + +func (d *StreamingDeserializer) readProperty() (*Property, error) { + key, err := d.readString() + if err != nil { + return nil, err + } + value, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + if _, err := d.readBytes(2); err != nil { + return nil, err + } + return &Property{Key: key, Value: value}, nil +} + +func (d *StreamingDeserializer) readVertexProperty() (*VertexProperty, error) { + id, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + labels, err := d.readList(false) + if err != nil { + return nil, err + } + labelSlice, ok := labels.([]interface{}) + if !ok || len(labelSlice) == 0 { + return nil, newError(err0404ReadNullTypeError) + } + label, ok := labelSlice[0].(string) + if !ok { + return nil, newError(err0404ReadNullTypeError) + } + value, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + if _, err := d.readBytes(2); err != nil { + return nil, err + } + props, err := d.ReadFullyQualified() + if err != nil { + return nil, err + } + vp := &VertexProperty{ + Element: Element{Id: id, Label: label}, + Value: value, + } + vp.Properties = make([]interface{}, 0) + if props != nil { + vp.Properties = props + } + return vp, nil +} + +func (d *StreamingDeserializer) readBigInt() (*big.Int, error) { + length, err := d.readInt32() + if err != nil { + return nil, err + } + if length == 0 { + return big.NewInt(0), nil + } + b, err := d.readBytes(int(length)) + if err != nil { + return nil, err + } + bi := big.NewInt(0).SetBytes(b) + if b[0]&0x80 != 0 { + one := big.NewInt(1) + bitLen := uint((len(b)*8)/8+1) * 8 + bi.Sub(bi, new(big.Int).Lsh(one, bitLen)) + } + return bi, nil +} + +func (d *StreamingDeserializer) readBigDecimal() (*BigDecimal, error) { + scale, err := d.readInt32() + if err != nil { + return nil, err + } + unscaled, err := d.readBigInt() + if err != nil { + return nil, err + } + return &BigDecimal{Scale: scale, UnscaledValue: unscaled}, nil +} + +func (d *StreamingDeserializer) readDateTime() (time.Time, error) { + year, err := d.readInt32() + if err != nil { + return time.Time{}, err + } + month, err := d.readByte() + if err != nil { + return time.Time{}, err + } + day, err := d.readByte() + if err != nil { + return time.Time{}, err + } + totalNS, err := d.readInt64() + if err != nil { + return time.Time{}, err + } + offset, err := d.readInt32() + if err != nil { + return time.Time{}, err + } + ns := totalNS % 1e9 + totalS := totalNS / 1e9 + s := totalS % 60 + totalM := totalS / 60 + m := totalM % 60 + h := totalM / 60 + return time.Date(int(year), time.Month(month), int(day), int(h), int(m), int(s), int(ns), GetTimezoneFromOffset(int(offset))), nil +} + +func (d *StreamingDeserializer) readDuration() (time.Duration, error) { + seconds, err := d.readInt64() + if err != nil { + return 0, err + } + nanos, err := d.readInt32() + if err != nil { + return 0, err + } + return time.Duration(seconds*int64(time.Second) + int64(nanos)), nil +} + +func (d *StreamingDeserializer) readByteBuffer() (*ByteBuffer, error) { + length, err := d.readInt32() + if err != nil { + return nil, err + } + data, err := d.readBytes(int(length)) + if err != nil { + return nil, err + } + return &ByteBuffer{Data: data}, nil +} + +func (d *StreamingDeserializer) readEnum() (string, error) { + if _, err := d.readByte(); err != nil { // type code (string) + return "", err + } + if _, err := d.readByte(); err != nil { // null flag + return "", err + } + return d.readString() +} + +// ReadStatus reads the response status after the EndOfStream marker. +// Returns the status code, message, exception string, and any error encountered. +// This should be called after ReadFullyQualified() returns an EndOfStream marker. +func (d *StreamingDeserializer) ReadStatus() (code uint32, message string, exception string, err error) { + code, err = d.readUint32() + if err != nil { + return 0, "", "", err + } + flag, err := d.readByte() + if err != nil { + return code, "", "", err + } + if flag != valueFlagNull { + message, err = d.readString() + if err != nil { + return code, "", "", err + } + } + flag, err = d.readByte() + if err != nil { + return code, message, "", err + } + if flag != valueFlagNull { + exception, err = d.readString() + if err != nil { + return code, message, "", err + } + } + return code, message, exception, nil +} diff --git a/gremlin-go/driver/streamingDeserializer_test.go b/gremlin-go/driver/streamingDeserializer_test.go new file mode 100644 index 0000000000..96cc7cdbe0 --- /dev/null +++ b/gremlin-go/driver/streamingDeserializer_test.go @@ -0,0 +1,402 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "bytes" + "io" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// slowReader simulates a network stream that delivers data in chunks with delays. +// This mimics how Go's HTTP client receives chunked transfer-encoded responses +// where chunk boundaries don't align with GraphBinary object boundaries. +type slowReader struct { + chunks [][]byte + delay time.Duration + index int + offset int + mu sync.Mutex +} + +func newSlowReader(chunks [][]byte, delay time.Duration) *slowReader { + return &slowReader{chunks: chunks, delay: delay} +} + +func (r *slowReader) Read(p []byte) (n int, err error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.index >= len(r.chunks) { + return 0, io.EOF + } + + // Simulate network delay between chunks + if r.offset == 0 && r.index > 0 { + r.mu.Unlock() + time.Sleep(r.delay) + r.mu.Lock() + } + + chunk := r.chunks[r.index] + remaining := chunk[r.offset:] + n = copy(p, remaining) + r.offset += n + + if r.offset >= len(chunk) { + r.index++ + r.offset = 0 + } + + return n, nil +} + +func TestStreamingDeserializer(t *testing.T) { + t.Run("readInt32", func(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x2A} // 42 + d := NewStreamingDeserializer(bytes.NewReader(data)) + val, err := d.readInt32() + assert.Nil(t, err) + assert.Equal(t, int32(42), val) + }) + + t.Run("readInt64", func(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64} // 100 + d := NewStreamingDeserializer(bytes.NewReader(data)) + val, err := d.readInt64() + assert.Nil(t, err) + assert.Equal(t, int64(100), val) + }) + + t.Run("readString", func(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e', 'l', 'l', 'o'} + d := NewStreamingDeserializer(bytes.NewReader(data)) + val, err := d.readString() + assert.Nil(t, err) + assert.Equal(t, "hello", val) + }) + + t.Run("readString empty", func(t *testing.T) { + data := []byte{0x00, 0x00, 0x00, 0x00} + d := NewStreamingDeserializer(bytes.NewReader(data)) + val, err := d.readString() + assert.Nil(t, err) + assert.Equal(t, "", val) + }) + + t.Run("error on incomplete data", func(t *testing.T) { + data := []byte{0x00, 0x00} // incomplete int32 + d := NewStreamingDeserializer(bytes.NewReader(data)) + _, err := d.readInt32() + assert.ErrorIs(t, err, io.ErrUnexpectedEOF) + }) + + t.Run("sticky error", func(t *testing.T) { + data := []byte{0x00} // too short + d := NewStreamingDeserializer(bytes.NewReader(data)) + + _, err1 := d.readInt32() + assert.Error(t, err1) + + // Subsequent reads should also fail + _, err2 := d.readInt32() + assert.Error(t, err2) + }) +} + +func TestStreamingChannelDelivery(t *testing.T) { + t.Run("results arrive incrementally via channel", func(t *testing.T) { + rs := newChannelResultSet() + + // Simulate streaming - send results with delays + go func() { + for i := 0; i < 5; i++ { + time.Sleep(10 * time.Millisecond) + rs.Channel() <- &Result{i} + } + rs.Close() + }() + + var times []time.Duration + start := time.Now() + + for { + _, ok, _ := rs.One() + if !ok { + break + } + times = append(times, time.Since(start)) + } + + assert.Equal(t, 5, len(times)) + + // Results should arrive ~10ms apart, not all at once + for i := 1; i < len(times); i++ { + gap := times[i] - times[i-1] + assert.GreaterOrEqual(t, gap, 5*time.Millisecond, + "Results %d and %d arrived too close together: %v", i-1, i, gap) + } + }) +} + +// TestStreamingBlocksOnPartialData verifies that the deserializer correctly blocks +// when it receives partial data, waiting for the rest of the object to arrive. +// This simulates the real-world scenario where Go's HTTP client receives chunks +// that don't align with server-sent GraphBinary object boundaries. +func TestStreamingBlocksOnPartialData(t *testing.T) { + t.Run("blocks until complete int32 is available", func(t *testing.T) { + // Split a 4-byte int32 across two chunks + chunk1 := []byte{0x00, 0x00} // First 2 bytes + chunk2 := []byte{0x00, 0x2A} // Last 2 bytes (total = 42) + + reader := newSlowReader([][]byte{chunk1, chunk2}, 20*time.Millisecond) + d := NewStreamingDeserializer(reader) + + start := time.Now() + val, err := d.readInt32() + elapsed := time.Since(start) + + assert.Nil(t, err) + assert.Equal(t, int32(42), val) + // Should have blocked waiting for second chunk + assert.GreaterOrEqual(t, elapsed, 15*time.Millisecond, + "Should have blocked waiting for remaining bytes") + }) + + t.Run("blocks until complete string is available", func(t *testing.T) { + // String "hello" split across chunks: + // Chunk 1: length (4 bytes) + partial content + // Chunk 2: remaining content + chunk1 := []byte{0x00, 0x00, 0x00, 0x05, 'h', 'e'} // length=5, "he" + chunk2 := []byte{'l', 'l', 'o'} // "llo" + + reader := newSlowReader([][]byte{chunk1, chunk2}, 20*time.Millisecond) + d := NewStreamingDeserializer(reader) + + start := time.Now() + val, err := d.readString() + elapsed := time.Since(start) + + assert.Nil(t, err) + assert.Equal(t, "hello", val) + assert.GreaterOrEqual(t, elapsed, 15*time.Millisecond, + "Should have blocked waiting for remaining string bytes") + }) +} + +// TestStreamingMultipleObjects verifies that multiple GraphBinary objects +// can be read from a stream, with each object returned as soon as it's complete. +func TestStreamingMultipleObjects(t *testing.T) { + t.Run("reads multiple objects as they arrive", func(t *testing.T) { + // Build a stream with 3 fully-qualified int32 values + // Each int32: type(1) + flag(1) + value(4) = 6 bytes + obj1 := []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x01} // int32 = 1 + obj2 := []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x02} // int32 = 2 + obj3 := []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x03} // int32 = 3 + + // Deliver each object as a separate chunk with delay + reader := newSlowReader([][]byte{obj1, obj2, obj3}, 15*time.Millisecond) + d := NewStreamingDeserializer(reader) + + var results []int32 + var times []time.Duration + start := time.Now() + + for i := 0; i < 3; i++ { + val, err := d.ReadFullyQualified() + assert.Nil(t, err) + results = append(results, val.(int32)) + times = append(times, time.Since(start)) + } + + assert.Equal(t, []int32{1, 2, 3}, results) + + // Objects should arrive with delays between them + for i := 1; i < len(times); i++ { + gap := times[i] - times[i-1] + assert.GreaterOrEqual(t, gap, 10*time.Millisecond, + "Object %d should have arrived after a delay", i) + } + }) + + t.Run("handles object split across chunk boundary", func(t *testing.T) { + // First chunk: complete object + partial second object + // Second chunk: rest of second object + complete third object + chunk1 := []byte{ + 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, // int32 = 1 (complete) + 0x01, 0x00, 0x00, // partial int32 (type + flag + 2 bytes of value) + } + chunk2 := []byte{ + 0x00, 0x00, 0x02, // rest of int32 = 2 + 0x01, 0x00, 0x00, 0x00, 0x00, 0x03, // int32 = 3 (complete) + } + + reader := newSlowReader([][]byte{chunk1, chunk2}, 20*time.Millisecond) + d := NewStreamingDeserializer(reader) + + // First object should return immediately + val1, err := d.ReadFullyQualified() + assert.Nil(t, err) + assert.Equal(t, int32(1), val1) + + // Second object should block waiting for chunk2 + start := time.Now() + val2, err := d.ReadFullyQualified() + elapsed := time.Since(start) + assert.Nil(t, err) + assert.Equal(t, int32(2), val2) + assert.GreaterOrEqual(t, elapsed, 15*time.Millisecond, + "Should have blocked waiting for rest of object") + + // Third object should return immediately (already in buffer) + val3, err := d.ReadFullyQualified() + assert.Nil(t, err) + assert.Equal(t, int32(3), val3) + }) +} + +// TestStreamingWithEndOfStreamMarker verifies that the deserializer correctly +// handles the EndOfStream marker and subsequent status reading. +func TestStreamingWithEndOfStreamMarker(t *testing.T) { + t.Run("reads EndOfStream marker and status", func(t *testing.T) { + // Build a complete response: + // - Header (2 bytes): version + flags + // - One int32 result + // - EndOfStream marker + // - Status: code(4) + message(nullable) + exception(nullable) + data := []byte{ + 0x81, 0x00, // Header: version byte + no bulking + 0x01, 0x00, 0x00, 0x00, 0x00, 0x2A, // int32 = 42 + 0xfd, 0x00, 0x00, // Marker type + flag + value=0 (EndOfStream) + 0x00, 0x00, 0x00, 0xC8, // Status code = 200 + 0x01, // Message is null + 0x01, // Exception is null + } + + d := NewStreamingDeserializer(bytes.NewReader(data)) + + // Read header + err := d.ReadHeader() + assert.Nil(t, err) + + // Read the result + val, err := d.ReadFullyQualified() + assert.Nil(t, err) + assert.Equal(t, int32(42), val) + + // Read EndOfStream marker + marker, err := d.ReadFullyQualified() + assert.Nil(t, err) + assert.Equal(t, EndOfStream(), marker) + + // Read status + code, msg, exc, err := d.ReadStatus() + assert.Nil(t, err) + assert.Equal(t, uint32(200), code) + assert.Equal(t, "", msg) + assert.Equal(t, "", exc) + }) + + t.Run("reads status with message", func(t *testing.T) { + // Status with a message + data := []byte{ + 0xfd, 0x00, 0x00, // EndOfStream marker + 0x00, 0x00, 0x01, 0x90, // Status code = 400 + 0x00, // Message is not null + 0x00, 0x00, 0x00, 0x05, 'e', 'r', 'r', 'o', 'r', // Message = "error" + 0x01, // Exception is null + } + + d := NewStreamingDeserializer(bytes.NewReader(data)) + + marker, err := d.ReadFullyQualified() + assert.Nil(t, err) + assert.Equal(t, EndOfStream(), marker) + + code, msg, exc, err := d.ReadStatus() + assert.Nil(t, err) + assert.Equal(t, uint32(400), code) + assert.Equal(t, "error", msg) + assert.Equal(t, "", exc) + }) +} + +// TestStreamingComplexTypes verifies streaming deserialization of complex types +// like vertices, edges, and paths. +func TestStreamingComplexTypes(t *testing.T) { + t.Run("reads vertex from stream", func(t *testing.T) { + // Vertex: id(int32) + labels(list of string) + properties(nullable) + data := []byte{ + 0x11, 0x00, // Vertex type + flag + // ID: int32 = 1 + 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, + // Labels: list with one string "person" + 0x00, 0x00, 0x00, 0x01, // list length = 1 + 0x03, 0x00, // string type + flag + 0x00, 0x00, 0x00, 0x06, 'p', 'e', 'r', 's', 'o', 'n', + // Properties: null + 0xfe, 0x01, + } + + d := NewStreamingDeserializer(bytes.NewReader(data)) + val, err := d.ReadFullyQualified() + assert.Nil(t, err) + + v, ok := val.(*Vertex) + assert.True(t, ok) + assert.Equal(t, int32(1), v.Id) + assert.Equal(t, "person", v.Label) + }) + + t.Run("reads list of integers from chunked stream", func(t *testing.T) { + // List type split across chunks + chunk1 := []byte{ + 0x09, 0x00, // List type + flag + 0x00, 0x00, 0x00, 0x03, // length = 3 + 0x01, 0x00, 0x00, 0x00, 0x00, 0x0A, // int32 = 10 + } + chunk2 := []byte{ + 0x01, 0x00, 0x00, 0x00, 0x00, 0x14, // int32 = 20 + 0x01, 0x00, 0x00, 0x00, 0x00, 0x1E, // int32 = 30 + } + + reader := newSlowReader([][]byte{chunk1, chunk2}, 15*time.Millisecond) + d := NewStreamingDeserializer(reader) + + start := time.Now() + val, err := d.ReadFullyQualified() + elapsed := time.Since(start) + + assert.Nil(t, err) + list, ok := val.([]interface{}) + assert.True(t, ok) + assert.Equal(t, 3, len(list)) + assert.Equal(t, int32(10), list[0]) + assert.Equal(t, int32(20), list[1]) + assert.Equal(t, int32(30), list[2]) + + // Should have blocked for second chunk + assert.GreaterOrEqual(t, elapsed, 10*time.Millisecond) + }) +} diff --git a/gremlin-go/driver/traversal.go b/gremlin-go/driver/traversal.go index 3ac5e934f3..40e8d79785 100644 --- a/gremlin-go/driver/traversal.go +++ b/gremlin-go/driver/traversal.go @@ -44,7 +44,6 @@ func (t *Traversal) ToList() ([]*Result, error) { return nil, newError(err0901ToListAnonTraversalError) } - // TODO update and test when connection is set up results, err := t.remote.submitGremlinLang(t.GremlinLang) if err != nil { return nil, err @@ -80,7 +79,6 @@ func (t *Traversal) Iterate() <-chan error { t.GremlinLang.AddStep("discard") - // TODO update and test when connection is set up res, err := t.remote.submitGremlinLang(t.GremlinLang) if err != nil { r <- err @@ -124,7 +122,6 @@ func (t *Traversal) Next() (*Result, error) { // GetResultSet submits the traversal and returns the ResultSet. func (t *Traversal) GetResultSet() (ResultSet, error) { if t.results == nil { - // TODO update and test when connection is set up results, err := t.remote.submitGremlinLang(t.GremlinLang) if err != nil { return nil, err @@ -749,33 +746,6 @@ var IO = ioconfig{ Registry: "~tinkerpop.ioconfig.registry", } -// TODO pending update/removal -// Metrics holds metrics data; typically for .profile()-step analysis. Metrics may be nested. Nesting enables -// the ability to capture explicit metrics for multiple distinct operations. Annotations are used to store -// miscellaneous notes that might be useful to a developer when examining results, such as index coverage -// for Steps in a Traversal. -//type Metrics struct { -// Id string -// Name string -// // the duration in nanoseconds. -// Duration int64 -// Counts map[string]int64 -// Annotations map[string]interface{} -// NestedMetrics []Metrics -//} - -// TraversalMetrics contains the Metrics gathered for a Traversal as the result of the .profile()-step. -//type TraversalMetrics struct { -// // the duration in nanoseconds. -// Duration int64 -// Metrics []Metrics -//} - -// GremlinType represents the GraphBinary type Class which can be used to serialize a class. -//type GremlinType struct { -// Fqcn string -//} - // BigDecimal represents an arbitrary-precision signed decimal number, consisting of an arbitrary precision integer // unscaled value and a 32-bit integer scale. type BigDecimal struct { diff --git a/gremlin-go/driver/traversal_test.go b/gremlin-go/driver/traversal_test.go index 1fdf1a52d9..0efded3ac9 100644 --- a/gremlin-go/driver/traversal_test.go +++ b/gremlin-go/driver/traversal_test.go @@ -547,13 +547,11 @@ func TestTraversal(t *testing.T) { func newWithOptionsConnection(t *testing.T) *GraphTraversalSource { // No authentication integration test with graphs loaded and alias configured server testNoAuthWithAliasUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) - testNoAuthWithAliasAuthInfo := &AuthInfo{} testNoAuthWithAliasTlsConfig := &tls.Config{} remote, err := NewDriverRemoteConnection(testNoAuthWithAliasUrl, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = testNoAuthWithAliasTlsConfig - settings.AuthInfo = testNoAuthWithAliasAuthInfo settings.TraversalSource = "gmodern" }) assert.Nil(t, err) @@ -563,13 +561,11 @@ func newWithOptionsConnection(t *testing.T) *GraphTraversalSource { func newConnection(t *testing.T) *DriverRemoteConnection { testNoAuthWithAliasUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) - testNoAuthWithAliasAuthInfo := &AuthInfo{} testNoAuthWithAliasTlsConfig := &tls.Config{} remote, err := NewDriverRemoteConnection(testNoAuthWithAliasUrl, func(settings *DriverRemoteConnectionSettings) { settings.TlsConfig = testNoAuthWithAliasTlsConfig - settings.AuthInfo = testNoAuthWithAliasAuthInfo settings.TraversalSource = "gtx" }) assert.Nil(t, err) diff --git a/gremlin-go/go.mod b/gremlin-go/go.mod index c36ea3718d..169beee051 100644 --- a/gremlin-go/go.mod +++ b/gremlin-go/go.mod @@ -28,6 +28,20 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2 v1.41.1 // indirect + github.com/aws/aws-sdk-go-v2/config v1.32.7 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.7 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.9 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 // indirect + github.com/aws/smithy-go v1.24.0 // indirect github.com/cucumber/gherkin/go/v26 v26.2.0 // indirect github.com/cucumber/messages/go/v21 v21.0.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/gremlin-go/go.sum b/gremlin-go/go.sum index 59675a75c5..cc7f1d6a7d 100644 --- a/gremlin-go/go.sum +++ b/gremlin-go/go.sum @@ -1,5 +1,33 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/aws/aws-sdk-go-v2 v1.41.1 h1:ABlyEARCDLN034NhxlRUSZr4l71mh+T5KAeGh6cerhU= +github.com/aws/aws-sdk-go-v2 v1.41.1/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2/config v1.32.7 h1:vxUyWGUwmkQ2g19n7JY/9YL8MfAIl7bTesIUykECXmY= +github.com/aws/aws-sdk-go-v2/config v1.32.7/go.mod h1:2/Qm5vKUU/r7Y+zUk/Ptt2MDAEKAfUtKc1+3U1Mo3oY= +github.com/aws/aws-sdk-go-v2/credentials v1.19.7 h1:tHK47VqqtJxOymRrNtUXN5SP/zUTvZKeLx4tH6PGQc8= +github.com/aws/aws-sdk-go-v2/credentials v1.19.7/go.mod h1:qOZk8sPDrxhf+4Wf4oT2urYJrYt3RejHSzgAquYeppw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17 h1:I0GyV8wiYrP8XpA70g1HBcQO1JlQxCMTW9npl5UbDHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17/go.mod h1:tyw7BOl5bBe/oqvoIeECFJjMdzXoa/dfVz3QQ5lgHGA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 h1:xOLELNKGp2vsiteLsvLPwxC+mYmO6OZ8PYgiuPJzF8U= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17/go.mod h1:5M5CI3D12dNOtH3/mk6minaRwI2/37ifCURZISxA/IQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 h1:WWLqlh79iO48yLkj1v3ISRNiv+3KdQoZ6JWyfcsyQik= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17/go.mod h1:EhG22vHRrvF8oXSTYStZhJc1aUgKtnJe+aOiFEV90cM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17 h1:RuNSMoozM8oXlgLG/n6WLaFGoea7/CddrCfIiSA+xdY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.17/go.mod h1:F2xxQ9TZz5gDWsclCtPQscGpP0VUOc8RqgFM3vDENmU= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.5 h1:VrhDvQib/i0lxvr3zqlUwLwJP4fpmpyD9wYG1vfSu+Y= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.5/go.mod h1:k029+U8SY30/3/ras4G/Fnv/b88N4mAfliNn08Dem4M= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.9 h1:v6EiMvhEYBoHABfbGB4alOYmCIrcgyPPiBE1wZAEbqk= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.9/go.mod h1:yifAsgBxgJWn3ggx70A3urX2AN49Y5sJTD1UQFlfqBw= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.13 h1:gd84Omyu9JLriJVCbGApcLzVR3XtmC4ZDPcAI6Ftvds= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.13/go.mod h1:sTGThjphYE4Ohw8vJiRStAcu3rbjtXRsdNB0TvZ5wwo= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.6 h1:5fFjR/ToSOzB2OQ/XqWpZBmNvmP/pJ1jOWYlFDJTjRQ= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.6/go.mod h1:qgFDZQSD/Kys7nJnVqYlWKnh0SSdMjAi0uSwON4wgYQ= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cucumber/gherkin/go/v26 v26.2.0 h1:EgIjePLWiPeslwIWmNQ3XHcypPsWAHoMCz/YEBKP4GI= github.com/cucumber/gherkin/go/v26 v26.2.0/go.mod h1:t2GAPnB8maCT4lkHL99BDCVNzCh1d7dBhCLt150Nr/0=
