This is an automated email from the ASF dual-hosted git repository.

joaoreis pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/cassandra-gocql-driver.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 65e2caf  Refactor HostInfo creation and ConnectAddress() method
65e2caf is described below

commit 65e2cafa8c46534a1aef83d5483302d23f7cb792
Author: tengu-alt <olexandr.luzh...@gmail.com>
AuthorDate: Mon Feb 17 11:19:34 2025 +0200

    Refactor HostInfo creation and ConnectAddress() method
    
    HostInfo struct creation was refactored to create via constructor to make 
sure the connectAddress is valid.
    Panic in case of invalid connect address inside of ConnectAddress() method 
was removed.
    
    patch by Oleksandr Luzhniy; reviewed by João Reis, James Hartig, for 
CASSGO-45
---
 CHANGELOG.md   |  2 ++
 conn.go        |  6 +++++-
 control.go     | 13 +++++++++++--
 host_source.go | 46 ++++++++++++++++++++++++++--------------------
 4 files changed, 44 insertions(+), 23 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5cfff09..9512c2f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -35,6 +35,8 @@ and this project adheres to [Semantic 
Versioning](https://semver.org/spec/v2.0.0
 
 - Standardized spelling of datacenter (CASSGO-35)
 
+- Refactor HostInfo creation and ConnectAddress() method (CASSGO-45)
+
 ### Fixed
 - Cassandra version unmarshal fix (CASSGO-49)
 
diff --git a/conn.go b/conn.go
index cd3fda6..aac75e4 100644
--- a/conn.go
+++ b/conn.go
@@ -1690,7 +1690,11 @@ func (c *Conn) awaitSchemaAgreement(ctx context.Context) 
(err error) {
                }
 
                for _, row := range rows {
-                       host, err := c.session.hostInfoFromMap(row, 
&HostInfo{connectAddress: c.host.ConnectAddress(), port: c.session.cfg.Port})
+                       h, err := newHostInfo(c.host.ConnectAddress(), 
c.session.cfg.Port)
+                       if err != nil {
+                               goto cont
+                       }
+                       host, err := c.session.hostInfoFromMap(row, h)
                        if err != nil {
                                goto cont
                        }
diff --git a/control.go b/control.go
index b30b44e..0e2a859 100644
--- a/control.go
+++ b/control.go
@@ -146,7 +146,11 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, 
error) {
 
        // Check if host is a literal IP address
        if ip := net.ParseIP(host); ip != nil {
-               hosts = append(hosts, &HostInfo{hostname: host, connectAddress: 
ip, port: port})
+               h, err := newHostInfo(ip, port)
+               if err != nil {
+                       return nil, err
+               }
+               hosts = append(hosts, h)
                return hosts, nil
        }
 
@@ -172,7 +176,12 @@ func hostInfo(addr string, defaultPort int) ([]*HostInfo, 
error) {
        }
 
        for _, ip := range ips {
-               hosts = append(hosts, &HostInfo{hostname: host, connectAddress: 
ip, port: port})
+               h, err := newHostInfo(ip, port)
+               if err != nil {
+                       return nil, err
+               }
+
+               hosts = append(hosts, h)
        }
 
        return hosts, nil
diff --git a/host_source.go b/host_source.go
index 2be9e3a..ffe54cf 100644
--- a/host_source.go
+++ b/host_source.go
@@ -181,6 +181,18 @@ type HostInfo struct {
        tokens           []string
 }
 
+func newHostInfo(addr net.IP, port int) (*HostInfo, error) {
+       if !validIpAddr(addr) {
+               return nil, errors.New("invalid host address")
+       }
+       host := &HostInfo{}
+       host.hostname = addr.String()
+       host.port = port
+
+       host.connectAddress = addr
+       return host, nil
+}
+
 func (h *HostInfo) Equal(host *HostInfo) bool {
        if h == host {
                // prevent rlock reentry
@@ -213,14 +225,12 @@ func (h *HostInfo) connectAddressLocked() (net.IP, 
string) {
        } else if validIpAddr(h.rpcAddress) {
                return h.rpcAddress, "rpc_adress"
        } else if validIpAddr(h.preferredIP) {
-               // where does perferred_ip get set?
                return h.preferredIP, "preferred_ip"
        } else if validIpAddr(h.broadcastAddress) {
                return h.broadcastAddress, "broadcast_address"
-       } else if validIpAddr(h.peer) {
-               return h.peer, "peer"
        }
-       return net.IPv4zero, "invalid"
+       return h.peer, "peer"
+
 }
 
 // nodeToNodeAddress returns address broadcasted between node to nodes.
@@ -240,24 +250,13 @@ func (h *HostInfo) nodeToNodeAddress() net.IP {
 }
 
 // Returns the address that should be used to connect to the host.
-// If you wish to override this, use an AddressTranslator or
-// use a HostFilter to SetConnectAddress()
+// If you wish to override this, use an AddressTranslator
 func (h *HostInfo) ConnectAddress() net.IP {
        h.mu.RLock()
        defer h.mu.RUnlock()
 
-       if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
-               return addr
-       }
-       panic(fmt.Sprintf("no valid connect address for host: %v. Is your 
cluster configured correctly?", h))
-}
-
-func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
-       // TODO(zariel): should this not be exported?
-       h.mu.Lock()
-       defer h.mu.Unlock()
-       h.connectAddress = address
-       return h
+       addr, _ := h.connectAddressLocked()
+       return addr
 }
 
 func (h *HostInfo) BroadcastAddress() net.IP {
@@ -491,6 +490,10 @@ func checkSystemSchema(control *controlConn) (bool, error) 
{
        return true, nil
 }
 
+func (s *Session) newHostInfoFromMap(addr net.IP, port int, row 
map[string]interface{}) (*HostInfo, error) {
+       return s.hostInfoFromMap(row, &HostInfo{connectAddress: addr, port: 
port})
+}
+
 // Given a map that represents a row from either system.local or system.peers
 // return as much information as we can in *HostInfo
 func (s *Session) hostInfoFromMap(row map[string]interface{}, host *HostInfo) 
(*HostInfo, error) {
@@ -606,6 +609,9 @@ func (s *Session) hostInfoFromMap(row 
map[string]interface{}, host *HostInfo) (*
        }
 
        ip, port := s.cfg.translateAddressPort(host.ConnectAddress(), host.port)
+       if !validIpAddr(ip) {
+               return nil, fmt.Errorf("invalid host address (before 
translation: %v:%v, after translation: %v:%v)", host.ConnectAddress(), 
host.port, ip.String(), port)
+       }
        host.connectAddress = ip
        host.port = port
 
@@ -623,7 +629,7 @@ func (s *Session) hostInfoFromIter(iter *Iter, 
connectAddress net.IP, defaultPor
                return nil, errors.New("query returned 0 rows")
        }
 
-       host, err := s.hostInfoFromMap(rows[0], &HostInfo{connectAddress: 
connectAddress, port: defaultPort})
+       host, err := s.newHostInfoFromMap(connectAddress, defaultPort, rows[0])
        if err != nil {
                return nil, err
        }
@@ -674,7 +680,7 @@ func (r *ringDescriber) getClusterPeerInfo(localHost 
*HostInfo) ([]*HostInfo, er
 
        for _, row := range rows {
                // extract all available info about the peer
-               host, err := r.session.hostInfoFromMap(row, &HostInfo{port: 
r.session.cfg.Port})
+               host, err := r.session.newHostInfoFromMap(nil, 
r.session.cfg.Port, row)
                if err != nil {
                        return nil, err
                } else if !isValidPeer(host) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org

Reply via email to