This is an automated email from the ASF dual-hosted git repository.
joaoreis pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git
The following commit(s) were added to refs/heads/trunk by this push:
new f1e31a58 Protocol version negotiation doesn't work if server replies
with stream id different than 0
f1e31a58 is described below
commit f1e31a58f7e0c25e58e2e2a0a0c6de358e643e8b
Author: Bohdan Siryk <[email protected]>
AuthorDate: Fri Nov 28 10:14:22 2025 +0200
Protocol version negotiation doesn't work if server replies with stream id
different than 0
Previously, protocol negotiation didn't work properly when C* was
responding with stream id different from 0.
This patch changes the way protocol negotiation works. Instead of parsing a
supported protocol version from C* error response, the driver tries to connect
with each supported protocol starting from the latest.
Patch by Bohdan Siryk; Reviewed by João Reis for CASSGO-98
---
CHANGELOG.md | 1 +
Makefile | 2 +-
conn.go | 51 ++++++++-
conn_test.go | 78 +++++++++++--
control.go | 98 ++++++----------
control_test.go | 35 ------
errors.go | 10 ++
frame.go | 15 ++-
protocol_negotiation_test.go | 267 +++++++++++++++++++++++++++++++++++++++++++
9 files changed, 442 insertions(+), 115 deletions(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index cd327073..872ebdd1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -17,6 +17,7 @@ and this project adheres to [Semantic
Versioning](https://semver.org/spec/v2.0.0
- Prevent panic with queries during session init (CASSGO-92)
- Return correct values from RowData (CASSGO-95)
- Prevent setting a compression flag in a frame header when native proto v5 is
being used (CASSGO-98)
+- Use protocol downgrading approach during protocol negotiation (CASSGO-97)
## [2.0.0]
diff --git a/Makefile b/Makefile
index 56ed015d..c4f6df21 100644
--- a/Makefile
+++ b/Makefile
@@ -100,7 +100,7 @@ test-integration-auth: .prepare-cassandra-cluster
test-unit:
@echo "Run unit tests"
@go clean -testcache
- go test -tags unit -timeout=5m -race ./...
+ go test -v -tags unit -timeout=5m -race ./...
check: .prepare-golangci
@echo "Build"
diff --git a/conn.go b/conn.go
index 40044565..a9bb4f5a 100644
--- a/conn.go
+++ b/conn.go
@@ -378,6 +378,13 @@ func (s *startupCoordinator) setupConn(ctx
context.Context) error {
select {
case err := <-startupErr:
if err != nil {
+ if s.checkProtocolRelatedError(err) {
+ return &unsupportedProtocolVersionError{
+ err: err,
+ hostInfo: s.conn.host,
+ version: protoVersion(s.conn.version),
+ }
+ }
return err
}
case <-ctx.Done():
@@ -387,6 +394,38 @@ func (s *startupCoordinator) setupConn(ctx
context.Context) error {
return nil
}
+// Checks if the error is protocol related and should be retried during
startup.
+// It returns the frame that caused the error and whether the error should be
retried.
+func (s *startupCoordinator) checkProtocolRelatedError(err error) bool {
+ var unwrappedFrame frame
+
+ var protocolErr *protocolError
+ if !errors.As(err, &protocolErr) {
+ var errFrame errorFrame
+ if !errors.As(err, &errFrame) {
+ return false
+ } else {
+ unwrappedFrame = errFrame
+ }
+ } else {
+ unwrappedFrame = protocolErr.frame
+ }
+
+ switch frame := unwrappedFrame.(type) {
+ case *supportedFrame:
+ // We can receive a supportedFrame wrapped in protocolError
from Conn.recv if the host responds to a 0 stream id.
+ // If we receive a supportedFrame then we know that the host is
not compatible with the protocol version, but it is reachable, so we can retry
+ return true
+ case errorFrame:
+ // If we receive an errorFrame with codes ErrCodeProtocol or
ErrCodeServer,
+ // then we should try to downgrade a protocol version, so do
not skip the host
+ return frame.code == ErrCodeProtocol || frame.code ==
ErrCodeServer
+ default:
+ // In any other case we should not retry as it means the host
is not reachable or some other error happened
+ return false
+ }
+}
+
func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder,
startupCompleted *atomic.Bool) (frame, error) {
select {
case s.frameTicker <- struct{}{}:
@@ -408,12 +447,14 @@ func (s *startupCoordinator) options(ctx context.Context,
startupCompleted *atom
return err
}
- supported, ok := frame.(*supportedFrame)
- if !ok {
- return NewErrProtocol("Unknown type of response to startup
frame: %T", frame)
+ switch frame := frame.(type) {
+ case *supportedFrame:
+ return s.startup(ctx, frame.supported, startupCompleted)
+ case error:
+ return frame
+ default:
+ return NewErrProtocol("Unknown type of response to startup
frame: %T (frame=%s)", frame, frame.String())
}
-
- return s.startup(ctx, supported.supported, startupCompleted)
}
func (s *startupCoordinator) startup(ctx context.Context, supported
map[string][]string, startupCompleted *atomic.Bool) error {
diff --git a/conn_test.go b/conn_test.go
index 60e4a2a8..ad4e66e5 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -1054,6 +1054,9 @@ type newTestServerOpts struct {
addr string
protocol uint8
recvHook func(*framer)
+
+ customRequestHandler func(srv *TestServer, reqFrame, respFrame
*framer) error
+ dontFailOnProtocolMismatch bool
}
func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context)
*TestServer {
@@ -1078,6 +1081,9 @@ func (nts newTestServerOpts) newServer(t testing.TB, ctx
context.Context) *TestS
cancel: cancel,
onRecv: nts.recvHook,
+
+ customRequestHandler: nts.customRequestHandler,
+ dontFailOnProtocolMismatch: nts.dontFailOnProtocolMismatch,
}
go srv.closeWatch()
@@ -1142,6 +1148,10 @@ type TestServer struct {
// onRecv is a hook point for tests, called in receive loop.
onRecv func(*framer)
+
+ // customRequestHandler allows overriding the default request handling
for testing purposes.
+ customRequestHandler func(srv *TestServer, reqFrame, respFrame
*framer) error
+ dontFailOnProtocolMismatch bool
}
func (srv *TestServer) closeWatch() {
@@ -1162,9 +1172,26 @@ func (srv *TestServer) serve() {
}
go func(conn net.Conn) {
+ var startupCompleted bool
+ var useProtoV5 bool
+
defer conn.Close()
for !srv.isClosed() {
- framer, err := srv.readFrame(conn)
+ var reader io.Reader = conn
+
+ if useProtoV5 && startupCompleted {
+ frame, _, err :=
readUncompressedSegment(conn)
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ return
+ }
+ srv.errorLocked(err)
+ return
+ }
+ reader = bytes.NewReader(frame)
+ }
+
+ framer, err := srv.readFrame(reader)
if err != nil {
if err == io.EOF {
return
@@ -1177,7 +1204,7 @@ func (srv *TestServer) serve() {
srv.onRecv(framer)
}
- go srv.process(conn, framer)
+ srv.process(conn, framer, &useProtoV5,
&startupCompleted)
}
}(conn)
}
@@ -1215,13 +1242,22 @@ func (srv *TestServer) errorLocked(err interface{}) {
srv.t.Error(err)
}
-func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
+func (srv *TestServer) process(conn net.Conn, reqFrame *framer, useProtoV5,
startupCompleted *bool) {
head := reqFrame.header
if head == nil {
srv.errorLocked("process frame with a nil header")
return
}
- respFrame := newFramer(nil, reqFrame.proto, GlobalTypes)
+ respFrame := newFramer(nil, byte(head.version), GlobalTypes)
+
+ if srv.customRequestHandler != nil {
+ if err := srv.customRequestHandler(srv, reqFrame, respFrame);
err != nil {
+ srv.errorLocked(err)
+ return
+ }
+ // Dont like this but...
+ goto finish
+ }
switch head.op {
case opStartup:
@@ -1412,26 +1448,46 @@ func (srv *TestServer) process(conn net.Conn, reqFrame
*framer) {
respFrame.writeString("not supported")
}
- respFrame.buf[0] = srv.protocol | 0x80
+finish:
+
+ respFrame.buf[0] |= 0x80
if err := respFrame.finish(); err != nil {
srv.errorLocked(err)
}
- if err := respFrame.writeTo(conn); err != nil {
- srv.errorLocked(err)
+ if *useProtoV5 && *startupCompleted {
+ segment, err := newUncompressedSegment(respFrame.buf, true)
+ if err == nil {
+ _, err = conn.Write(segment)
+ }
+ if err != nil {
+ srv.errorLocked(err)
+ return
+ }
+ } else {
+ if err := respFrame.writeTo(conn); err != nil {
+ srv.errorLocked(err)
+ }
+
+ if reqFrame.header.op == opStartup {
+ *startupCompleted = true
+ if head.version == protoVersion5 {
+ *useProtoV5 = true
+ }
+ }
}
}
-func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) {
+func (srv *TestServer) readFrame(reader io.Reader) (*framer, error) {
buf := make([]byte, srv.headerSize)
- head, err := readHeader(conn, buf)
+ head, err := readHeader(reader, buf)
if err != nil {
return nil, err
}
framer := newFramer(nil, srv.protocol, GlobalTypes)
- err = framer.readFrame(conn, &head)
+ err = framer.readFrame(reader, &head)
if err != nil {
return nil, err
}
@@ -1439,7 +1495,7 @@ func (srv *TestServer) readFrame(conn net.Conn) (*framer,
error) {
// should be a request frame
if head.version.response() {
return nil, fmt.Errorf("expected to read a request frame got
version: %v", head.version)
- } else if head.version.version() != srv.protocol {
+ } else if !srv.dontFailOnProtocolMismatch && head.version.version() !=
srv.protocol {
return nil, fmt.Errorf("expected to read protocol version 0x%x
got 0x%x", srv.protocol, head.version.version())
}
diff --git a/control.go b/control.go
index e59acb40..c518ba68 100644
--- a/control.go
+++ b/control.go
@@ -32,7 +32,6 @@ import (
"math/rand"
"net"
"os"
- "regexp"
"strconv"
"sync"
"sync/atomic"
@@ -202,56 +201,9 @@ func shuffleHosts(hosts []*HostInfo) []*HostInfo {
return shuffled
}
-// this is going to be version dependant and a nightmare to maintain :(
-var protocolSupportRe = regexp.MustCompile(`the lowest supported version is
\d+ and the greatest is (\d+)$`)
-var betaProtocolRe = regexp.MustCompile(`Beta version of the protocol used
\(.*\), but USE_BETA flag is unset`)
-
-func parseProtocolFromError(err error) int {
- errStr := err.Error()
-
- var errProtocol ErrProtocol
- if errors.As(err, &errProtocol) {
- err = errProtocol.error
- }
-
- // I really wish this had the actual info in the error frame...
- matches := betaProtocolRe.FindAllStringSubmatch(errStr, -1)
- if len(matches) == 1 {
- var protoErr *protocolError
- if errors.As(err, &protoErr) {
- version := protoErr.frame.Header().version.version()
- if version > 0 {
- return int(version - 1)
- }
- }
- return 0
- }
-
- matches = protocolSupportRe.FindAllStringSubmatch(errStr, -1)
- if len(matches) != 1 || len(matches[0]) != 2 {
- var protoErr *protocolError
- if errors.As(err, &protoErr) {
- return int(protoErr.frame.Header().version.version())
- }
- return 0
- }
-
- max, err := strconv.Atoi(matches[0][1])
- if err != nil {
- return 0
- }
-
- return max
-}
-
-const highestProtocolVersionSupported = 5
-
func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
hosts = shuffleHosts(hosts)
- connCfg := *c.session.connCfg
- connCfg.ProtoVersion = highestProtocolVersionSupported
-
handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
// we should never get here, but if we do it means we connected
to a
// host successfully which means our attempted protocol version
worked
@@ -261,30 +213,56 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo)
(int, error) {
})
var err error
+ var proto int
for _, host := range hosts {
- var conn *Conn
- conn, err = c.session.dial(c.session.ctx, host, &connCfg,
handler)
+ proto, err = c.tryProtocolVersionsForHost(host, handler)
+ if err == nil {
+ return proto, nil
+ }
+
+ c.session.logger.Debug("Failed to discover protocol version for
host.",
+ NewLogFieldIP("host_addr", host.ConnectAddress()),
+ NewLogFieldError("err", err))
+ }
+
+ return 0, err
+}
+
+func (c *controlConn) tryProtocolVersionsForHost(host *HostInfo, handler
ConnErrorHandler) (int, error) {
+ connCfg := *c.session.connCfg
+
+ var triedVersions []int
+
+ for proto := highestProtocolVersionSupported; proto >=
lowestProtocolVersionSupported; proto-- {
+ connCfg.ProtoVersion = proto
+
+ conn, err := c.session.dial(c.session.ctx, host, &connCfg,
handler)
if conn != nil {
conn.Close()
}
if err == nil {
- c.session.logger.Debug("Discovered protocol version
using host.",
- NewLogFieldInt("protocol_version",
connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()),
NewLogFieldString("host_id", host.HostID()))
- return connCfg.ProtoVersion, nil
+ return proto, nil
}
- if proto := parseProtocolFromError(err); proto > 0 {
- c.session.logger.Debug("Discovered protocol version
using host after parsing protocol error.",
- NewLogFieldInt("protocol_version", proto),
NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id",
host.HostID()))
- return proto, nil
+ var unsupportedErr *unsupportedProtocolVersionError
+ if errors.As(err, &unsupportedErr) {
+ // the host does not support this protocol version, try
a lower version
+ c.session.logger.Debug("Failed to connect to host
during protocol negotiation.",
+ NewLogFieldIP("host_addr",
host.ConnectAddress()),
+ NewLogFieldInt("proto_version", proto),
+ NewLogFieldError("err", err))
+ triedVersions = append(triedVersions,
connCfg.ProtoVersion)
+ continue
}
- c.session.logger.Debug("Failed to discover protocol version
using host.",
- NewLogFieldIP("host_addr", host.ConnectAddress()),
NewLogFieldString("host_id", host.HostID()), NewLogFieldError("err", err))
+ c.session.logger.Debug("Error connecting to host during
protocol negotiation.",
+ NewLogFieldIP("host_addr", host.ConnectAddress()),
+ NewLogFieldError("err", err))
+ return 0, err
}
- return 0, err
+ return 0, fmt.Errorf("gocql: failed to discover protocol version for
host %s, tried versions: %v", host.ConnectAddress(), triedVersions)
}
func (c *controlConn) connect(hosts []*HostInfo, sessionInit bool) error {
diff --git a/control_test.go b/control_test.go
index 9f83ec95..7d9311a6 100644
--- a/control_test.go
+++ b/control_test.go
@@ -57,38 +57,3 @@ func TestHostInfo_Lookup(t *testing.T) {
}
}
}
-
-func TestParseProtocol(t *testing.T) {
- tests := [...]struct {
- err error
- proto int
- }{
- {
- err: &protocolError{
- frame: errorFrame{
- code: 0x10,
- message: "Invalid or unsupported
protocol version (5); the lowest supported version is 3 and the greatest is 4",
- },
- },
- proto: 4,
- },
- {
- err: &protocolError{
- frame: errorFrame{
- frameHeader: frameHeader{
- version: 0x83,
- },
- code: 0x10,
- message: "Invalid or unsupported
protocol version: 5",
- },
- },
- proto: 3,
- },
- }
-
- for i, test := range tests {
- if proto := parseProtocolFromError(test.err); proto !=
test.proto {
- t.Errorf("%d: exepcted proto %d got %d", i, test.proto,
proto)
- }
- }
-}
diff --git a/errors.go b/errors.go
index 2d1c2205..4305f78f 100644
--- a/errors.go
+++ b/errors.go
@@ -244,3 +244,13 @@ type RequestErrCASWriteUnknown struct {
Received int
BlockFor int
}
+
+type unsupportedProtocolVersionError struct {
+ hostInfo *HostInfo
+ version protoVersion
+ err error
+}
+
+func (e unsupportedProtocolVersionError) Error() string {
+ return fmt.Sprintf("unsupported protocol version %d for host %s",
e.version, e.hostInfo.ConnectAddress())
+}
diff --git a/frame.go b/frame.go
index e86c538c..0316032e 100644
--- a/frame.go
+++ b/frame.go
@@ -65,12 +65,13 @@ func NamedValue(name string, value interface{}) interface{}
{
const (
protoDirectionMask = 0x80
protoVersionMask = 0x7F
- protoVersion1 = 0x01
- protoVersion2 = 0x02
protoVersion3 = 0x03
protoVersion4 = 0x04
protoVersion5 = 0x05
+ lowestProtocolVersionSupported = protoVersion3
+ highestProtocolVersionSupported = protoVersion5
+
maxFrameSize = 256 * 1024 * 1024
maxSegmentPayloadSize = 0x1FFFF
@@ -422,7 +423,7 @@ func readHeader(r io.Reader, p []byte) (head frameHeader,
err error) {
version := p[0] & protoVersionMask
- if version < protoVersion3 || version > protoVersion5 {
+ if version < lowestProtocolVersionSupported || version >
highestProtocolVersionSupported {
return frameHeader{}, fmt.Errorf("gocql: unsupported protocol
response version: %d", version)
}
@@ -2370,6 +2371,14 @@ func (f *framer) writeStringMap(m map[string]string) {
}
}
+func (f *framer) writeStringMultiMap(m map[string][]string) {
+ f.writeShort(uint16(len(m)))
+ for k, v := range m {
+ f.writeString(k)
+ f.writeStringList(v)
+ }
+}
+
func (f *framer) writeBytesMap(m map[string][]byte) {
f.writeShort(uint16(len(m)))
for k, v := range m {
diff --git a/protocol_negotiation_test.go b/protocol_negotiation_test.go
new file mode 100644
index 00000000..567c74e3
--- /dev/null
+++ b/protocol_negotiation_test.go
@@ -0,0 +1,267 @@
+//go:build all || unit
+// +build all unit
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package gocql
+
+import (
+ "context"
+ "encoding/binary"
+ "fmt"
+ "slices"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type requestHandlerForProtocolNegotiationTest struct {
+ supportedProtocolVersions []protoVersion
+ supportedBetaProtocols []protoVersion
+
+ // forces stream id to 0
+ forceZeroStreamID bool
+
+ forceCloseConnection bool
+}
+
+func (r *requestHandlerForProtocolNegotiationTest)
supportsBetaProtocol(version protoVersion) bool {
+ return slices.Contains(r.supportedBetaProtocols, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) supportsProtocol(version
protoVersion) bool {
+ return slices.Contains(r.supportedProtocolVersions, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) hasBetaFlag(header
*frameHeader) bool {
+ return header.flags&flagBetaProtocol == flagBetaProtocol
+}
+
+func (r *requestHandlerForProtocolNegotiationTest)
createBetaFlagUnsetProtocolErrorMessage(version protoVersion) string {
+ return fmt.Sprintf("Beta version of the protocol used (%d/v%d-beta),
but USE_BETA flag is unset", version, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) handle(_ *TestServer,
reqFrame, respFrame *framer) error {
+ if r.forceCloseConnection {
+ return fmt.Errorf("NEGOTIATION TEST: forcing close connection")
+ }
+
+ stream := reqFrame.header.stream
+
+ // If a client uses beta protocol, but the USE_BETA flag is not set, we
respond with an error
+ if r.supportsBetaProtocol(reqFrame.header.version) &&
!r.hasBetaFlag(reqFrame.header) {
+ if r.forceZeroStreamID {
+ stream = 0
+ }
+ respFrame.writeHeader(0, opError, stream)
+ respFrame.writeInt(ErrCodeProtocol)
+
respFrame.writeString(r.createBetaFlagUnsetProtocolErrorMessage(reqFrame.header.version))
+ return nil
+ }
+
+ // if a client uses an unsupported protocol version, we respond with an
error
+ if !r.supportsProtocol(reqFrame.header.version) {
+ if r.forceZeroStreamID {
+ stream = 0
+ }
+ respFrame.writeHeader(0, opError, stream)
+ respFrame.writeInt(ErrCodeProtocol)
+ respFrame.writeString(fmt.Sprintf("NEGOTIATION TEST:
Unsupported protocol version %d", reqFrame.header.version))
+ return nil
+ }
+
+ switch reqFrame.header.op {
+ case opStartup, opRegister:
+ respFrame.writeHeader(0, opReady, stream)
+ case opOptions:
+ // Emulating C* behavior.
+ // If a client uses an unsupported protocol version, C*
responds with supported versions to 0 stream id.
+ // If a client uses a beta protocol version, but the USE_BETA
flag is not set, C* responds with supported versions to 0 stream id.
+ if r.forceZeroStreamID &&
!(r.supportsProtocol(reqFrame.header.version) ||
r.supportsBetaProtocol(reqFrame.header.version) &&
!r.hasBetaFlag(reqFrame.header)) {
+ stream = 0
+ }
+ respFrame.writeHeader(0, opSupported, stream)
+ var supportedVersionsWithDesc []string
+ for _, supportedVersion := range r.supportedProtocolVersions {
+ supportedVersionsWithDesc =
append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d", supportedVersion,
supportedVersion))
+ }
+ for _, betaProtocol := range r.supportedBetaProtocols {
+ supportedVersionsWithDesc =
append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d-beta", betaProtocol,
betaProtocol))
+ }
+ supported := map[string][]string{
+ "PROTOCOL_VERSIONS": supportedVersionsWithDesc,
+ }
+ respFrame.writeStringMultiMap(supported)
+ case opQuery:
+ respFrame.writeHeader(0, opResult, stream)
+ respFrame.writeInt(resultKindRows)
+ respFrame.writeInt(int32(flagGlobalTableSpec))
+ respFrame.writeInt(1)
+ respFrame.writeString("system")
+ respFrame.writeString("local")
+ respFrame.writeString("rack")
+ respFrame.writeShort(uint16(TypeVarchar))
+ respFrame.writeInt(1)
+ respFrame.writeInt(int32(len("rack-1")))
+ respFrame.writeString("rack-1")
+ case opPrepare:
+ // This doesn't really make any sense, but it's enough to test
the protocol negotiation
+ respFrame.writeHeader(0, opResult, stream)
+ respFrame.writeInt(resultKindPrepared)
+ // <id>
+ respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil,
111))
+ if respFrame.proto >= protoVersion5 {
+
respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 222))
+ }
+ // <metadata>
+ respFrame.writeInt(0) // <flags>
+ respFrame.writeInt(0) // <columns_count>
+ if reqFrame.header.version >= protoVersion4 {
+ respFrame.writeInt(0) // <pk_count>
+ }
+ // <result_metadata>
+ respFrame.writeInt(int32(flagGlobalTableSpec)) // <flags>
+ respFrame.writeInt(1) //
<columns_count>
+ // <global_table_spec>
+ respFrame.writeString("system")
+ respFrame.writeString("keyspaces")
+ // <col_spec_0>
+ respFrame.writeString("col0") // <name>
+ respFrame.writeShort(uint16(TypeBoolean)) // <type>
+ case opExecute:
+ // This doesn't really make any sense, but it's enough to test
the protocol negotiation
+ respFrame.writeHeader(0, opResult, stream)
+ respFrame.writeInt(resultKindRows)
+ // <metadata>
+ respFrame.writeInt(0) // <flags>
+ respFrame.writeInt(0) // <columns_count>
+ // <rows_count>
+ respFrame.writeInt(0)
+ }
+
+ return nil
+}
+
+func mockedErrorCodeHandler(errorCode int) func(*TestServer, *framer, *framer)
error {
+ return func(_ *TestServer, reqFrame *framer, respFrame *framer) error {
+ reqFrame.writeHeader(0, opError, reqFrame.header.stream)
+ reqFrame.writeInt(int32(errorCode))
+ reqFrame.writeString(fmt.Sprintf("NEGOTIATION TEST: Error code
%d", errorCode))
+ return nil
+ }
+}
+
+func TestProtocolNegotiation(t *testing.T) {
+ testCases := []struct {
+ name string
+ supportedVersions []protoVersion
+ supportedBetaVersions []protoVersion
+ expectedVersion protoVersion
+ expectedErrorMsg string
+
+ forceZeroStreamID bool
+ overrideHost string
+
+ requestHandler func(*TestServer, *framer, *framer) error
+ }{
+ {
+ name: "all supported versions",
+ supportedVersions: []protoVersion{protoVersion3,
protoVersion4, protoVersion5},
+ expectedVersion: protoVersion5,
+ },
+ {
+ name: "v5-beta is supported",
+ supportedVersions: []protoVersion{protoVersion3,
protoVersion4},
+ supportedBetaVersions: []protoVersion{protoVersion5},
+ expectedVersion: protoVersion4,
+ },
+ {
+ name: "v5 is unsupported",
+ supportedVersions: []protoVersion{protoVersion3,
protoVersion4},
+ expectedVersion: protoVersion4,
+ },
+ {
+ name: "all supported versions / 0 stream
id",
+ supportedVersions: []protoVersion{protoVersion3,
protoVersion4, protoVersion5},
+ expectedVersion: protoVersion5,
+ forceZeroStreamID: true,
+ },
+ {
+ name: "v5-beta is supported / 0 stream
id",
+ supportedVersions: []protoVersion{protoVersion3,
protoVersion4},
+ supportedBetaVersions: []protoVersion{protoVersion5},
+ expectedVersion: protoVersion4,
+ forceZeroStreamID: true,
+ },
+ {
+ name: "v5 is unsupported / 0 stream id",
+ supportedVersions: []protoVersion{protoVersion3,
protoVersion4},
+ expectedVersion: protoVersion4,
+ forceZeroStreamID: true,
+ },
+ {
+ name: "wrong host addr",
+ expectedErrorMsg: "unable to discover protocol version",
+ overrideHost: "1.2.3.4", // totally wrong addr to
get network related error
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ handler := &requestHandlerForProtocolNegotiationTest{
+ supportedProtocolVersions: tc.supportedVersions,
+ supportedBetaProtocols:
tc.supportedBetaVersions,
+ forceZeroStreamID: tc.forceZeroStreamID,
+ }
+
+ srv := newTestServerOpts{
+ addr: "127.0.0.1:0",
+ protocol: 5,
+ customRequestHandler: handler.handle,
+ dontFailOnProtocolMismatch: true,
+ }.newServer(t, context.Background())
+
+ go srv.serve()
+ defer srv.Stop()
+
+ cluster := NewCluster(srv.Address)
+ if tc.overrideHost != "" {
+ cluster.Hosts = []string{tc.overrideHost}
+ }
+
+ cluster.Compressor = nil
+ cluster.ProtoVersion = 0
+ cluster.Logger = NewLogger(LogLevelDebug)
+ cluster.ConnectTimeout = time.Second * 2
+ cluster.Timeout = time.Second * 2
+ cluster.DisableInitialHostLookup = true
+
+ s, err := cluster.CreateSession()
+ switch {
+ case tc.expectedErrorMsg != "":
+ require.Error(t, err)
+ require.ErrorContains(t, err,
tc.expectedErrorMsg)
+ default:
+ require.NoError(t, err)
+ require.Equal(t, tc.expectedVersion,
protoVersion(s.cfg.ProtoVersion))
+ }
+ })
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]