This is an automated email from the ASF dual-hosted git repository.

wangdan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pegasus.git


The following commit(s) were added to refs/heads/master by this push:
     new fcab87d96 feat(go-client): introduce Equal APIs for HostPort and 
RPCAddress (#2360)
fcab87d96 is described below

commit fcab87d96b7f5451dbb70c8f12ac33cfd3eca8b3
Author: Dan Wang <[email protected]>
AuthorDate: Wed Feb 11 17:42:28 2026 +0800

    feat(go-client): introduce Equal APIs for HostPort and RPCAddress (#2360)
    
    https://github.com/apache/incubator-pegasus/issues/2358
    
    Add `Equal` interfaces for the `HostPort` and `RPCAddress` structures to 
enable
    equality comparisons based on attributes such as host and port.
---
 go-client/idl/base/host_port.go        | 59 +++++++++++++++------
 go-client/idl/base/host_port_test.go   | 94 ++++++++++++++++++++++++++++------
 go-client/idl/base/rpc_address.go      | 34 ++++++------
 go-client/idl/base/rpc_address_test.go | 59 ++++++++++++++++++---
 4 files changed, 192 insertions(+), 54 deletions(-)

diff --git a/go-client/idl/base/host_port.go b/go-client/idl/base/host_port.go
index 22b721815..8fc1981ea 100644
--- a/go-client/idl/base/host_port.go
+++ b/go-client/idl/base/host_port.go
@@ -48,7 +48,7 @@ func NewHostPort(host string, port uint16) *HostPort {
        }
 }
 
-func (r *HostPort) Read(iprot thrift.TProtocol) error {
+func (hp *HostPort) Read(iprot thrift.TProtocol) error {
        host, err := iprot.ReadString()
        if err != nil {
                return err
@@ -62,43 +62,68 @@ func (r *HostPort) Read(iprot thrift.TProtocol) error {
                return err
        }
 
-       r.host = host
-       r.port = uint16(port)
-       r.hpType = HostPortType(hpType)
+       hp.host = host
+       hp.port = uint16(port)
+       hp.hpType = HostPortType(hpType)
        return nil
 }
 
-func (r *HostPort) Write(oprot thrift.TProtocol) error {
-       err := oprot.WriteString(r.host)
+func (hp *HostPort) Write(oprot thrift.TProtocol) error {
+       err := oprot.WriteString(hp.host)
        if err != nil {
                return err
        }
-       err = oprot.WriteI16(int16(r.port))
+       err = oprot.WriteI16(int16(hp.port))
        if err != nil {
                return err
        }
-       err = oprot.WriteByte(int8(r.hpType))
+       err = oprot.WriteByte(int8(hp.hpType))
        if err != nil {
                return err
        }
        return nil
 }
 
-func (r *HostPort) GetHost() string {
-       return r.host
+func (hp *HostPort) GetHost() string {
+       return hp.host
 }
 
-func (r *HostPort) GetPort() uint16 {
-       return r.port
+func (hp *HostPort) GetPort() uint16 {
+       return hp.port
 }
 
-func (r *HostPort) String() string {
-       if r == nil {
+func (hp *HostPort) String() string {
+       if hp == nil {
                return "<nil>"
        }
-       return fmt.Sprintf("HostPort(%s:%d)", r.host, r.port)
+       return fmt.Sprintf("HostPort(%s:%d)", hp.host, hp.port)
 }
 
-func (r *HostPort) GetHostPort() string {
-       return fmt.Sprintf("%s:%d", r.host, r.port)
+func (hp *HostPort) GetHostPort() string {
+       return fmt.Sprintf("%s:%d", hp.host, hp.port)
+}
+
+func (hp *HostPort) Equal(other *HostPort) bool {
+       if hp == other {
+               return true
+       }
+
+       if hp == nil || other == nil {
+               return false
+       }
+
+       if hp.hpType != other.hpType {
+               return false
+       }
+
+       switch hp.hpType {
+       case HOST_TYPE_IPV4:
+               return hp.host == other.host &&
+                       hp.port == other.port
+       case HOST_TYPE_GROUP:
+               // TODO(wangdan): support HOST_TYPE_GROUP.
+               return false
+       default:
+               return true
+       }
 }
diff --git a/go-client/idl/base/host_port_test.go 
b/go-client/idl/base/host_port_test.go
index e07d3f02f..28f42fded 100644
--- a/go-client/idl/base/host_port_test.go
+++ b/go-client/idl/base/host_port_test.go
@@ -20,32 +20,96 @@
 package base
 
 import (
+       "fmt"
        "testing"
 
        "github.com/apache/thrift/lib/go/thrift"
        "github.com/stretchr/testify/assert"
 )
 
+func testNewHostPort(t *testing.T, host string, port uint16) *HostPort {
+       hp := NewHostPort(host, port)
+       assert.Equal(t, host, hp.GetHost())
+       assert.Equal(t, port, hp.GetPort())
+       assert.True(t, hp.Equal(hp))
+
+       return hp
+}
+
+func stringify(host string, port uint16) string {
+       return fmt.Sprintf("<%s:%d>", host, port)
+}
+
 func TestHostPort(t *testing.T) {
-       testCases := map[string]uint16{
-               "localhost": 8080,
+       tests := map[string]uint16{
+               "localhost":          8080,
+               "pegasus.apache.org": 443,
        }
 
-       for host, port := range testCases {
-               hp := NewHostPort(host, port)
-               assert.Equal(t, host, hp.GetHost())
-               assert.Equal(t, port, hp.GetPort())
+       runner := func(host string, port uint16) func(t *testing.T) {
+               return func(t *testing.T) {
+                       t.Parallel()
+
+                       hp := testNewHostPort(t, host, port)
 
-               // test HostPort serialize
-               buf := thrift.NewTMemoryBuffer()
-               oprot := thrift.NewTBinaryProtocolTransport(buf)
-               hp.Write(oprot)
+                       // Test serialization.
+                       buf := thrift.NewTMemoryBuffer()
+                       oprot := thrift.NewTBinaryProtocolTransport(buf)
+                       assert.NoError(t, hp.Write(oprot))
 
-               // test HostPort deserialize
-               readHostPort := NewHostPort("", 0)
-               readHostPort.Read(oprot)
+                       // Test deserialization.
+                       peer := NewHostPort("", 0)
+                       assert.NoError(t, peer.Read(oprot))
+                       assert.True(t, peer.Equal(peer))
+
+                       // Test equality.
+                       assert.Equal(t, hp, peer)
+                       assert.True(t, hp.Equal(peer))
+                       assert.True(t, peer.Equal(hp))
+               }
+       }
+
+       for host, port := range tests {
+               t.Run(stringify(host, port), runner(host, port))
+       }
+}
+
+func TestHostPortEquality(t *testing.T) {
+       type hpCase struct {
+               host string
+               port uint16
+       }
+       type testCase struct {
+               x     hpCase
+               y     hpCase
+               equal bool
+       }
+       tests := []testCase{
+               {hpCase{"localhost", 8080}, hpCase{"localhost", 8080}, true},
+               {hpCase{"localhost", 8080}, hpCase{"pegasus.apache.org", 8080}, 
false},
+               {hpCase{"localhost", 8080}, hpCase{"localhost", 8081}, false},
+       }
+
+       testName := func(hpX hpCase, hpY hpCase) string {
+               hpName := func(hp hpCase) string {
+                       return stringify(hp.host, hp.port)
+               }
+               return fmt.Sprintf("%s-vs-%s", hpName(hpX), hpName(hpY))
+       }
+
+       runner := func(test testCase) func(t *testing.T) {
+               return func(t *testing.T) {
+                       t.Parallel()
+
+                       hpX := testNewHostPort(t, test.x.host, test.x.port)
+                       hpY := testNewHostPort(t, test.y.host, test.y.port)
+
+                       assert.Equal(t, test.equal, hpX.Equal(hpY))
+                       assert.Equal(t, test.equal, hpY.Equal(hpX))
+               }
+       }
 
-               // check equals
-               assert.Equal(t, readHostPort, hp)
+       for _, test := range tests {
+               t.Run(testName(test.x, test.y), runner(test))
        }
 }
diff --git a/go-client/idl/base/rpc_address.go 
b/go-client/idl/base/rpc_address.go
index d451c4228..16d4da83a 100644
--- a/go-client/idl/base/rpc_address.go
+++ b/go-client/idl/base/rpc_address.go
@@ -37,38 +37,42 @@ func NewRPCAddress(ip net.IP, port int) *RPCAddress {
        }
 }
 
-func (r *RPCAddress) Read(iprot thrift.TProtocol) error {
+func (a *RPCAddress) Read(iprot thrift.TProtocol) error {
        address, err := iprot.ReadI64()
        if err != nil {
                return err
        }
-       r.address = address
+       a.address = address
        return nil
 }
 
-func (r *RPCAddress) Write(oprot thrift.TProtocol) error {
-       return oprot.WriteI64(r.address)
+func (a *RPCAddress) Write(oprot thrift.TProtocol) error {
+       return oprot.WriteI64(a.address)
 }
 
-func (r *RPCAddress) String() string {
-       if r == nil {
+func (a *RPCAddress) String() string {
+       if a == nil {
                return "<nil>"
        }
-       return fmt.Sprintf("RPCAddress(%s)", r.GetAddress())
+       return fmt.Sprintf("RPCAddress(%s)", a.GetAddress())
 }
 
-func (r *RPCAddress) GetIP() net.IP {
-       return net.IPv4(byte(0xff&(r.address>>56)), byte(0xff&(r.address>>48)), 
byte(0xff&(r.address>>40)), byte(0xff&(r.address>>32)))
+func (a *RPCAddress) GetIP() net.IP {
+       return net.IPv4(byte(0xff&(a.address>>56)), byte(0xff&(a.address>>48)), 
byte(0xff&(a.address>>40)), byte(0xff&(a.address>>32)))
 }
 
-func (r *RPCAddress) GetPort() int {
-       return int(0xffff & (r.address >> 16))
+func (a *RPCAddress) GetPort() int {
+       return int(0xffff & (a.address >> 16))
 }
 
-func (r *RPCAddress) GetAddress() string {
-       return fmt.Sprintf("%s:%d", r.GetIP(), r.GetPort())
+func (a *RPCAddress) GetAddress() string {
+       return fmt.Sprintf("%s:%d", a.GetIP(), a.GetPort())
 }
 
-func (r *RPCAddress) GetRawAddress() int64 {
-       return r.address
+func (a *RPCAddress) GetRawAddress() int64 {
+       return a.address
+}
+
+func (a *RPCAddress) Equal(other *RPCAddress) bool {
+       return a.address == other.address
 }
diff --git a/go-client/idl/base/rpc_address_test.go 
b/go-client/idl/base/rpc_address_test.go
index fbb86210e..1df85c065 100644
--- a/go-client/idl/base/rpc_address_test.go
+++ b/go-client/idl/base/rpc_address_test.go
@@ -20,25 +20,70 @@
 package base
 
 import (
+       "fmt"
        "net"
        "testing"
 
        "github.com/stretchr/testify/assert"
 )
 
+func testNewRPCAddress(t *testing.T, addrStr string) *RPCAddress {
+       addr, err := net.ResolveTCPAddr("tcp", addrStr)
+       assert.NoError(t, err)
+
+       rpcAddr := NewRPCAddress(addr.IP, addr.Port)
+       assert.Equal(t, addrStr, rpcAddr.GetAddress())
+       assert.True(t, rpcAddr.Equal(rpcAddr))
+
+       return rpcAddr
+}
+
 func TestRPCAddress(t *testing.T) {
-       testCases := []string{
+       tests := []string{
                "127.0.0.1:8080",
                "192.168.0.1:123",
                "0.0.0.0:12345",
        }
 
-       for _, ts := range testCases {
-               tcpAddrStr := ts
-               addr, err := net.ResolveTCPAddr("tcp", tcpAddrStr)
-               assert.NoError(t, err)
+       runner := func(test string) func(t *testing.T) {
+               return func(t *testing.T) {
+                       t.Parallel()
+
+                       testNewRPCAddress(t, test)
+               }
+       }
+
+       for _, test := range tests {
+               name := fmt.Sprintf("<%s>", test)
+               t.Run(name, runner(test))
+       }
+}
+
+func TestRPCAddressEquality(t *testing.T) {
+       tests := []struct {
+               x     string
+               y     string
+               equal bool
+       }{
+               {"127.0.0.1:8080", "127.0.0.1:8080", true},
+               {"127.0.0.1:8080", "192.168.0.1:8080", false},
+               {"127.0.0.1:8080", "127.0.0.1:8081", false},
+       }
+
+       runner := func(x string, y string, equal bool) func(t *testing.T) {
+               return func(t *testing.T) {
+                       t.Parallel()
+
+                       rpcAddrX := testNewRPCAddress(t, x)
+                       rpcAddrY := testNewRPCAddress(t, y)
+
+                       assert.Equal(t, equal, rpcAddrX.Equal(rpcAddrY))
+                       assert.Equal(t, equal, rpcAddrY.Equal(rpcAddrX))
+               }
+       }
 
-               rpcAddr := NewRPCAddress(addr.IP, addr.Port)
-               assert.Equal(t, rpcAddr.GetAddress(), tcpAddrStr)
+       for _, test := range tests {
+               name := fmt.Sprintf("<%s>-vs-<%s>", test.x, test.y)
+               t.Run(name, runner(test.x, test.y, test.equal))
        }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to