This is an automated email from the ASF dual-hosted git repository.

jameshartig pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git


The following commit(s) were added to refs/heads/trunk by this push:
     new f400b491 CASSGO-6 Accept peers with empty rack
f400b491 is described below

commit f400b49138fe84ce980a2ea3e17d5d21232e284e
Author: James Hartig <jameshar...@apache.org>
AuthorDate: Thu Jun 19 00:32:15 2025 +0000

    CASSGO-6 Accept peers with empty rack
    
    This fixes #1706.
    
    Patch by James Hartig for CASSGO-6; reviewed by João Reis for CASSGO-6
---
 CHANGELOG.md        |   1 +
 cassandra_test.go   | 105 +++++++++++++++++++++++++++++++++++++
 control.go          |   7 ++-
 host_source.go      | 120 +++++++++++++++++++++++++++++++-----------
 host_source_test.go | 148 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 5 files changed, 349 insertions(+), 32 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6262875b..227e729d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -62,6 +62,7 @@ and this project adheres to [Semantic 
Versioning](https://semver.org/spec/v2.0.0
 - Don't panic in MapExecuteBatchCAS if no `[applied]` column is returned 
(CASSGO-42)
 - Fix deadlock in refresh debouncer stop (CASSGO-41)
 - Endless query execution fix (CASSGO-50)
+- Accept peers with empty rack (CASSGO-6)
 
 ## [1.7.0] - 2024-09-23
 
diff --git a/cassandra_test.go b/cassandra_test.go
index 2613f17d..b92c3cd6 100644
--- a/cassandra_test.go
+++ b/cassandra_test.go
@@ -4013,3 +4013,108 @@ func TestRoutingKeyCacheUsesOverriddenKeyspace(t 
*testing.T) {
 
        session.Query("DROP KEYSPACE IF EXISTS 
gocql_test_routing_key_cache").Exec()
 }
+
+func TestHostInfoFromIter(t *testing.T) {
+       session := createSession(t)
+       defer session.Close()
+
+       err := createTable(session, `CREATE TABLE IF NOT EXISTS 
gocql_test.system_peers(
+               peer inet PRIMARY KEY,
+               data_center text,
+               host_id uuid,
+               preferred_ip inet,
+               rack text,
+               release_version text,
+               rpc_address inet,
+               schema_version uuid,
+               tokens set<text>
+       )`)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       id1 := MustRandomUUID()
+       err = session.Query(
+               "INSERT INTO gocql_test.system_peers (peer, data_center, 
host_id, rack, release_version, rpc_address, tokens) VALUES (?, ?, ?, ?, ?, ?, 
?)",
+               net.ParseIP("10.0.0.1"),
+               "dc1",
+               id1,
+               "rack1",
+               "4.0.0",
+               net.ParseIP("10.0.0.2"),
+               []string{"0", "1"},
+       ).Exec()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       id2 := MustRandomUUID()
+       err = session.Query(
+               "INSERT INTO gocql_test.system_peers (peer, data_center, 
host_id, release_version, rpc_address, tokens) VALUES (?, ?, ?, ?, ?, ?)",
+               net.ParseIP("10.0.0.2"),
+               "dc2",
+               id2,
+               "4.0.0",
+               net.ParseIP("10.0.0.3"),
+               []string{"0", "1"},
+       ).Exec()
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       iter := session.Query("SELECT * FROM gocql_test.system_peers WHERE 
data_center='dc1' ALLOW FILTERING").Iter()
+
+       h, err := session.hostInfoFromIter(iter, nil, 9042)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if !isValidPeer(h) {
+               t.Errorf("expected %+v to be a valid peer", h)
+       }
+       if addr := h.ConnectAddressAndPort(); addr != "10.0.0.2:9042" {
+               t.Errorf("unexpected connect address: %s != '10.0.0.2:9042'", 
addr)
+       }
+       if h.HostID() != id1.String() {
+               t.Errorf("unexpected hostID %s != %s", h.HostID(), id1.String())
+       }
+       if h.Version().String() != "v4.0.0" {
+               t.Errorf("unexpected version %s != v4.0.0", 
h.Version().String())
+       }
+       if h.Rack() != "rack1" {
+               t.Errorf("unexpected rack %s != 'rack1'", h.Rack())
+       }
+       if h.DataCenter() != "dc1" {
+               t.Errorf("unexpected data center %s != 'dc1'", h.DataCenter())
+       }
+       if h.missingRack {
+               t.Errorf("unexpected missing rack")
+       }
+
+       iter = session.Query("SELECT * FROM gocql_test.system_peers WHERE 
data_center='dc2' ALLOW FILTERING").Iter()
+
+       h, err = session.hostInfoFromIter(iter, nil, 9042)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if isValidPeer(h) {
+               t.Errorf("expected %+v to be an invalid peer", h)
+       }
+       if addr := h.ConnectAddressAndPort(); addr != "10.0.0.3:9042" {
+               t.Errorf("unexpected connect address: %s != '10.0.0.3:9042'", 
addr)
+       }
+       if h.HostID() != id2.String() {
+               t.Errorf("unexpected hostID %s != %s", h.HostID(), id2.String())
+       }
+       if h.Version().String() != "v4.0.0" {
+               t.Errorf("unexpected version %s != v4.0.0", 
h.Version().String())
+       }
+       if h.Rack() != "" {
+               t.Errorf("unexpected rack %s != ''", h.Rack())
+       }
+       if h.DataCenter() != "dc2" {
+               t.Errorf("unexpected data center %s != 'dc2'", h.DataCenter())
+       }
+       if !h.missingRack {
+               t.Errorf("unexpected non-missing rack")
+       }
+}
diff --git a/control.go b/control.go
index fa30cd05..b521c41d 100644
--- a/control.go
+++ b/control.go
@@ -322,7 +322,12 @@ func (c *controlConn) setupConn(conn *Conn, sessionInit 
bool) error {
        iter := conn.querySystemLocal(context.TODO())
        host, err := c.session.hostInfoFromIter(iter, conn.host.connectAddress, 
conn.r.RemoteAddr().(*net.TCPAddr).Port)
        if err != nil {
-               return err
+               // just cleanup
+               iter.Close()
+               return fmt.Errorf("could not retrieve control host info: %w", 
err)
+       }
+       if host == nil {
+               return errors.New("could not retrieve control host info: query 
returned 0 rows")
        }
 
        var exists bool
diff --git a/host_source.go b/host_source.go
index fbc9134d..3ee3dd07 100644
--- a/host_source.go
+++ b/host_source.go
@@ -171,6 +171,7 @@ type HostInfo struct {
        port             int
        dataCenter       string
        rack             string
+       missingRack      bool
        hostId           string
        workload         string
        graph            bool
@@ -413,8 +414,9 @@ func (h *HostInfo) update(from *HostInfo) {
        if h.dataCenter == "" {
                h.dataCenter = from.dataCenter
        }
-       if h.rack == "" {
+       if h.missingRack {
                h.rack = from.rack
+               h.missingRack = from.missingRack
        }
        if h.hostId == "" {
                h.hostId = from.hostId
@@ -530,7 +532,7 @@ func newHostInfoFromRow(s *Session, defaultAddr net.IP, 
defaultPort int, row map
        const assertErrorMsg = "Assertion failed for %s, type was %T"
        var ok bool
 
-       host := &HostInfo{connectAddress: defaultAddr, port: defaultPort}
+       host := &HostInfo{connectAddress: defaultAddr, port: defaultPort, 
missingRack: true}
 
        // Process all fields from the row
        for key, value := range row {
@@ -541,14 +543,30 @@ func newHostInfoFromRow(s *Session, defaultAddr net.IP, 
defaultPort int, row map
                                return nil, fmt.Errorf(assertErrorMsg, 
"data_center", value)
                        }
                case "rack":
-                       host.rack, ok = value.(string)
+                       rack, ok := value.(*string)
                        if !ok {
-                               return nil, fmt.Errorf(assertErrorMsg, "rack", 
value)
+                               if rack, ok := value.(string); !ok {
+                                       return nil, fmt.Errorf(assertErrorMsg, 
"rack", value)
+                               } else {
+                                       host.rack = rack
+                                       host.missingRack = false
+                               }
+                       } else if rack != nil {
+                               host.rack = *rack
+                               host.missingRack = false
                        }
                case "host_id":
                        hostId, ok := value.(UUID)
                        if !ok {
-                               return nil, fmt.Errorf(assertErrorMsg, 
"host_id", value)
+                               if str, ok := value.(string); ok {
+                                       var err error
+                                       hostId, err = ParseUUID(str)
+                                       if err != nil {
+                                               return nil, fmt.Errorf("failed 
to parse host_id: %w", err)
+                                       }
+                               } else {
+                                       return nil, fmt.Errorf(assertErrorMsg, 
"host_id", value)
+                               }
                        }
                        host.hostId = hostId.String()
                case "release_version":
@@ -560,7 +578,11 @@ func newHostInfoFromRow(s *Session, defaultAddr net.IP, 
defaultPort int, row map
                case "peer":
                        ip, ok := value.(net.IP)
                        if !ok {
-                               return nil, fmt.Errorf(assertErrorMsg, "peer", 
value)
+                               if str, ok := value.(string); ok {
+                                       ip = net.ParseIP(str)
+                               } else {
+                                       return nil, fmt.Errorf(assertErrorMsg, 
"peer", value)
+                               }
                        }
                        host.peer = ip
                case "cluster_name":
@@ -576,31 +598,51 @@ func newHostInfoFromRow(s *Session, defaultAddr net.IP, 
defaultPort int, row map
                case "broadcast_address":
                        ip, ok := value.(net.IP)
                        if !ok {
-                               return nil, fmt.Errorf(assertErrorMsg, 
"broadcast_address", value)
+                               if str, ok := value.(string); ok {
+                                       ip = net.ParseIP(str)
+                               } else {
+                                       return nil, fmt.Errorf(assertErrorMsg, 
"broadcast_address", value)
+                               }
                        }
                        host.broadcastAddress = ip
                case "preferred_ip":
                        ip, ok := value.(net.IP)
                        if !ok {
-                               return nil, fmt.Errorf(assertErrorMsg, 
"preferred_ip", value)
+                               if str, ok := value.(string); ok {
+                                       ip = net.ParseIP(str)
+                               } else {
+                                       return nil, fmt.Errorf(assertErrorMsg, 
"preferred_ip", value)
+                               }
                        }
                        host.preferredIP = ip
                case "rpc_address":
                        ip, ok := value.(net.IP)
                        if !ok {
-                               return nil, fmt.Errorf(assertErrorMsg, 
"rpc_address", value)
+                               if str, ok := value.(string); ok {
+                                       ip = net.ParseIP(str)
+                               } else {
+                                       return nil, fmt.Errorf(assertErrorMsg, 
"rpc_address", value)
+                               }
                        }
                        host.rpcAddress = ip
                case "native_address":
                        ip, ok := value.(net.IP)
                        if !ok {
-                               return nil, fmt.Errorf(assertErrorMsg, 
"native_address", value)
+                               if str, ok := value.(string); ok {
+                                       ip = net.ParseIP(str)
+                               } else {
+                                       return nil, fmt.Errorf(assertErrorMsg, 
"native_address", value)
+                               }
                        }
                        host.rpcAddress = ip
                case "listen_address":
                        ip, ok := value.(net.IP)
                        if !ok {
-                               return nil, fmt.Errorf(assertErrorMsg, 
"listen_address", value)
+                               if str, ok := value.(string); ok {
+                                       ip = net.ParseIP(str)
+                               } else {
+                                       return nil, fmt.Errorf(assertErrorMsg, 
"listen_address", value)
+                               }
                        }
                        host.listenAddress = ip
                case "native_port":
@@ -666,18 +708,23 @@ func newHostInfoFromRow(s *Session, defaultAddr net.IP, 
defaultPort int, row map
        }
 }
 
+// this will return nil, nil if there were no rows left in the Iter
 func (s *Session) hostInfoFromIter(iter *Iter, connectAddress net.IP, 
defaultPort int) (*HostInfo, error) {
-       rows, err := iter.SliceMap()
-       if err != nil {
-               // TODO(zariel): make typed error
-               return nil, err
+       // TODO: switch this to a new iterator method once CASSGO-36 is solved
+       m := map[string]interface{}{
+               // we set rack to a double pointer so we can know if it's NULL 
or not since
+               // we need to be able to filter out NULL rack hosts but not 
empty string hosts
+               // see CASSGO-6
+               "rack": new(*string),
        }
-
-       if len(rows) == 0 {
-               return nil, errors.New("query returned 0 rows")
+       if !iter.MapScan(m) {
+               if err := iter.Close(); err != nil {
+                       return nil, err
+               }
+               return nil, nil
        }
 
-       host, err := s.newHostInfoFromMap(connectAddress, defaultPort, rows[0])
+       host, err := s.newHostInfoFromMap(connectAddress, defaultPort, m)
        if err != nil {
                return nil, err
        }
@@ -700,8 +747,13 @@ func (r *ringDescriber) getLocalHostInfo() (*HostInfo, 
error) {
 
        host, err := r.session.hostInfoFromIter(iter, nil, r.session.cfg.Port)
        if err != nil {
+               // just cleanup
+               iter.Close()
                return nil, fmt.Errorf("could not retrieve local host info: 
%w", err)
        }
+       if host == nil {
+               return nil, errors.New("could not retrieve local host info: 
query returned 0 rows")
+       }
        return host, nil
 }
 
@@ -711,7 +763,6 @@ func (r *ringDescriber) getClusterPeerInfo(localHost 
*HostInfo) ([]*HostInfo, er
                return nil, errNoControl
        }
 
-       var peers []*HostInfo
        iter := r.session.control.withConnHost(func(ch *connHost) *Iter {
                return ch.conn.querySystemPeers(context.TODO(), 
localHost.version)
        })
@@ -720,18 +771,25 @@ func (r *ringDescriber) getClusterPeerInfo(localHost 
*HostInfo) ([]*HostInfo, er
                return nil, errNoControl
        }
 
-       rows, err := iter.SliceMap()
-       if err != nil {
-               // TODO(zariel): make typed error
-               return nil, fmt.Errorf("unable to fetch peer host info: %s", 
err)
-       }
-
-       for _, row := range rows {
+       var peers []*HostInfo
+       for {
                // extract all available info about the peer
-               host, err := r.session.newHostInfoFromMap(nil, 
r.session.cfg.Port, row)
+               host, err := r.session.hostInfoFromIter(iter, nil, 
r.session.cfg.Port)
                if err != nil {
-                       return nil, err
-               } else if !isValidPeer(host) {
+                       // if the error came from the iterator then return it, 
otherwise ignore
+                       // and warn
+                       if iterErr := iter.Close(); iterErr != nil {
+                               return nil, fmt.Errorf("unable to fetch peer 
host info: %s", iterErr)
+                       }
+                       // skip over peers that we couldn't parse
+                       r.session.logger.Warning("Failed to parse peer this 
host will be ignored.", newLogFieldError("err", err))
+                       continue
+               }
+               // if nil then none left
+               if host == nil {
+                       break
+               }
+               if !isValidPeer(host) {
                        // If it's not a valid peer
                        r.session.logger.Warning("Found invalid peer "+
                                "likely due to a gossip or snitch issue, this 
host will be ignored.", newLogFieldStringer("host", host))
@@ -749,7 +807,7 @@ func isValidPeer(host *HostInfo) bool {
        return !(len(host.RPCAddress()) == 0 ||
                host.hostId == "" ||
                host.dataCenter == "" ||
-               host.rack == "" ||
+               host.missingRack ||
                len(host.tokens) == 0)
 }
 
diff --git a/host_source_test.go b/host_source_test.go
index e2454be5..07086e68 100644
--- a/host_source_test.go
+++ b/host_source_test.go
@@ -85,6 +85,153 @@ func TestCassVersionBefore(t *testing.T) {
 
 }
 
+func TestNewHostInfoFromRow(t *testing.T) {
+       id := MustRandomUUID()
+       row := map[string]interface{}{
+               "broadcast_address": "10.0.0.1",
+               "listen_address":    net.ParseIP("10.0.0.2"),
+               "rpc_address":       net.ParseIP("10.0.0.3"),
+               "data_center":       "dc",
+               "rack":              "",
+               "host_id":           id,
+               "release_version":   "4.0.0",
+               "native_port":       9042,
+               "tokens":            []string{"0", "1"},
+       }
+       s := &Session{}
+       h, err := newHostInfoFromRow(s, nil, 0, row)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if !isValidPeer(h) {
+               t.Errorf("expected %+v to be a valid peer", h)
+       }
+       if addr := h.ConnectAddressAndPort(); addr != "10.0.0.3:9042" {
+               t.Errorf("unexpected connect address: %s != '10.0.0.3:9042'", 
addr)
+       }
+       if h.HostID() != id.String() {
+               t.Errorf("unexpected hostID %s != %s", h.HostID(), id.String())
+       }
+       if h.Version().String() != "v4.0.0" {
+               t.Errorf("unexpected version %s != v4.0.0", 
h.Version().String())
+       }
+       if h.Rack() != "" {
+               t.Errorf("unexpected rack %s != ''", h.Rack())
+       }
+       if h.DataCenter() != "dc" {
+               t.Errorf("unexpected data center %s != 'dc'", h.DataCenter())
+       }
+
+       row = map[string]interface{}{
+               "broadcast_address": "10.0.0.1",
+               "listen_address":    net.ParseIP("10.0.0.2"),
+               "preferred_ip":      "10.0.0.4",
+               "data_center":       "dc",
+               "rack":              "rack",
+               "host_id":           id,
+               "release_version":   "4.0.0",
+               "native_port":       9042,
+               "tokens":            []string{"0", "1"},
+       }
+       h, err = newHostInfoFromRow(s, nil, 0, row)
+       if err != nil {
+               t.Fatal(err)
+       }
+       // missing rpc_address
+       if isValidPeer(h) {
+               t.Errorf("expected %+v to be an invalid peer", h)
+       }
+       if addr := h.ConnectAddressAndPort(); addr != "10.0.0.4:9042" {
+               t.Errorf("unexpected connect address: %s != '10.0.0.4:9042'", 
addr)
+       }
+       if h.Rack() != "rack" {
+               t.Errorf("unexpected rack %s != 'rack'", h.Rack())
+       }
+
+       row = map[string]interface{}{
+               "broadcast_address": "10.0.0.1",
+               "data_center":       "dc",
+               "rack":              "rack",
+               "host_id":           id,
+               "native_port":       9042,
+               "tokens":            []string{"0", "1"},
+       }
+       h, err = newHostInfoFromRow(s, nil, 0, row)
+       if err != nil {
+               t.Fatal(err)
+       }
+       // missing rpc_address
+       if isValidPeer(h) {
+               t.Errorf("expected %+v to be an invalid peer", h)
+       }
+       if addr := h.ConnectAddressAndPort(); addr != "10.0.0.1:9042" {
+               t.Errorf("unexpected connect address: %s != '10.0.0.1:9042'", 
addr)
+       }
+
+       row = map[string]interface{}{
+               "rpc_address": "10.0.0.2",
+               "data_center": "dc",
+               "rack":        "rack",
+               "host_id":     id,
+               "tokens":      []string{"0", "1"},
+       }
+       s = &Session{
+               cfg: ClusterConfig{
+                       AddressTranslator: AddressTranslatorFunc(func(addr 
net.IP, port int) (net.IP, int) {
+                               if !addr.Equal(net.ParseIP("10.0.0.2")) {
+                                       t.Errorf("unexpected ip sent to 
translator: %s != '10.0.0.2'", addr.String())
+                               }
+                               if port != 9042 {
+                                       t.Errorf("unexpected port sent to 
translator: %d != 9042", port)
+                               }
+                               return net.ParseIP("10.0.0.5"), 9043
+                       }),
+               },
+               logger: &defaultLogger{},
+       }
+       h, err = newHostInfoFromRow(s, nil, 9042, row)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if !isValidPeer(h) {
+               t.Errorf("expected %+v to be a valid peer", h)
+       }
+       if addr := h.ConnectAddressAndPort(); addr != "10.0.0.5:9043" {
+               t.Errorf("unexpected connect address: %s != '10.0.0.5:9043'", 
addr)
+       }
+
+       // missing rack
+       row = map[string]interface{}{
+               "rpc_address": "10.0.0.2",
+               "data_center": "dc",
+               "host_id":     id,
+               "tokens":      []string{"0", "1"},
+       }
+       h, err = newHostInfoFromRow(nil, nil, 9042, row)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if isValidPeer(h) {
+               t.Errorf("expected %+v to be an invalid peer", h)
+       }
+       if h.Rack() != "" {
+               t.Errorf("unexpected rack %s != ''", h.Rack())
+       }
+
+       // inavlid ip
+       row = map[string]interface{}{
+               "rpc_address": net.ParseIP("0.0.0.0"),
+               "data_center": "dc",
+               "rack":        "rack",
+               "host_id":     id,
+               "tokens":      []string{"0", "1"},
+       }
+       _, err = newHostInfoFromRow(nil, nil, 9042, row)
+       if err == nil {
+               t.Error("expected invalid ip to error")
+       }
+}
+
 func TestIsValidPeer(t *testing.T) {
        host := &HostInfo{
                rpcAddress: net.ParseIP("0.0.0.0"),
@@ -99,6 +246,7 @@ func TestIsValidPeer(t *testing.T) {
        }
 
        host.rack = ""
+       host.missingRack = true
        if isValidPeer(host) {
                t.Errorf("expected %+v to NOT be a valid peer", host)
        }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org

Reply via email to