joao-r-reis commented on code in PR #1920:
URL: 
https://github.com/apache/cassandra-gocql-driver/pull/1920#discussion_r2571926410


##########
protocol_negotiation_test.go:
##########
@@ -0,0 +1,179 @@
+//go:build all || unit
+// +build all unit
+
+package gocql
+
+import (
+       "context"
+       "encoding/binary"
+       "fmt"
+       "slices"
+       "testing"
+       "time"
+
+       "github.com/stretchr/testify/require"
+)
+
+type requestHandlerForProtocolNegotiationTest struct {
+       supportedProtocolVersions []protoVersion
+       supportedBetaProtocols    []protoVersion
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) 
supportsBetaProtocol(version protoVersion) bool {
+       return slices.Contains(r.supportedBetaProtocols, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) supportsProtocol(version 
protoVersion) bool {
+       return slices.Contains(r.supportedProtocolVersions, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) hasBetaFlag(header 
*frameHeader) bool {
+       return header.flags&flagBetaProtocol == flagBetaProtocol
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) 
createBetaFlagUnsetProtocolErrorMessage(version protoVersion) string {
+       return fmt.Sprintf("Beta version of the protocol used (%d/v%d-beta), 
but USE_BETA flag is unset", version, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) handle(_ *TestServer, 
reqFrame, respFrame *framer) error {
+       // If a client uses beta protocol, but the USE_BETA flag is not set, we 
respond with an error
+       if r.supportsBetaProtocol(reqFrame.header.version) && 
!r.hasBetaFlag(reqFrame.header) {
+               respFrame.writeHeader(0, opError, reqFrame.header.stream)
+               respFrame.writeInt(ErrCodeProtocol)
+               
respFrame.writeString(r.createBetaFlagUnsetProtocolErrorMessage(reqFrame.header.version))
+               return nil
+       }
+
+       // if a client uses an unsupported protocol version, we respond with an 
error
+       if !r.supportsProtocol(reqFrame.header.version) {
+               respFrame.writeHeader(0, opError, reqFrame.header.stream)
+               respFrame.writeInt(ErrCodeProtocol)
+               respFrame.writeString(fmt.Sprintf("NEGOTITATION TEST: 
Unsupported protocol version %d", reqFrame.header.version))
+               return nil
+       }
+
+       stream := reqFrame.header.stream
+
+       switch reqFrame.header.op {
+       case opStartup, opRegister:
+               respFrame.writeHeader(0, opReady, stream)
+       case opOptions:
+               respFrame.writeHeader(0, opSupported, stream)
+               var supportedVersionsWithDesc []string
+               for _, supportedVersion := range r.supportedProtocolVersions {
+                       supportedVersionsWithDesc = 
append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d", supportedVersion, 
supportedVersion))
+               }
+               for _, betaProtocol := range r.supportedBetaProtocols {
+                       supportedVersionsWithDesc = 
append(supportedVersionsWithDesc, fmt.Sprintf("%d/v%d-beta", betaProtocol, 
betaProtocol))
+               }
+               supported := map[string][]string{
+                       "PROTOCOL_VERSIONS": supportedVersionsWithDesc,
+               }
+               respFrame.writeStringMultiMap(supported)
+       case opQuery:
+               respFrame.writeHeader(0, opResult, stream)
+               respFrame.writeInt(resultKindRows)
+               respFrame.writeInt(int32(flagGlobalTableSpec))
+               respFrame.writeInt(1)
+               respFrame.writeString("system")
+               respFrame.writeString("local")
+               respFrame.writeString("rack")
+               respFrame.writeShort(uint16(TypeVarchar))
+               respFrame.writeInt(1)
+               respFrame.writeInt(int32(len("rack-1")))
+               respFrame.writeString("rack-1")
+       case opPrepare:
+               // This doesn't really make any sense, but it's enough to test 
the protocol negotiation
+               respFrame.writeHeader(0, opResult, stream)
+               respFrame.writeInt(resultKindPrepared)
+               // <id>
+               respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 
111))
+               if respFrame.proto >= protoVersion5 {
+                       
respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 222))
+               }
+               // <metadata>
+               respFrame.writeInt(0) // <flags>
+               respFrame.writeInt(0) // <columns_count>
+               if reqFrame.header.version >= protoVersion4 {
+                       respFrame.writeInt(0) // <pk_count>
+               }
+               // <result_metadata>
+               respFrame.writeInt(int32(flagGlobalTableSpec)) // <flags>
+               respFrame.writeInt(1)                          // 
<columns_count>
+               // <global_table_spec>
+               respFrame.writeString("system")
+               respFrame.writeString("keyspaces")
+               // <col_spec_0>
+               respFrame.writeString("col0")             // <name>
+               respFrame.writeShort(uint16(TypeBoolean)) // <type>
+       case opExecute:
+               // This doesn't really make any sense, but it's enough to test 
the protocol negotiation
+               respFrame.writeHeader(0, opResult, stream)
+               respFrame.writeInt(resultKindRows)
+               // <metadata>
+               respFrame.writeInt(0) // <flags>
+               respFrame.writeInt(0) // <columns_count>
+               // <rows_count>
+               respFrame.writeInt(0)
+       }
+
+       return nil
+}
+
+func TestProtocolNegotiation(t *testing.T) {

Review Comment:
   Very good job with this test, I feel like it was a weak spot in our test 
suite and now it's much better. Can you add `-v` to the `go test` command in 
the GH workflow step that runs the unit tests so we can double check which 
tests are running?



##########
control.go:
##########
@@ -262,26 +217,27 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) 
(int, error) {
 
        var err error
        for _, host := range hosts {
-               var conn *Conn
-               conn, err = c.session.dial(c.session.ctx, host, &connCfg, 
handler)
-               if conn != nil {
-                       conn.Close()
-               }
+               connCfg := *c.session.connCfg
+               for proto := highestProtocolVersionSupported; proto >= 
lowestProtocolVersionSupported; proto-- {
+                       connCfg.ProtoVersion = proto
+
+                       var conn *Conn
+                       conn, err = c.session.dial(c.session.ctx, host, 
&connCfg, handler)
+                       if conn != nil {
+                               conn.Close()
+                       }
 
-               if err == nil {
-                       c.session.logger.Debug("Discovered protocol version 
using host.",
-                               NewLogFieldInt("protocol_version", 
connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()), 
NewLogFieldString("host_id", host.HostID()))
-                       return connCfg.ProtoVersion, nil
-               }
+                       if err == nil {
+                               c.session.logger.Debug("Discovered protocol 
version using host.",
+                                       NewLogFieldInt("protocol_version", 
connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()))

Review Comment:
   any particular reason you removed the host_id log field from both calls? Is 
it because it's not populated?



##########
protocol_negotiation_test.go:
##########
@@ -0,0 +1,179 @@
+//go:build all || unit
+// +build all unit
+
+package gocql

Review Comment:
   add license header, we should add a check for this in CI at some point



##########
protocol_negotiation_test.go:
##########
@@ -0,0 +1,179 @@
+//go:build all || unit
+// +build all unit
+
+package gocql
+
+import (
+       "context"
+       "encoding/binary"
+       "fmt"
+       "slices"
+       "testing"
+       "time"
+
+       "github.com/stretchr/testify/require"
+)
+
+type requestHandlerForProtocolNegotiationTest struct {
+       supportedProtocolVersions []protoVersion
+       supportedBetaProtocols    []protoVersion
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) 
supportsBetaProtocol(version protoVersion) bool {
+       return slices.Contains(r.supportedBetaProtocols, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) supportsProtocol(version 
protoVersion) bool {
+       return slices.Contains(r.supportedProtocolVersions, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) hasBetaFlag(header 
*frameHeader) bool {
+       return header.flags&flagBetaProtocol == flagBetaProtocol
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) 
createBetaFlagUnsetProtocolErrorMessage(version protoVersion) string {
+       return fmt.Sprintf("Beta version of the protocol used (%d/v%d-beta), 
but USE_BETA flag is unset", version, version)
+}
+
+func (r *requestHandlerForProtocolNegotiationTest) handle(_ *TestServer, 
reqFrame, respFrame *framer) error {
+       // If a client uses beta protocol, but the USE_BETA flag is not set, we 
respond with an error
+       if r.supportsBetaProtocol(reqFrame.header.version) && 
!r.hasBetaFlag(reqFrame.header) {
+               respFrame.writeHeader(0, opError, reqFrame.header.stream)
+               respFrame.writeInt(ErrCodeProtocol)
+               
respFrame.writeString(r.createBetaFlagUnsetProtocolErrorMessage(reqFrame.header.version))
+               return nil
+       }
+
+       // if a client uses an unsupported protocol version, we respond with an 
error
+       if !r.supportsProtocol(reqFrame.header.version) {
+               respFrame.writeHeader(0, opError, reqFrame.header.stream)

Review Comment:
   can we have a test case that replies with the stream id of the request in 
addition to this one? This behavior is not defined in the spec and I think some 
C* versions might behave differently



##########
control.go:
##########
@@ -262,26 +217,27 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) 
(int, error) {
 
        var err error
        for _, host := range hosts {
-               var conn *Conn
-               conn, err = c.session.dial(c.session.ctx, host, &connCfg, 
handler)
-               if conn != nil {
-                       conn.Close()
-               }
+               connCfg := *c.session.connCfg
+               for proto := highestProtocolVersionSupported; proto >= 
lowestProtocolVersionSupported; proto-- {
+                       connCfg.ProtoVersion = proto
+
+                       var conn *Conn
+                       conn, err = c.session.dial(c.session.ctx, host, 
&connCfg, handler)
+                       if conn != nil {
+                               conn.Close()
+                       }
 
-               if err == nil {
-                       c.session.logger.Debug("Discovered protocol version 
using host.",
-                               NewLogFieldInt("protocol_version", 
connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()), 
NewLogFieldString("host_id", host.HostID()))
-                       return connCfg.ProtoVersion, nil
-               }
+                       if err == nil {
+                               c.session.logger.Debug("Discovered protocol 
version using host.",
+                                       NewLogFieldInt("protocol_version", 
connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()))
+                               return connCfg.ProtoVersion, nil
+                       }
 
-               if proto := parseProtocolFromError(err); proto > 0 {
-                       c.session.logger.Debug("Discovered protocol version 
using host after parsing protocol error.",
-                               NewLogFieldInt("protocol_version", proto), 
NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", 
host.HostID()))
-                       return proto, nil
+                       c.session.logger.Debug("Failed to connect to the host 
using protocol version.",

Review Comment:
   We should only attempt to reconnect to the same host with a lower protocol 
version if the error is assumed to be related to unsupported protocol version: 
   - t's an error to the first request (OPTIONS)
   - the error type is PROTOCOL_ERROR or SERVER_ERROR - SERVER_ERROR is for old 
C* versions that reported this error as SERVER_ERROR instead of PROTOCOL_ERROR
   
   I propose we use a custom internal error type (unsupportedProtocolVersionErr 
for example) just for this case that we can return inside the dial method and 
we test for it here in discoverProtocol(). Reference code [in the java driver 
here](https://github.com/apache/cassandra-java-driver/blob/62eade21bfeb16a12ce71013fbaebf1c19b5ae96/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ChannelFactory.java#L238-L259).
   
   [Technically the java driver does check for the error 
string](https://github.com/apache/cassandra-java-driver/blob/62eade21bfeb16a12ce71013fbaebf1c19b5ae96/core/src/main/java/com/datastax/oss/driver/internal/core/channel/ProtocolInitHandler.java#L327-L334)
 which I didn't know but I still believe we shouldn't do that, checking the 
above 2 conditions is enough.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to