This is an automated email from the ASF dual-hosted git repository.

joaoreis pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git


The following commit(s) were added to refs/heads/trunk by this push:
     new f1e31a58 Protocol version negotiation doesn't work if server replies 
with stream id different than 0
f1e31a58 is described below

commit f1e31a58f7e0c25e58e2e2a0a0c6de358e643e8b
Author: Bohdan Siryk <[email protected]>
AuthorDate: Fri Nov 28 10:14:22 2025 +0200

    Protocol version negotiation doesn't work if server replies with stream id 
different than 0
    
    Previously, protocol negotiation didn't work properly when C* was 
responding with stream id different from 0.
    
    This patch changes the way protocol negotiation works. Instead of parsing a 
supported protocol version from C* error response, the driver tries to connect 
with each supported protocol starting from the latest.
    
    Patch by Bohdan Siryk; Reviewed by João Reis for CASSGO-98
---
 CHANGELOG.md                 |   1 +
 Makefile                     |   2 +-
 conn.go                      |  51 ++++++++-
 conn_test.go                 |  78 +++++++++++--
 control.go                   |  98 ++++++----------
 control_test.go              |  35 ------
 errors.go                    |  10 ++
 frame.go                     |  15 ++-
 protocol_negotiation_test.go | 267 +++++++++++++++++++++++++++++++++++++++++++
 9 files changed, 442 insertions(+), 115 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index cd327073..872ebdd1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -17,6 +17,7 @@ and this project adheres to [Semantic 
Versioning](https://semver.org/spec/v2.0.0
 - Prevent panic with queries during session init (CASSGO-92)
 - Return correct values from RowData (CASSGO-95)
 - Prevent setting a compression flag in a frame header when native proto v5 is 
being used (CASSGO-98)
+- Use protocol downgrading approach during protocol negotiation (CASSGO-97)
 
 ## [2.0.0]
 
diff --git a/Makefile b/Makefile
index 56ed015d..c4f6df21 100644
--- a/Makefile
+++ b/Makefile
@@ -100,7 +100,7 @@ test-integration-auth: .prepare-cassandra-cluster
 test-unit:
        @echo "Run unit tests"
        @go clean -testcache
-       go test -tags unit -timeout=5m -race ./...
+       go test -v -tags unit -timeout=5m -race ./...
 
 check: .prepare-golangci
        @echo "Build"
diff --git a/conn.go b/conn.go
index 40044565..a9bb4f5a 100644
--- a/conn.go
+++ b/conn.go
@@ -378,6 +378,13 @@ func (s *startupCoordinator) setupConn(ctx 
context.Context) error {
        select {
        case err := <-startupErr:
                if err != nil {
+                       if s.checkProtocolRelatedError(err) {
+                               return &unsupportedProtocolVersionError{
+                                       err:      err,
+                                       hostInfo: s.conn.host,
+                                       version:  protoVersion(s.conn.version),
+                               }
+                       }
                        return err
                }
        case <-ctx.Done():
@@ -387,6 +394,38 @@ func (s *startupCoordinator) setupConn(ctx 
context.Context) error {
        return nil
 }
 
+// Checks if the error is protocol related and should be retried during 
startup.
+// It returns the frame that caused the error and whether the error should be 
retried.
+func (s *startupCoordinator) checkProtocolRelatedError(err error) bool {
+       var unwrappedFrame frame
+
+       var protocolErr *protocolError
+       if !errors.As(err, &protocolErr) {
+               var errFrame errorFrame
+               if !errors.As(err, &errFrame) {
+                       return false
+               } else {
+                       unwrappedFrame = errFrame
+               }
+       } else {
+               unwrappedFrame = protocolErr.frame
+       }
+
+       switch frame := unwrappedFrame.(type) {
+       case *supportedFrame:
+               // We can receive a supportedFrame wrapped in protocolError 
from Conn.recv if the host responds to a 0 stream id.
+               // If we receive a supportedFrame then we know that the host is 
not compatible with the protocol version, but it is reachable, so we can retry
+               return true
+       case errorFrame:
+               // If we receive an errorFrame with codes ErrCodeProtocol or 
ErrCodeServer,
+               // then we should try to downgrade a protocol version, so do 
not skip the host
+               return frame.code == ErrCodeProtocol || frame.code == 
ErrCodeServer
+       default:
+               // In any other case we should not retry as it means the host 
is not reachable or some other error happened
+               return false
+       }
+}
+
 func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder, 
startupCompleted *atomic.Bool) (frame, error) {
        select {
        case s.frameTicker <- struct{}{}:
@@ -408,12 +447,14 @@ func (s *startupCoordinator) options(ctx context.Context, 
startupCompleted *atom
                return err
        }
 
-       supported, ok := frame.(*supportedFrame)
-       if !ok {
-               return NewErrProtocol("Unknown type of response to startup 
frame: %T", frame)
+       switch frame := frame.(type) {
+       case *supportedFrame:
+               return s.startup(ctx, frame.supported, startupCompleted)
+       case error:
+               return frame
+       default:
+               return NewErrProtocol("Unknown type of response to startup 
frame: %T (frame=%s)", frame, frame.String())
        }
-
-       return s.startup(ctx, supported.supported, startupCompleted)
 }
 
 func (s *startupCoordinator) startup(ctx context.Context, supported 
map[string][]string, startupCompleted *atomic.Bool) error {
diff --git a/conn_test.go b/conn_test.go
index 60e4a2a8..ad4e66e5 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -1054,6 +1054,9 @@ type newTestServerOpts struct {
        addr     string
        protocol uint8
        recvHook func(*framer)
+
+       customRequestHandler       func(srv *TestServer, reqFrame, respFrame 
*framer) error
+       dontFailOnProtocolMismatch bool
 }
 
 func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) 
*TestServer {
@@ -1078,6 +1081,9 @@ func (nts newTestServerOpts) newServer(t testing.TB, ctx 
context.Context) *TestS
                cancel:     cancel,
 
                onRecv: nts.recvHook,
+
+               customRequestHandler:       nts.customRequestHandler,
+               dontFailOnProtocolMismatch: nts.dontFailOnProtocolMismatch,
        }
 
        go srv.closeWatch()
@@ -1142,6 +1148,10 @@ type TestServer struct {
 
        // onRecv is a hook point for tests, called in receive loop.
        onRecv func(*framer)
+
+       // customRequestHandler allows overriding the default request handling 
for testing purposes.
+       customRequestHandler       func(srv *TestServer, reqFrame, respFrame 
*framer) error
+       dontFailOnProtocolMismatch bool
 }
 
 func (srv *TestServer) closeWatch() {
@@ -1162,9 +1172,26 @@ func (srv *TestServer) serve() {
                }
 
                go func(conn net.Conn) {
+                       var startupCompleted bool
+                       var useProtoV5 bool
+
                        defer conn.Close()
                        for !srv.isClosed() {
-                               framer, err := srv.readFrame(conn)
+                               var reader io.Reader = conn
+
+                               if useProtoV5 && startupCompleted {
+                                       frame, _, err := 
readUncompressedSegment(conn)
+                                       if err != nil {
+                                               if errors.Is(err, io.EOF) {
+                                                       return
+                                               }
+                                               srv.errorLocked(err)
+                                               return
+                                       }
+                                       reader = bytes.NewReader(frame)
+                               }
+
+                               framer, err := srv.readFrame(reader)
                                if err != nil {
                                        if err == io.EOF {
                                                return
@@ -1177,7 +1204,7 @@ func (srv *TestServer) serve() {
                                        srv.onRecv(framer)
                                }
 
-                               go srv.process(conn, framer)
+                               srv.process(conn, framer, &useProtoV5, 
&startupCompleted)
                        }
                }(conn)
        }
@@ -1215,13 +1242,22 @@ func (srv *TestServer) errorLocked(err interface{}) {
        srv.t.Error(err)
 }
 
-func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
+func (srv *TestServer) process(conn net.Conn, reqFrame *framer, useProtoV5, 
startupCompleted *bool) {
        head := reqFrame.header
        if head == nil {
                srv.errorLocked("process frame with a nil header")
                return
        }
-       respFrame := newFramer(nil, reqFrame.proto, GlobalTypes)
+       respFrame := newFramer(nil, byte(head.version), GlobalTypes)
+
+       if srv.customRequestHandler != nil {
+               if err := srv.customRequestHandler(srv, reqFrame, respFrame); 
err != nil {
+                       srv.errorLocked(err)
+                       return
+               }
+               // Dont like this but...
+               goto finish
+       }
 
        switch head.op {
        case opStartup:
@@ -1412,26 +1448,46 @@ func (srv *TestServer) process(conn net.Conn, reqFrame 
*framer) {
                respFrame.writeString("not supported")
        }
 
-       respFrame.buf[0] = srv.protocol | 0x80
+finish:
+
+       respFrame.buf[0] |= 0x80
 
        if err := respFrame.finish(); err != nil {
                srv.errorLocked(err)
        }
 
-       if err := respFrame.writeTo(conn); err != nil {
-               srv.errorLocked(err)
+       if *useProtoV5 && *startupCompleted {
+               segment, err := newUncompressedSegment(respFrame.buf, true)
+               if err == nil {
+                       _, err = conn.Write(segment)
+               }
+               if err != nil {
+                       srv.errorLocked(err)
+                       return
+               }
+       } else {
+               if err := respFrame.writeTo(conn); err != nil {
+                       srv.errorLocked(err)
+               }
+
+               if reqFrame.header.op == opStartup {
+                       *startupCompleted = true
+                       if head.version == protoVersion5 {
+                               *useProtoV5 = true
+                       }
+               }
        }
 }
 
-func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) {
+func (srv *TestServer) readFrame(reader io.Reader) (*framer, error) {
        buf := make([]byte, srv.headerSize)
-       head, err := readHeader(conn, buf)
+       head, err := readHeader(reader, buf)
        if err != nil {
                return nil, err
        }
        framer := newFramer(nil, srv.protocol, GlobalTypes)
 
-       err = framer.readFrame(conn, &head)
+       err = framer.readFrame(reader, &head)
        if err != nil {
                return nil, err
        }
@@ -1439,7 +1495,7 @@ func (srv *TestServer) readFrame(conn net.Conn) (*framer, 
error) {
        // should be a request frame
        if head.version.response() {
                return nil, fmt.Errorf("expected to read a request frame got 
version: %v", head.version)
-       } else if head.version.version() != srv.protocol {
+       } else if !srv.dontFailOnProtocolMismatch && head.version.version() != 
srv.protocol {
                return nil, fmt.Errorf("expected to read protocol version 0x%x 
got 0x%x", srv.protocol, head.version.version())
        }
 
diff --git a/control.go b/control.go
index e59acb40..c518ba68 100644
--- a/control.go
+++ b/control.go
@@ -32,7 +32,6 @@ import (
        "math/rand"
        "net"
        "os"
-       "regexp"
        "strconv"
        "sync"
        "sync/atomic"
@@ -202,56 +201,9 @@ func shuffleHosts(hosts []*HostInfo) []*HostInfo {
        return shuffled
 }
 
-// this is going to be version dependant and a nightmare to maintain :(
-var protocolSupportRe = regexp.MustCompile(`the lowest supported version is 
\d+ and the greatest is (\d+)$`)
-var betaProtocolRe = regexp.MustCompile(`Beta version of the protocol used 
\(.*\), but USE_BETA flag is unset`)
-
-func parseProtocolFromError(err error) int {
-       errStr := err.Error()
-
-       var errProtocol ErrProtocol
-       if errors.As(err, &errProtocol) {
-               err = errProtocol.error
-       }
-
-       // I really wish this had the actual info in the error frame...
-       matches := betaProtocolRe.FindAllStringSubmatch(errStr, -1)
-       if len(matches) == 1 {
-               var protoErr *protocolError
-               if errors.As(err, &protoErr) {
-                       version := protoErr.frame.Header().version.version()
-                       if version > 0 {
-                               return int(version - 1)
-                       }
-               }
-               return 0
-       }
-
-       matches = protocolSupportRe.FindAllStringSubmatch(errStr, -1)
-       if len(matches) != 1 || len(matches[0]) != 2 {
-               var protoErr *protocolError
-               if errors.As(err, &protoErr) {
-                       return int(protoErr.frame.Header().version.version())
-               }
-               return 0
-       }
-
-       max, err := strconv.Atoi(matches[0][1])
-       if err != nil {
-               return 0
-       }
-
-       return max
-}
-
-const highestProtocolVersionSupported = 5
-
 func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
        hosts = shuffleHosts(hosts)
 
-       connCfg := *c.session.connCfg
-       connCfg.ProtoVersion = highestProtocolVersionSupported
-
        handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
                // we should never get here, but if we do it means we connected 
to a
                // host successfully which means our attempted protocol version 
worked
@@ -261,30 +213,56 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) 
(int, error) {
        })
 
        var err error
+       var proto int
        for _, host := range hosts {
-               var conn *Conn
-               conn, err = c.session.dial(c.session.ctx, host, &connCfg, 
handler)
+               proto, err = c.tryProtocolVersionsForHost(host, handler)
+               if err == nil {
+                       return proto, nil
+               }
+
+               c.session.logger.Debug("Failed to discover protocol version for 
host.",
+                       NewLogFieldIP("host_addr", host.ConnectAddress()),
+                       NewLogFieldError("err", err))
+       }
+
+       return 0, err
+}
+
+func (c *controlConn) tryProtocolVersionsForHost(host *HostInfo, handler 
ConnErrorHandler) (int, error) {
+       connCfg := *c.session.connCfg
+
+       var triedVersions []int
+
+       for proto := highestProtocolVersionSupported; proto >= 
lowestProtocolVersionSupported; proto-- {
+               connCfg.ProtoVersion = proto
+
+               conn, err := c.session.dial(c.session.ctx, host, &connCfg, 
handler)
                if conn != nil {
                        conn.Close()
                }
 
                if err == nil {
-                       c.session.logger.Debug("Discovered protocol version 
using host.",
-                               NewLogFieldInt("protocol_version", 
connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()), 
NewLogFieldString("host_id", host.HostID()))
-                       return connCfg.ProtoVersion, nil
+                       return proto, nil
                }
 
-               if proto := parseProtocolFromError(err); proto > 0 {
-                       c.session.logger.Debug("Discovered protocol version 
using host after parsing protocol error.",
-                               NewLogFieldInt("protocol_version", proto), 
NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", 
host.HostID()))
-                       return proto, nil
+               var unsupportedErr *unsupportedProtocolVersionError
+               if errors.As(err, &unsupportedErr) {
+                       // the host does not support this protocol version, try 
a lower version
+                       c.session.logger.Debug("Failed to connect to host 
during protocol negotiation.",
+                               NewLogFieldIP("host_addr", 
host.ConnectAddress()),
+                               NewLogFieldInt("proto_version", proto),
+                               NewLogFieldError("err", err))
+                       triedVersions = append(triedVersions, 
connCfg.ProtoVersion)
+                       continue
                }
 
-               c.session.logger.Debug("Failed to discover protocol version 
using host.",
-                       NewLogFieldIP("host_addr", host.ConnectAddress()), 
NewLogFieldString("host_id", host.HostID()), NewLogFieldError("err", err))
+               c.session.logger.Debug("Error connecting to host during 
protocol negotiation.",
+                       NewLogFieldIP("host_addr", host.ConnectAddress()),
+                       NewLogFieldError("err", err))
+               return 0, err
        }
 
-       return 0, err
+       return 0, fmt.Errorf("gocql: failed to discover protocol version for 
host %s, tried versions: %v", host.ConnectAddress(), triedVersions)
 }
 
 func (c *controlConn) connect(hosts []*HostInfo, sessionInit bool) error {
diff --git a/control_test.go b/control_test.go
index 9f83ec95..7d9311a6 100644
--- a/control_test.go
+++ b/control_test.go
@@ -57,38 +57,3 @@ func TestHostInfo_Lookup(t *testing.T) {
                }
        }
 }
-
-func TestParseProtocol(t *testing.T) {
-       tests := [...]struct {
-               err   error
-               proto int
-       }{
-               {
-                       err: &protocolError{
-                               frame: errorFrame{
-                                       code:    0x10,
-                                       message: "Invalid or unsupported 
protocol version (5); the lowest supported version is 3 and the greatest is 4",
-                               },
-                       },
-                       proto: 4,
-               },
-               {
-                       err: &protocolError{
-                               frame: errorFrame{
-                                       frameHeader: frameHeader{
-                                               version: 0x83,
-                                       },
-                                       code:    0x10,
-                                       message: "Invalid or unsupported 
protocol version: 5",
-                               },
-                       },
-                       proto: 3,
-               },
-       }
-
-       for i, test := range tests {
-               if proto := parseProtocolFromError(test.err); proto != 
test.proto {
-                       t.Errorf("%d: exepcted proto %d got %d", i, test.proto, 
proto)
-               }
-       }
-}
diff --git a/errors.go b/errors.go
index 2d1c2205..4305f78f 100644
--- a/errors.go
+++ b/errors.go
@@ -244,3 +244,13 @@ type RequestErrCASWriteUnknown struct {
        Received    int
        BlockFor    int
 }
+
+type unsupportedProtocolVersionError struct {
+       hostInfo *HostInfo
+       version  protoVersion
+       err      error
+}
+
+func (e unsupportedProtocolVersionError) Error() string {
+       return fmt.Sprintf("unsupported protocol version %d for host %s", 
e.version, e.hostInfo.ConnectAddress())
+}
diff --git a/frame.go b/frame.go
index e86c538c..0316032e 100644
--- a/frame.go
+++ b/frame.go
@@ -65,12 +65,13 @@ func NamedValue(name string, value interface{}) interface{} 
{
 const (
        protoDirectionMask = 0x80
        protoVersionMask   = 0x7F
-       protoVersion1      = 0x01
-       protoVersion2      = 0x02
        protoVersion3      = 0x03
        protoVersion4      = 0x04
        protoVersion5      = 0x05
 
+       lowestProtocolVersionSupported  = protoVersion3
+       highestProtocolVersionSupported = protoVersion5
+
        maxFrameSize = 256 * 1024 * 1024
 
        maxSegmentPayloadSize = 0x1FFFF
@@ -422,7 +423,7 @@ func readHeader(r io.Reader, p []byte) (head frameHeader, 
err error) {
 
        version := p[0] & protoVersionMask
 
-       if version < protoVersion3 || version > protoVersion5 {
+       if version < lowestProtocolVersionSupported || version > 
highestProtocolVersionSupported {
                return frameHeader{}, fmt.Errorf("gocql: unsupported protocol 
response version: %d", version)
        }
 
@@ -2370,6 +2371,14 @@ func (f *framer) writeStringMap(m map[string]string) {
        }
 }
 
+func (f *framer) writeStringMultiMap(m map[string][]string) {
+       f.writeShort(uint16(len(m)))
+       for k, v := range m {
+               f.writeString(k)
+               f.writeStringList(v)
+       }
+}
+
 func (f *framer) writeBytesMap(m map[string][]byte) {
        f.writeShort(uint16(len(m)))
        for k, v := range m {
diff --git a/protocol_negotiation_test.go b/protocol_negotiation_test.go
new file mode 100644
index 00000000..567c74e3
--- /dev/null
+++ b/protocol_negotiation_test.go
@@ -0,0 +1,267 @@
+//go:build all || unit
+// +build all unit
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package gocql
+
+import (
+       "context"
+       "encoding/binary"
+       "fmt"
+       "slices"
+       "testing"
+       "time"
+
+       "github.com/stretchr/testify/require"
+)
+
+type requestHandlerForProtocolNegotiationTest struct {
+       supportedProtocolVersions []protoVersion
+       supportedBetaProtocols    []protoVersion
+
+       // forces stream id to 0
+       forceZeroStreamID bool
+
+       forceCloseConnection bool
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) 
supportsBetaProtocol(version protoVersion) bool {
+       return slices.Contains(r.supportedBetaProtocols, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) supportsProtocol(version 
protoVersion) bool {
+       return slices.Contains(r.supportedProtocolVersions, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) hasBetaFlag(header 
*frameHeader) bool {
+       return header.flags&flagBetaProtocol == flagBetaProtocol
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) 
createBetaFlagUnsetProtocolErrorMessage(version protoVersion) string {
+       return fmt.Sprintf("Beta version of the protocol used (%d/v%d-beta), 
but USE_BETA flag is unset", version, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) handle(_ *TestServer, 
reqFrame, respFrame *framer) error {
+       if r.forceCloseConnection {
+               return fmt.Errorf("NEGOTIATION TEST: forcing close connection")
+       }
+
+       stream := reqFrame.header.stream
+
+       // If a client uses beta protocol, but the USE_BETA flag is not set, we 
respond with an error
+       if r.supportsBetaProtocol(reqFrame.header.version) && 
!r.hasBetaFlag(reqFrame.header) {
+               if r.forceZeroStreamID {
+                       stream = 0
+               }
+               respFrame.writeHeader(0, opError, stream)
+               respFrame.writeInt(ErrCodeProtocol)
+               
respFrame.writeString(r.createBetaFlagUnsetProtocolErrorMessage(reqFrame.header.version))
+               return nil
+       }
+
+       // if a client uses an unsupported protocol version, we respond with an 
error
+       if !r.supportsProtocol(reqFrame.header.version) {
+               if r.forceZeroStreamID {
+                       stream = 0
+               }
+               respFrame.writeHeader(0, opError, stream)
+               respFrame.writeInt(ErrCodeProtocol)
+               respFrame.writeString(fmt.Sprintf("NEGOTIATION TEST: 
Unsupported protocol version %d", reqFrame.header.version))
+               return nil
+       }
+
+       switch reqFrame.header.op {
+       case opStartup, opRegister:
+               respFrame.writeHeader(0, opReady, stream)
+       case opOptions:
+               // Emulating C* behavior.
+               // If a client uses an unsupported protocol version, C* 
responds with supported versions to 0 stream id.
+               // If a client uses a beta protocol version, but the USE_BETA 
flag is not set, C* responds with supported versions to 0 stream id.
+               if r.forceZeroStreamID && 
!(r.supportsProtocol(reqFrame.header.version) || 
r.supportsBetaProtocol(reqFrame.header.version) && 
!r.hasBetaFlag(reqFrame.header)) {
+                       stream = 0
+               }
+               respFrame.writeHeader(0, opSupported, stream)
+               var supportedVersionsWithDesc []string
+               for _, supportedVersion := range r.supportedProtocolVersions {
+                       supportedVersionsWithDesc = 
append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d", supportedVersion, 
supportedVersion))
+               }
+               for _, betaProtocol := range r.supportedBetaProtocols {
+                       supportedVersionsWithDesc = 
append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d-beta", betaProtocol, 
betaProtocol))
+               }
+               supported := map[string][]string{
+                       "PROTOCOL_VERSIONS": supportedVersionsWithDesc,
+               }
+               respFrame.writeStringMultiMap(supported)
+       case opQuery:
+               respFrame.writeHeader(0, opResult, stream)
+               respFrame.writeInt(resultKindRows)
+               respFrame.writeInt(int32(flagGlobalTableSpec))
+               respFrame.writeInt(1)
+               respFrame.writeString("system")
+               respFrame.writeString("local")
+               respFrame.writeString("rack")
+               respFrame.writeShort(uint16(TypeVarchar))
+               respFrame.writeInt(1)
+               respFrame.writeInt(int32(len("rack-1")))
+               respFrame.writeString("rack-1")
+       case opPrepare:
+               // This doesn't really make any sense, but it's enough to test 
the protocol negotiation
+               respFrame.writeHeader(0, opResult, stream)
+               respFrame.writeInt(resultKindPrepared)
+               // <id>
+               respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 
111))
+               if respFrame.proto >= protoVersion5 {
+                       
respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 222))
+               }
+               // <metadata>
+               respFrame.writeInt(0) // <flags>
+               respFrame.writeInt(0) // <columns_count>
+               if reqFrame.header.version >= protoVersion4 {
+                       respFrame.writeInt(0) // <pk_count>
+               }
+               // <result_metadata>
+               respFrame.writeInt(int32(flagGlobalTableSpec)) // <flags>
+               respFrame.writeInt(1)                          // 
<columns_count>
+               // <global_table_spec>
+               respFrame.writeString("system")
+               respFrame.writeString("keyspaces")
+               // <col_spec_0>
+               respFrame.writeString("col0")             // <name>
+               respFrame.writeShort(uint16(TypeBoolean)) // <type>
+       case opExecute:
+               // This doesn't really make any sense, but it's enough to test 
the protocol negotiation
+               respFrame.writeHeader(0, opResult, stream)
+               respFrame.writeInt(resultKindRows)
+               // <metadata>
+               respFrame.writeInt(0) // <flags>
+               respFrame.writeInt(0) // <columns_count>
+               // <rows_count>
+               respFrame.writeInt(0)
+       }
+
+       return nil
+}
+
+func mockedErrorCodeHandler(errorCode int) func(*TestServer, *framer, *framer) 
error {
+       return func(_ *TestServer, reqFrame *framer, respFrame *framer) error {
+               reqFrame.writeHeader(0, opError, reqFrame.header.stream)
+               reqFrame.writeInt(int32(errorCode))
+               reqFrame.writeString(fmt.Sprintf("NEGOTIATION TEST: Error code 
%d", errorCode))
+               return nil
+       }
+}
+
+func TestProtocolNegotiation(t *testing.T) {
+       testCases := []struct {
+               name                  string
+               supportedVersions     []protoVersion
+               supportedBetaVersions []protoVersion
+               expectedVersion       protoVersion
+               expectedErrorMsg      string
+
+               forceZeroStreamID bool
+               overrideHost      string
+
+               requestHandler func(*TestServer, *framer, *framer) error
+       }{
+               {
+                       name:              "all supported versions",
+                       supportedVersions: []protoVersion{protoVersion3, 
protoVersion4, protoVersion5},
+                       expectedVersion:   protoVersion5,
+               },
+               {
+                       name:                  "v5-beta is supported",
+                       supportedVersions:     []protoVersion{protoVersion3, 
protoVersion4},
+                       supportedBetaVersions: []protoVersion{protoVersion5},
+                       expectedVersion:       protoVersion4,
+               },
+               {
+                       name:              "v5 is unsupported",
+                       supportedVersions: []protoVersion{protoVersion3, 
protoVersion4},
+                       expectedVersion:   protoVersion4,
+               },
+               {
+                       name:              "all supported versions / 0 stream 
id",
+                       supportedVersions: []protoVersion{protoVersion3, 
protoVersion4, protoVersion5},
+                       expectedVersion:   protoVersion5,
+                       forceZeroStreamID: true,
+               },
+               {
+                       name:                  "v5-beta is supported / 0 stream 
id",
+                       supportedVersions:     []protoVersion{protoVersion3, 
protoVersion4},
+                       supportedBetaVersions: []protoVersion{protoVersion5},
+                       expectedVersion:       protoVersion4,
+                       forceZeroStreamID:     true,
+               },
+               {
+                       name:              "v5 is unsupported / 0 stream id",
+                       supportedVersions: []protoVersion{protoVersion3, 
protoVersion4},
+                       expectedVersion:   protoVersion4,
+                       forceZeroStreamID: true,
+               },
+               {
+                       name:             "wrong host addr",
+                       expectedErrorMsg: "unable to discover protocol version",
+                       overrideHost:     "1.2.3.4", // totally wrong addr to 
get network related error
+               },
+       }
+
+       for _, tc := range testCases {
+               t.Run(tc.name, func(t *testing.T) {
+                       handler := &requestHandlerForProtocolNegotiationTest{
+                               supportedProtocolVersions: tc.supportedVersions,
+                               supportedBetaProtocols:    
tc.supportedBetaVersions,
+                               forceZeroStreamID:         tc.forceZeroStreamID,
+                       }
+
+                       srv := newTestServerOpts{
+                               addr:                       "127.0.0.1:0",
+                               protocol:                   5,
+                               customRequestHandler:       handler.handle,
+                               dontFailOnProtocolMismatch: true,
+                       }.newServer(t, context.Background())
+
+                       go srv.serve()
+                       defer srv.Stop()
+
+                       cluster := NewCluster(srv.Address)
+                       if tc.overrideHost != "" {
+                               cluster.Hosts = []string{tc.overrideHost}
+                       }
+
+                       cluster.Compressor = nil
+                       cluster.ProtoVersion = 0
+                       cluster.Logger = NewLogger(LogLevelDebug)
+                       cluster.ConnectTimeout = time.Second * 2
+                       cluster.Timeout = time.Second * 2
+                       cluster.DisableInitialHostLookup = true
+
+                       s, err := cluster.CreateSession()
+                       switch {
+                       case tc.expectedErrorMsg != "":
+                               require.Error(t, err)
+                               require.ErrorContains(t, err, 
tc.expectedErrorMsg)
+                       default:
+                               require.NoError(t, err)
+                               require.Equal(t, tc.expectedVersion, 
protoVersion(s.cfg.ProtoVersion))
+                       }
+               })
+       }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to