This is an automated email from the ASF dual-hosted git repository. xiazcy pushed a commit to branch go-http-streaming in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit c3d66cb0c134459992493662181f1ba9c5d33b4c Author: Yang Xia <[email protected]> AuthorDate: Thu Jan 15 14:22:27 2026 -0800 combined httpProtocol & Transport --- gremlin-go/driver/client.go | 16 +- gremlin-go/driver/driverRemoteConnection.go | 4 +- gremlin-go/driver/httpConnection.go | 196 +++++++++++++++++++++++++ gremlin-go/driver/httpConnection_test.go | 217 ++++++++++++++++++++++++++++ gremlin-go/driver/streamingDeserializer.go | 7 +- 5 files changed, 431 insertions(+), 9 deletions(-) diff --git a/gremlin-go/driver/client.go b/gremlin-go/driver/client.go index d8777d873d..afc8924ce0 100644 --- a/gremlin-go/driver/client.go +++ b/gremlin-go/driver/client.go @@ -60,13 +60,19 @@ type ClientSettings struct { EnableUserAgentOnConnect bool } +// protocol defines the interface for HTTP communication with Gremlin server +type protocol interface { + send(request *request) (ResultSet, error) + close() +} + // Client is used to connect and interact with a Gremlin-supported server. type Client struct { url string traversalSource string logHandler *logHandler connectionSettings *connectionSettings - httpProtocol *httpProtocol + protocol protocol } // NewClient creates a Client and configures it with the given parameters. @@ -111,14 +117,14 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C logHandler := newLogHandler(settings.Logger, settings.LogVerbosity, settings.Language) - httpProt := newHttpProtocol(logHandler, url, connSettings) + conn := newHttpConnection(logHandler, url, connSettings) client := &Client{ url: url, traversalSource: settings.TraversalSource, logHandler: logHandler, connectionSettings: connSettings, - httpProtocol: httpProt, + protocol: conn, } return client, nil @@ -142,7 +148,7 @@ func (client *Client) SubmitWithOptions(traversalString string, requestOptions R // TODO interceptors (ie. auth) - rs, err := client.httpProtocol.send(&request) + rs, err := client.protocol.send(&request) return rs, err } @@ -171,7 +177,7 @@ func (client *Client) submitGremlinLang(gremlinLang *GremlinLang) (ResultSet, er } request := MakeStringRequest(gremlinLang.GetGremlin(), client.traversalSource, requestOptionsBuilder.Create()) - return client.httpProtocol.send(&request) + return client.protocol.send(&request) } func applyOptionsConfig(builder *RequestOptionsBuilder, config map[string]interface{}) *RequestOptionsBuilder { diff --git a/gremlin-go/driver/driverRemoteConnection.go b/gremlin-go/driver/driverRemoteConnection.go index e4ec91d7bf..2d0c09a099 100644 --- a/gremlin-go/driver/driverRemoteConnection.go +++ b/gremlin-go/driver/driverRemoteConnection.go @@ -100,14 +100,14 @@ func NewDriverRemoteConnection( logHandler := newLogHandler(settings.Logger, settings.LogVerbosity, settings.Language) - httpProt := newHttpProtocol(logHandler, url, connSettings) + conn := newHttpConnection(logHandler, url, connSettings) client := &Client{ url: url, traversalSource: settings.TraversalSource, logHandler: logHandler, connectionSettings: connSettings, - httpProtocol: httpProt, + protocol: conn, } return &DriverRemoteConnection{client: client, isClosed: false, settings: settings}, nil diff --git a/gremlin-go/driver/httpConnection.go b/gremlin-go/driver/httpConnection.go new file mode 100644 index 0000000000..2a0d203744 --- /dev/null +++ b/gremlin-go/driver/httpConnection.go @@ -0,0 +1,196 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "bytes" + "compress/zlib" + "io" + "net/http" + "time" +) + +// httpConnection handles HTTP request/response for Gremlin queries +type httpConnection struct { + url string + httpClient *http.Client + connSettings *connectionSettings + logHandler *logHandler + serializer *GraphBinarySerializer +} + +// Connection pool defaults aligned with Java driver +const ( + defaultMaxConnsPerHost = 128 // Java: ConnectionPool.MAX_POOL_SIZE + defaultMaxIdleConnsPerHost = 8 // Keep some connections warm + defaultIdleConnTimeout = 180 * time.Second // Java: CONNECTION_IDLE_TIMEOUT_MILLIS + defaultConnectionTimeout = 15 * time.Second // Java: CONNECTION_SETUP_TIMEOUT_MILLIS +) + +func newHttpConnection(handler *logHandler, url string, connSettings *connectionSettings) *httpConnection { + timeout := connSettings.connectionTimeout + if timeout == 0 { + timeout = defaultConnectionTimeout + } + + transport := &http.Transport{ + TLSClientConfig: connSettings.tlsConfig, + MaxConnsPerHost: defaultMaxConnsPerHost, + MaxIdleConnsPerHost: defaultMaxIdleConnsPerHost, + IdleConnTimeout: defaultIdleConnTimeout, + DisableCompression: !connSettings.enableCompression, + } + + return &httpConnection{ + url: url, + httpClient: &http.Client{Transport: transport, Timeout: timeout}, + connSettings: connSettings, + logHandler: handler, + serializer: newGraphBinarySerializer(handler), + } +} + +// send sends request and streams results directly to ResultSet +func (c *httpConnection) send(req *request) (ResultSet, error) { + rs := newChannelResultSet() + + data, err := c.serializer.SerializeMessage(req) + if err != nil { + rs.Close() + return rs, err + } + + go c.executeAndStream(data, rs) + + return rs, nil +} + +func (c *httpConnection) executeAndStream(data []byte, rs ResultSet) { + defer rs.Close() + + req, err := http.NewRequest(http.MethodPost, c.url, bytes.NewReader(data)) + if err != nil { + c.logHandler.logf(Error, failedToSendRequest, err.Error()) + rs.setError(err) + return + } + + c.setHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + c.logHandler.logf(Error, failedToSendRequest, err.Error()) + rs.setError(err) + return + } + defer resp.Body.Close() + + reader, zlibReader, err := c.getReader(resp) + if err != nil { + c.logHandler.logf(Error, failedToReceiveResponse, err.Error()) + rs.setError(err) + return + } + if zlibReader != nil { + defer zlibReader.Close() + } + + c.streamToResultSet(reader, rs) +} + +func (c *httpConnection) setHeaders(req *http.Request) { + req.Header.Set("Content-Type", graphBinaryMimeType) + req.Header.Set("Accept", graphBinaryMimeType) + + if c.connSettings.enableUserAgentOnConnect { + req.Header.Set(userAgentHeader, userAgent) + } + if c.connSettings.enableCompression { + req.Header.Set("Accept-Encoding", "deflate") + } + if c.connSettings.authInfo != nil { + if headers := c.connSettings.authInfo.GetHeader(); headers != nil { + for k, vals := range headers { + for _, v := range vals { + req.Header.Add(k, v) + } + } + } + if ok, user, pass := c.connSettings.authInfo.GetBasicAuth(); ok { + req.SetBasicAuth(user, pass) + } + } +} + +func (c *httpConnection) getReader(resp *http.Response) (io.Reader, io.Closer, error) { + if resp.Header.Get("Content-Encoding") == "deflate" { + zr, err := zlib.NewReader(resp.Body) + if err != nil { + return nil, nil, err + } + return zr, zr, nil + } + return resp.Body, nil, nil +} + +func (c *httpConnection) streamToResultSet(reader io.Reader, rs ResultSet) { + d := newStreamingDeserializer(reader) + if err := d.readHeader(); err != nil { + if err != io.EOF { + c.logHandler.logf(Error, failedToReceiveResponse, err.Error()) + rs.setError(err) + } + return + } + + for { + obj, err := d.readFullyQualified() + if err != nil { + if err != io.EOF { + c.logHandler.logf(Error, failedToReceiveResponse, err.Error()) + rs.setError(err) + } + return + } + + if marker, ok := obj.(Marker); ok && marker == EndOfStream() { + code, msg, exc, err := d.readStatus() + if err != nil { + c.logHandler.logf(Error, failedToReceiveResponse, err.Error()) + rs.setError(err) + return + } + if code != 200 && code != 0 { + if exc != "" { + rs.setError(newError(err0502ResponseHandlerReadLoopError, exc, code)) + } else { + rs.setError(newError(err0502ResponseHandlerReadLoopError, msg, code)) + } + } + return + } + + rs.Channel() <- &Result{obj} + } +} + +func (c *httpConnection) close() { + c.httpClient.CloseIdleConnections() +} diff --git a/gremlin-go/driver/httpConnection_test.go b/gremlin-go/driver/httpConnection_test.go new file mode 100644 index 0000000000..f106cac66e --- /dev/null +++ b/gremlin-go/driver/httpConnection_test.go @@ -0,0 +1,217 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +*/ + +package gremlingo + +import ( + "bytes" + "crypto/tls" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/text/language" +) + +func newTestLogHandler() *logHandler { + return newLogHandler(&defaultLogger{}, Warning, language.English) +} + +func TestNewHttpConnection(t *testing.T) { + t.Run("applies default timeout when not set", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost:8182/gremlin", &connectionSettings{}) + + assert.Equal(t, defaultConnectionTimeout, conn.httpClient.Timeout) + }) + + t.Run("uses provided timeout", func(t *testing.T) { + customTimeout := 30 * time.Second + conn := newHttpConnection(newTestLogHandler(), "http://localhost:8182/gremlin", &connectionSettings{ + connectionTimeout: customTimeout, + }) + + assert.Equal(t, customTimeout, conn.httpClient.Timeout) + }) + + t.Run("applies TLS config", func(t *testing.T) { + tlsConfig := &tls.Config{InsecureSkipVerify: true} + conn := newHttpConnection(newTestLogHandler(), "https://localhost:8182/gremlin", &connectionSettings{ + tlsConfig: tlsConfig, + }) + + transport := conn.httpClient.Transport.(*http.Transport) + assert.Equal(t, tlsConfig, transport.TLSClientConfig) + }) +} + +func TestSetHeaders(t *testing.T) { + t.Run("sets content type and accept headers", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{}) + req, err := http.NewRequest(http.MethodPost, "http://localhost/gremlin", nil) + require.NoError(t, err) + + conn.setHeaders(req) + + assert.Equal(t, graphBinaryMimeType, req.Header.Get("Content-Type")) + assert.Equal(t, graphBinaryMimeType, req.Header.Get("Accept")) + }) + + t.Run("sets user agent when enabled", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{ + enableUserAgentOnConnect: true, + }) + req, err := http.NewRequest(http.MethodPost, "http://localhost/gremlin", nil) + require.NoError(t, err) + + conn.setHeaders(req) + + assert.NotEmpty(t, req.Header.Get(userAgentHeader)) + }) + + t.Run("sets compression header when enabled", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{ + enableCompression: true, + }) + req, err := http.NewRequest(http.MethodPost, "http://localhost/gremlin", nil) + require.NoError(t, err) + + conn.setHeaders(req) + + assert.Equal(t, "deflate", req.Header.Get("Accept-Encoding")) + }) + + t.Run("sets basic auth when provided", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{ + authInfo: BasicAuthInfo("user", "pass"), + }) + req, err := http.NewRequest(http.MethodPost, "http://localhost/gremlin", nil) + require.NoError(t, err) + + conn.setHeaders(req) + + user, pass, ok := req.BasicAuth() + assert.True(t, ok) + assert.Equal(t, "user", user) + assert.Equal(t, "pass", pass) + }) + + t.Run("handles nil authInfo", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{ + authInfo: nil, + }) + req, err := http.NewRequest(http.MethodPost, "http://localhost/gremlin", nil) + require.NoError(t, err) + + conn.setHeaders(req) + + _, _, ok := req.BasicAuth() + assert.False(t, ok) + }) +} + +func TestGetReader(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{}) + + t.Run("returns body for non-compressed response", func(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + Body: http.NoBody, + } + + reader, closer, err := conn.getReader(resp) + + assert.NoError(t, err) + assert.Nil(t, closer) + assert.Equal(t, resp.Body, reader) + }) + + t.Run("returns zlib reader for deflate response", func(t *testing.T) { + // Valid zlib compressed empty data + zlibData := []byte{0x78, 0x9c, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01} + resp := &http.Response{ + Header: http.Header{"Content-Encoding": []string{"deflate"}}, + Body: io.NopCloser(bytes.NewReader(zlibData)), + } + + reader, closer, err := conn.getReader(resp) + + assert.NoError(t, err) + assert.NotNil(t, closer) + assert.NotNil(t, reader) + require.NoError(t, closer.Close()) + }) + + t.Run("returns error for invalid zlib data", func(t *testing.T) { + resp := &http.Response{ + Header: http.Header{"Content-Encoding": []string{"deflate"}}, + Body: io.NopCloser(bytes.NewReader([]byte{0x00, 0x00})), + } + + _, _, err := conn.getReader(resp) + + assert.Error(t, err) + }) +} + +func TestHttpConnectionWithMockServer(t *testing.T) { + t.Run("handles connection error", func(t *testing.T) { + conn := newHttpConnection(newTestLogHandler(), "http://localhost:99999/gremlin", &connectionSettings{ + connectionTimeout: 100 * time.Millisecond, + }) + + rs, err := conn.send(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + assert.NoError(t, err) // send returns nil, error goes to ResultSet + + // All() blocks until stream closes, then we can check error + _, _ = rs.All() + assert.Error(t, rs.GetError()) + }) + + t.Run("receives headers from request", func(t *testing.T) { + headersCh := make(chan http.Header, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headersCh <- r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newHttpConnection(newTestLogHandler(), server.URL, &connectionSettings{ + enableUserAgentOnConnect: true, + enableCompression: true, + }) + + rs, err := conn.send(&request{gremlin: "g.V()", fields: map[string]interface{}{}}) + require.NoError(t, err) + + select { + case receivedHeaders := <-headersCh: + assert.Equal(t, graphBinaryMimeType, receivedHeaders.Get("Content-Type")) + assert.Equal(t, "deflate", receivedHeaders.Get("Accept-Encoding")) + assert.NotEmpty(t, receivedHeaders.Get(userAgentHeader)) + case <-time.After(time.Second): + t.Fatal("timeout waiting for request") + } + + _, _ = rs.All() // drain + }) +} diff --git a/gremlin-go/driver/streamingDeserializer.go b/gremlin-go/driver/streamingDeserializer.go index 949c3e3b00..e5541625f5 100644 --- a/gremlin-go/driver/streamingDeserializer.go +++ b/gremlin-go/driver/streamingDeserializer.go @@ -40,6 +40,9 @@ type streamingDeserializer struct { err error // sticky error } +// GraphBinary flag for bulked list/set +const flagBulked = 0x02 + func newStreamingDeserializer(r io.Reader) *streamingDeserializer { return &streamingDeserializer{r: bufio.NewReaderSize(r, 8192)} } @@ -186,9 +189,9 @@ func (d *streamingDeserializer) readValue(dt dataType, flag byte) (interface{}, } return id, nil case listType: - return d.readList(flag == 0x02) + return d.readList(flag == flagBulked) case setType: - list, err := d.readList(flag == 0x02) + list, err := d.readList(flag == flagBulked) if err != nil { return nil, err }
