This is an automated email from the ASF dual-hosted git repository.

hanahmily pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git


The following commit(s) were added to refs/heads/main by this push:
     new 74334e027 Refactor part of the pub.Client as common library (#972)
74334e027 is described below

commit 74334e027b7fad414c9773f803163b5c97a9655b
Author: mrproliu <[email protected]>
AuthorDate: Fri Feb 13 07:40:40 2026 +0800

    Refactor part of the pub.Client as common library (#972)
    
    * Refactor part of the pub.Client as common library
---
 banyand/internal/storage/segment.go                |   8 +-
 banyand/internal/storage/tsdb.go                   |   2 +-
 banyand/queue/pub/batch.go                         |  38 +-
 banyand/queue/pub/client.go                        | 317 +------------
 banyand/queue/pub/client_test.go                   |  10 +-
 banyand/queue/pub/pub.go                           | 246 ++++------
 banyand/queue/pub/pub_suite_test.go                |  26 +-
 banyand/queue/pub/pub_test.go                      |  26 +-
 banyand/queue/pub/pub_tls_test.go                  |   4 +-
 banyand/queue/pub/retry_test.go                    | 202 --------
 .../queue/pub => pkg/grpchelper}/circuitbreaker.go |  46 +-
 .../pub => pkg/grpchelper}/circuitbreaker_test.go  | 459 +++++++++---------
 pkg/grpchelper/connmanager.go                      | 526 +++++++++++++++++++++
 pkg/grpchelper/connmanager_test.go                 | 115 +++++
 pkg/grpchelper/helpers_test.go                     |  67 +++
 {banyand/queue/pub => pkg/grpchelper}/retry.go     |  61 ++-
 pkg/grpchelper/retry_test.go                       | 500 ++++++++++++++++++++
 17 files changed, 1655 insertions(+), 998 deletions(-)

diff --git a/banyand/internal/storage/segment.go 
b/banyand/internal/storage/segment.go
index 74dff1518..2953ce3b5 100644
--- a/banyand/internal/storage/segment.go
+++ b/banyand/internal/storage/segment.go
@@ -74,7 +74,7 @@ type segment[T TSTable, O any] struct {
        suffix        string
        location      string
        lastAccessed  atomic.Int64
-       mu            sync.Mutex
+       mu            sync.RWMutex
        refCount      int32
        mustBeDeleted uint32
        id            segmentID
@@ -194,6 +194,12 @@ func (s *segment[T, O]) initialize(ctx context.Context) 
error {
        return nil
 }
 
+func (s *segment[T, O]) collectMetrics() {
+       s.mu.RLock()
+       defer s.mu.RUnlock()
+       s.index.store.CollectMetrics(s.index.p.SegLabelValues()...)
+}
+
 func (s *segment[T, O]) DecRef() {
        shouldCleanup := false
 
diff --git a/banyand/internal/storage/tsdb.go b/banyand/internal/storage/tsdb.go
index f341b5aec..a0ad00e28 100644
--- a/banyand/internal/storage/tsdb.go
+++ b/banyand/internal/storage/tsdb.go
@@ -365,7 +365,7 @@ func (d *database[T, O]) collect() {
                for _, t := range tables {
                        t.Collect(d.segmentController.metrics)
                }
-               s.index.store.CollectMetrics(s.index.p.SegLabelValues()...)
+               s.collectMetrics()
                s.DecRef()
                refCount += atomic.LoadInt32(&s.refCount)
        }
diff --git a/banyand/queue/pub/batch.go b/banyand/queue/pub/batch.go
index 127bec113..b755d300a 100644
--- a/banyand/queue/pub/batch.go
+++ b/banyand/queue/pub/batch.go
@@ -31,9 +31,17 @@ import (
        modelv1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
        "github.com/apache/skywalking-banyandb/banyand/queue"
        "github.com/apache/skywalking-banyandb/pkg/bus"
+       "github.com/apache/skywalking-banyandb/pkg/grpchelper"
        "github.com/apache/skywalking-banyandb/pkg/logger"
 )
 
+const (
+       defaultMaxRetries        = 3
+       defaultPerRequestTimeout = 2 * time.Second
+       defaultBackoffBase       = 500 * time.Millisecond
+       defaultBackoffMax        = 30 * time.Second
+)
+
 type writeStream struct {
        client    clusterv1.Service_SendClient
        ctxDoneCh <-chan struct{}
@@ -72,7 +80,7 @@ func (bp *batchPublisher) Publish(ctx context.Context, topic 
bus.Topic, messages
                node := m.Node()
 
                // Check circuit breaker before attempting send
-               if !bp.pub.isRequestAllowed(node) {
+               if !bp.pub.connMgr.IsRequestAllowed(node) {
                        err = multierr.Append(err, fmt.Errorf("circuit breaker 
open for node %s", node))
                        continue
                }
@@ -95,11 +103,11 @@ func (bp *batchPublisher) Publish(ctx context.Context, 
topic bus.Topic, messages
                                if errSend != nil {
                                        err = multierr.Append(err, 
fmt.Errorf("failed to send message to node %s: %w", node, errSend))
                                        // Record failure for circuit breaker 
(only for transient/internal errors)
-                                       bp.pub.recordFailure(node, errSend)
+                                       bp.pub.connMgr.RecordFailure(node, 
errSend)
                                        return false
                                }
                                // Record success for circuit breaker
-                               bp.pub.recordSuccess(node)
+                               bp.pub.connMgr.RecordSuccess(node)
                                return true
                        }
                        return false
@@ -119,14 +127,12 @@ func (bp *batchPublisher) Publish(ctx context.Context, 
topic bus.Topic, messages
                        }
                        continue
                }
-               var client *client
+               var nodeClient *client
                // nolint: contextcheck
                if func() bool {
-                       bp.pub.mu.RLock()
-                       defer bp.pub.mu.RUnlock()
-                       var ok bool
-                       client, ok = bp.pub.active[node]
-                       if !ok {
+                       var clientOK bool
+                       nodeClient, clientOK = bp.pub.connMgr.GetClient(node)
+                       if !clientOK {
                                err = multierr.Append(err, fmt.Errorf("failed 
to get client for node %s", node))
                                return true
                        }
@@ -147,7 +153,7 @@ func (bp *batchPublisher) Publish(ctx context.Context, 
topic bus.Topic, messages
                streamCtx, cancel := context.WithTimeout(ctx, bp.timeout)
                // this assignment is for getting around the go vet lint
                deferFn := cancel
-               stream, errCreateStream := client.client.Send(streamCtx)
+               stream, errCreateStream := nodeClient.client.Send(streamCtx)
                if errCreateStream != nil {
                        err = multierr.Append(err, fmt.Errorf("failed to get 
stream for node %s: %w", node, errCreateStream))
                        continue
@@ -171,9 +177,9 @@ func (bp *batchPublisher) Publish(ctx context.Context, 
topic bus.Topic, messages
                        }
                        resp, errRecv := s.Recv()
                        if errRecv != nil {
-                               if isFailoverError(errRecv) {
+                               if grpchelper.IsFailoverError(errRecv) {
                                        // Record circuit breaker failure 
before creating failover event
-                                       bp.pub.recordFailure(curNode, errRecv)
+                                       bp.pub.connMgr.RecordFailure(curNode, 
errRecv)
                                        bc <- batchEvent{n: curNode, e: 
common.NewErrorWithStatus(modelv1.Status_STATUS_INTERNAL_ERROR, 
errRecv.Error())}
                                }
                                return
@@ -187,7 +193,7 @@ func (bp *batchPublisher) Publish(ctx context.Context, 
topic bus.Topic, messages
                        if isFailoverStatus(resp.Status) {
                                ce := common.NewErrorWithStatus(resp.Status, 
resp.Error)
                                // Record circuit breaker failure before 
creating failover event
-                               bp.pub.recordFailure(curNode, ce)
+                               bp.pub.connMgr.RecordFailure(curNode, ce)
                                bc <- batchEvent{n: curNode, e: ce}
                        }
                }(stream, deferFn, bp.f.events[len(bp.f.events)-1], nodeName)
@@ -211,7 +217,7 @@ func (bp *batchPublisher) Close() (cee 
map[string]*common.Error, err error) {
                        defer bp.pub.closer.Done()
                        for n, e := range batchEvents {
                                // Record circuit breaker failure before 
failover
-                               bp.pub.recordFailure(n, e.e)
+                               bp.pub.connMgr.RecordFailure(n, e.e)
                                if bp.topic == nil {
                                        bp.pub.failover(n, e.e, 
data.TopicCommon)
                                        continue
@@ -311,7 +317,7 @@ func (bp *batchPublisher) retrySend(ctx context.Context, 
stream clusterv1.Servic
                lastErr = sendErr
 
                // Check if error is retryable
-               if !isTransientError(sendErr) {
+               if !grpchelper.IsTransientError(sendErr) {
                        // Non-transient error, don't retry
                        return sendErr
                }
@@ -322,7 +328,7 @@ func (bp *batchPublisher) retrySend(ctx context.Context, 
stream clusterv1.Servic
                }
 
                // Calculate backoff with jitter
-               backoff := jitteredBackoff(defaultBackoffBase, 
defaultBackoffMax, attempt, defaultJitterFactor)
+               backoff := grpchelper.JitteredBackoff(defaultBackoffBase, 
defaultBackoffMax, attempt, grpchelper.DefaultJitterFactor)
 
                // Sleep with backoff, but respect context cancellation
                select {
diff --git a/banyand/queue/pub/client.go b/banyand/queue/pub/client.go
index 488b9f98d..ad4ab9567 100644
--- a/banyand/queue/pub/client.go
+++ b/banyand/queue/pub/client.go
@@ -19,11 +19,9 @@ package pub
 
 import (
        "context"
-       "fmt"
        "time"
 
        "google.golang.org/grpc"
-       "google.golang.org/grpc/health/grpc_health_v1"
 
        "github.com/apache/skywalking-banyandb/api/common"
        clusterv1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/cluster/v1"
@@ -32,7 +30,6 @@ import (
        "github.com/apache/skywalking-banyandb/banyand/metadata/schema"
        "github.com/apache/skywalking-banyandb/pkg/bus"
        "github.com/apache/skywalking-banyandb/pkg/grpchelper"
-       "github.com/apache/skywalking-banyandb/pkg/logger"
 )
 
 const (
@@ -118,6 +115,11 @@ type client struct {
        md     schema.Metadata
 }
 
+// Close implements grpchelper.Client.
+func (*client) Close() error {
+       return nil
+}
+
 func (p *pub) OnAddOrUpdate(md schema.Metadata) {
        if md.Kind != schema.KindNode {
                return
@@ -142,78 +144,7 @@ func (p *pub) OnAddOrUpdate(md schema.Metadata) {
        if !okRole {
                return
        }
-
-       address := node.GrpcAddress
-       if address == "" {
-               p.log.Warn().Stringer("node", node).Msg("grpc address is empty")
-               return
-       }
-       name := node.Metadata.GetName()
-       if name == "" {
-               p.log.Warn().Stringer("node", node).Msg("node name is empty")
-               return
-       }
-       p.mu.Lock()
-       defer p.mu.Unlock()
-
-       p.registerNode(node)
-
-       if _, ok := p.active[name]; ok {
-               return
-       }
-       if _, ok := p.evictable[name]; ok {
-               return
-       }
-       credOpts, err := p.getClientTransportCredentials()
-       if err != nil {
-               p.log.Error().Err(err).Msg("failed to load client TLS 
credentials")
-               return
-       }
-       conn, err := grpc.NewClient(address, append(credOpts,
-               grpc.WithDefaultServiceConfig(p.retryPolicy),
-               
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxReceiveMessageSize)))...)
-       if err != nil {
-               p.log.Error().Err(err).Msg("failed to connect to grpc server")
-               return
-       }
-
-       if !p.checkClientHealthAndReconnect(conn, md) {
-               p.log.Info().Str("status", p.dump()).Stringer("node", 
node).Msg("node is unhealthy in the register flow, move it to evict queue")
-               return
-       }
-
-       c := clusterv1.NewServiceClient(conn)
-       p.active[name] = &client{conn: conn, client: c, md: md}
-       p.addClient(md)
-       // Initialize or reset circuit breaker state to closed
-       p.recordSuccess(name)
-       p.log.Info().Str("status", p.dump()).Stringer("node", node).Msg("new 
node is healthy, add it to active queue")
-}
-
-func (p *pub) registerNode(node *databasev1.Node) {
-       name := node.Metadata.GetName()
-       defer func() {
-               p.registered[name] = node
-       }()
-
-       n, ok := p.registered[name]
-       if !ok {
-               return
-       }
-       if n.GrpcAddress == node.GrpcAddress {
-               return
-       }
-       if en, ok := p.evictable[name]; ok {
-               close(en.c)
-               delete(p.evictable, name)
-               p.log.Info().Str("node", name).Str("status", 
p.dump()).Msg("node is removed from evict queue by the new gRPC address updated 
event")
-       }
-       if client, ok := p.active[name]; ok {
-               _ = client.conn.Close()
-               delete(p.active, name)
-               p.deleteClient(client.md)
-               p.log.Info().Str("status", p.dump()).Str("node", 
name).Msg("node is removed from active queue by the new gRPC address updated 
event")
-       }
+       p.connMgr.OnAddOrUpdate(node)
 }
 
 func (p *pub) OnDelete(md schema.Metadata) {
@@ -225,163 +156,14 @@ func (p *pub) OnDelete(md schema.Metadata) {
                p.log.Warn().Msg("failed to cast node spec")
                return
        }
-       name := node.Metadata.GetName()
-       if name == "" {
-               p.log.Warn().Stringer("node", node).Msg("node name is empty")
-               return
-       }
-       p.mu.Lock()
-       defer p.mu.Unlock()
-       delete(p.registered, name)
-       if en, ok := p.evictable[name]; ok {
-               close(en.c)
-               delete(p.evictable, name)
-               p.log.Info().Str("status", p.dump()).Stringer("node", 
node).Msg("node is removed from evict queue by delete event")
-               return
-       }
-
-       if client, ok := p.active[name]; ok {
-               if p.removeNodeIfUnhealthy(md, node, client) {
-                       p.log.Info().Str("status", p.dump()).Stringer("node", 
node).Msg("remove node from active queue by delete event")
-                       return
-               }
-               if !p.closer.AddRunning() {
-                       return
-               }
-               go func() {
-                       defer p.closer.Done()
-                       var elapsed time.Duration
-                       attempt := 0
-                       for {
-                               backoff := jitteredBackoff(initBackoff, 
maxBackoff, attempt, defaultJitterFactor)
-                               select {
-                               case <-time.After(backoff):
-                                       if func() bool {
-                                               elapsed += backoff
-                                               p.mu.Lock()
-                                               defer p.mu.Unlock()
-                                               if _, ok := p.registered[name]; 
ok {
-                                                       // The client has been 
added back to registered clients map, just return
-                                                       return true
-                                               }
-                                               if p.removeNodeIfUnhealthy(md, 
node, client) {
-                                                       
p.log.Info().Str("status", p.dump()).Stringer("node", node).Dur("after", 
elapsed).Msg("remove node from active queue by delete event")
-                                                       return true
-                                               }
-                                               return false
-                                       }() {
-                                               return
-                                       }
-                               case <-p.closer.CloseNotify():
-                                       return
-                               }
-                               attempt++
-                       }
-               }()
-       }
-}
-
-func (p *pub) removeNodeIfUnhealthy(md schema.Metadata, node *databasev1.Node, 
client *client) bool {
-       if p.healthCheck(node.String(), client.conn) {
-               return false
-       }
-       _ = client.conn.Close()
-       name := node.Metadata.GetName()
-       delete(p.active, name)
-       p.deleteClient(md)
-       return true
-}
-
-func (p *pub) checkClientHealthAndReconnect(conn *grpc.ClientConn, md 
schema.Metadata) bool {
-       node, ok := md.Spec.(*databasev1.Node)
-       if !ok {
-               logger.Panicf("failed to cast node spec")
-               return false
-       }
-       if p.healthCheck(node.String(), conn) {
-               return true
-       }
-       _ = conn.Close()
-       if !p.closer.AddRunning() {
-               return false
-       }
-       name := node.Metadata.Name
-       p.evictable[name] = evictNode{n: node, c: make(chan struct{})}
-       p.deleteClient(md)
-       go func(p *pub, name string, en evictNode, md schema.Metadata) {
-               defer p.closer.Done()
-               attempt := 0
-               for {
-                       backoff := jitteredBackoff(initBackoff, maxBackoff, 
attempt, defaultJitterFactor)
-                       select {
-                       case <-time.After(backoff):
-                               credOpts, errEvict := 
p.getClientTransportCredentials()
-                               if errEvict != nil {
-                                       p.log.Error().Err(errEvict).Msg("failed 
to load client TLS credentials (evict)")
-                                       return
-                               }
-                               connEvict, errEvict := 
grpc.NewClient(node.GrpcAddress, append(credOpts,
-                                       
grpc.WithDefaultServiceConfig(p.retryPolicy),
-                                       
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxReceiveMessageSize)))...)
-                               if errEvict == nil && 
p.healthCheck(en.n.String(), connEvict) {
-                                       func() {
-                                               p.mu.Lock()
-                                               defer p.mu.Unlock()
-                                               if _, ok := p.evictable[name]; 
!ok {
-                                                       // The client has been 
removed from evict clients map, just return
-                                                       return
-                                               }
-                                               c := 
clusterv1.NewServiceClient(connEvict)
-                                               p.active[name] = &client{conn: 
connEvict, client: c, md: md}
-                                               p.addClient(md)
-                                               delete(p.evictable, name)
-                                               // Reset circuit breaker state 
to closed
-                                               p.recordSuccess(name)
-                                               p.log.Info().Str("status", 
p.dump()).Stringer("node", en.n).Msg("node is healthy, move it back to active 
queue")
-                                       }()
-                                       return
-                               }
-                               _ = connEvict.Close()
-                               if _, ok := p.registered[name]; !ok {
-                                       return
-                               }
-                               p.log.Error().Err(errEvict).Msgf("failed to 
re-connect to grpc server %s after waiting for %s", node.GrpcAddress, backoff)
-                       case <-en.c:
-                               return
-                       case <-p.closer.CloseNotify():
-                               return
-                       }
-                       attempt++
-               }
-       }(p, name, p.evictable[name], md)
-       return false
-}
-
-func (p *pub) healthCheck(node string, conn *grpc.ClientConn) bool {
-       var resp *grpc_health_v1.HealthCheckResponse
-       if err := grpchelper.Request(context.Background(), rpcTimeout, 
func(rpcCtx context.Context) (err error) {
-               resp, err = grpc_health_v1.NewHealthClient(conn).Check(rpcCtx,
-                       &grpc_health_v1.HealthCheckRequest{
-                               Service: "",
-                       })
-               return err
-       }); err != nil {
-               if e := p.log.Debug(); e.Enabled() {
-                       e.Err(err).Str("node", node).Msg("service unhealthy")
-               }
-               return false
-       }
-       if resp.GetStatus() == grpc_health_v1.HealthCheckResponse_SERVING {
-               return true
-       }
-       return false
+       p.connMgr.OnDelete(node)
 }
 
 func (p *pub) checkServiceHealth(svc string, conn *grpc.ClientConn) 
*common.Error {
-       client := clusterv1.NewServiceClient(conn)
+       serviceClient := clusterv1.NewServiceClient(conn)
        ctx, cancel := context.WithTimeout(context.Background(), rpcTimeout)
        defer cancel()
-       resp, err := client.HealthCheck(ctx, &clusterv1.HealthCheckRequest{
+       resp, err := serviceClient.HealthCheck(ctx, 
&clusterv1.HealthCheckRequest{
                ServiceName: svc,
        })
        if err != nil {
@@ -394,27 +176,11 @@ func (p *pub) checkServiceHealth(svc string, conn 
*grpc.ClientConn) *common.Erro
 }
 
 func (p *pub) failover(node string, ce *common.Error, topic bus.Topic) {
-       p.mu.Lock()
-       defer p.mu.Unlock()
        if ce.Status() != modelv1.Status_STATUS_INTERNAL_ERROR {
                _, _ = p.checkWritable(node, topic)
                return
        }
-       if en, evictable := p.evictable[node]; evictable {
-               if _, registered := p.registered[node]; !registered {
-                       close(en.c)
-                       delete(p.evictable, node)
-                       p.log.Info().Str("node", node).Str("status", 
p.dump()).Msg("node is removed from evict queue by wire event")
-               }
-               return
-       }
-
-       if client, ok := p.active[node]; ok && 
!p.checkClientHealthAndReconnect(client.conn, client.md) {
-               _ = client.conn.Close()
-               delete(p.active, node)
-               p.deleteClient(client.md)
-               p.log.Info().Str("status", p.dump()).Str("node", 
node).Msg("node is unhealthy in the failover flow, move it to evict queue")
-       }
+       p.connMgr.FailoverNode(node)
 }
 
 func (p *pub) checkWritable(n string, topic bus.Topic) (bool, *common.Error) {
@@ -422,16 +188,16 @@ func (p *pub) checkWritable(n string, topic bus.Topic) 
(bool, *common.Error) {
        if !ok {
                return false, nil
        }
-       node, ok := p.active[n]
+       c, ok := p.connMgr.GetClient(n)
        if !ok {
                return false, nil
        }
        topicStr := topic.String()
-       err := p.checkServiceHealth(topicStr, node.conn)
+       err := p.checkServiceHealth(topicStr, c.conn)
        if err == nil {
                return true, nil
        }
-       h.OnDelete(node.md)
+       h.OnDelete(c.md)
        if !p.closer.AddRunning() {
                return false, err
        }
@@ -461,28 +227,18 @@ func (p *pub) checkWritable(n string, topic bus.Topic) 
(bool, *common.Error) {
                }()
                attempt := 0
                for {
-                       backoff := jitteredBackoff(initBackoff, maxBackoff, 
attempt, defaultJitterFactor)
+                       backoff := 
grpchelper.JitteredBackoff(grpchelper.InitBackoff, grpchelper.MaxBackoff, 
attempt, grpchelper.DefaultJitterFactor)
                        select {
                        case <-time.After(backoff):
-                               p.mu.RLock()
-                               nodeCur, okCur := p.active[nodeName]
-                               p.mu.RUnlock()
+                               nodeCur, okCur := p.connMgr.GetClient(nodeName)
                                if !okCur {
                                        return
                                }
                                errInternal := p.checkServiceHealth(t, 
nodeCur.conn)
                                if errInternal == nil {
-                                       func() {
-                                               p.mu.Lock()
-                                               defer p.mu.Unlock()
-                                               nodeCur, okCur := 
p.active[nodeName]
-                                               if !okCur {
-                                                       return
-                                               }
-                                               // Record success for circuit 
breaker
-                                               p.recordSuccess(nodeName)
-                                               h.OnAddOrUpdate(nodeCur.md)
-                                       }()
+                                       // Record success for circuit breaker
+                                       p.connMgr.RecordSuccess(nodeName)
+                                       h.OnAddOrUpdate(nodeCur.md)
                                        return
                                }
                                p.log.Warn().Str("topic", 
t).Err(errInternal).Str("node", nodeName).Dur("backoff", backoff).Msg("data 
node can not ingest data")
@@ -494,40 +250,3 @@ func (p *pub) checkWritable(n string, topic bus.Topic) 
(bool, *common.Error) {
        }(n, topicStr)
        return false, err
 }
-
-func (p *pub) deleteClient(md schema.Metadata) {
-       if len(p.handlers) > 0 {
-               for _, h := range p.handlers {
-                       h.OnDelete(md)
-               }
-       }
-}
-
-func (p *pub) addClient(md schema.Metadata) {
-       if len(p.handlers) > 0 {
-               for _, h := range p.handlers {
-                       h.OnAddOrUpdate(md)
-               }
-       }
-}
-
-func (p *pub) dump() string {
-       keysRegistered := make([]string, 0, len(p.registered))
-       for k := range p.registered {
-               keysRegistered = append(keysRegistered, k)
-       }
-       keysActive := make([]string, 0, len(p.active))
-       for k := range p.active {
-               keysActive = append(keysActive, k)
-       }
-       keysEvictable := make([]string, 0, len(p.evictable))
-       for k := range p.evictable {
-               keysEvictable = append(keysEvictable, k)
-       }
-       return fmt.Sprintf("registered: %v, active :%v, evictable :%v", 
keysRegistered, keysActive, keysEvictable)
-}
-
-type evictNode struct {
-       n *databasev1.Node
-       c chan struct{}
-}
diff --git a/banyand/queue/pub/client_test.go b/banyand/queue/pub/client_test.go
index 0b797cb18..84807010c 100644
--- a/banyand/queue/pub/client_test.go
+++ b/banyand/queue/pub/client_test.go
@@ -81,9 +81,7 @@ var _ = ginkgo.Describe("publish clients 
register/unregister", func() {
                closeFn := setup(addr1, codes.OK, 200*time.Millisecond)
                defer closeFn()
                gomega.Eventually(func() int {
-                       p.mu.RLock()
-                       defer p.mu.RUnlock()
-                       return len(p.active)
+                       return p.connMgr.ActiveCount()
                }, flags.EventuallyTimeout).Should(gomega.Equal(1))
                verifyClients(p, 1, 0, 1, 1)
        })
@@ -171,10 +169,8 @@ func verifyClients(p *pub, active, evict, onAdd, onDelete 
int) {
 }
 
 func verifyClientsWithGomega(g gomega.Gomega, p *pub, topic bus.Topic, active, 
evict, onAdd, onDelete int) {
-       p.mu.RLock()
-       defer p.mu.RUnlock()
-       g.Expect(len(p.active)).Should(gomega.Equal(active))
-       g.Expect(len(p.evictable)).Should(gomega.Equal(evict))
+       g.Expect(p.connMgr.ActiveCount()).Should(gomega.Equal(active))
+       g.Expect(p.connMgr.EvictableCount()).Should(gomega.Equal(evict))
        for t, eh := range p.handlers {
                if topic != data.TopicCommon && t != topic {
                        continue
diff --git a/banyand/queue/pub/pub.go b/banyand/queue/pub/pub.go
index 2ce624ab5..05043ac0b 100644
--- a/banyand/queue/pub/pub.go
+++ b/banyand/queue/pub/pub.go
@@ -20,18 +20,17 @@ package pub
 
 import (
        "context"
+       "errors"
        "fmt"
        "io"
        "strings"
        "sync"
        "time"
 
-       "github.com/pkg/errors"
+       pkgerrors "github.com/pkg/errors"
        "go.uber.org/multierr"
        "google.golang.org/grpc"
-       "google.golang.org/grpc/codes"
        "google.golang.org/grpc/credentials"
-       "google.golang.org/grpc/status"
        "google.golang.org/protobuf/proto"
 
        "github.com/apache/skywalking-banyandb/api/common"
@@ -63,6 +62,8 @@ var (
        _ run.PreRunner = (*pub)(nil)
        _ run.Service   = (*pub)(nil)
        _ run.Config    = (*pub)(nil)
+
+       _ grpchelper.ConnectionHandler[*client] = (*pub)(nil)
 )
 
 type pub struct {
@@ -70,23 +71,61 @@ type pub struct {
        metadata        metadata.Repo
        handlers        map[bus.Topic]schema.EventHandler
        log             *logger.Logger
-       registered      map[string]*databasev1.Node
-       active          map[string]*client
-       evictable       map[string]evictNode
+       connMgr         *grpchelper.ConnManager[*client]
        closer          *run.Closer
        writableProbe   map[string]map[string]struct{}
-       cbStates        map[string]*circuitState
        caCertPath      string
        caCertReloader  *pkgtls.Reloader
        prefix          string
        retryPolicy     string
        allowedRoles    []databasev1.Role
-       mu              sync.RWMutex
-       cbMu            sync.RWMutex
        writableProbeMu sync.Mutex
        tlsEnabled      bool
 }
 
+// AddressOf implements grpchelper.ConnectionHandler.
+func (p *pub) AddressOf(node *databasev1.Node) string {
+       return node.GrpcAddress
+}
+
+// GetDialOptions implements grpchelper.ConnectionHandler.
+func (p *pub) GetDialOptions() ([]grpc.DialOption, error) {
+       return p.getClientTransportCredentials()
+}
+
+// NewClient implements grpchelper.ConnectionHandler.
+func (p *pub) NewClient(conn *grpc.ClientConn, node *databasev1.Node) 
(*client, error) {
+       md := schema.Metadata{
+               TypeMeta: schema.TypeMeta{
+                       Name: node.Metadata.GetName(),
+                       Kind: schema.KindNode,
+               },
+               Spec: node,
+       }
+       return &client{
+               client: clusterv1.NewServiceClient(conn),
+               conn:   conn,
+               md:     md,
+       }, nil
+}
+
+// OnActive implements grpchelper.ConnectionHandler.
+func (p *pub) OnActive(_ string, c *client) {
+       for _, h := range p.handlers {
+               h.OnAddOrUpdate(c.md)
+       }
+}
+
+// OnInactive implements grpchelper.ConnectionHandler.
+func (p *pub) OnInactive(name string, c *client) {
+       for _, h := range p.handlers {
+               h.OnDelete(c.md)
+       }
+       p.writableProbeMu.Lock()
+       delete(p.writableProbe, name)
+       p.writableProbeMu.Unlock()
+}
+
 func (p *pub) FlagSet() *run.FlagSet {
        prefixFlag := func(name string) string {
                if p.prefix == "" {
@@ -113,23 +152,12 @@ func (p *pub) Register(topic bus.Topic, handler 
schema.EventHandler) {
 }
 
 func (p *pub) GracefulStop() {
-       // Stop CA certificate reloader if enabled
        if p.caCertReloader != nil {
                p.caCertReloader.Stop()
        }
-
-       p.mu.Lock()
-       defer p.mu.Unlock()
-       for i := range p.evictable {
-               close(p.evictable[i].c)
-       }
-       p.evictable = nil
        p.closer.Done()
        p.closer.CloseThenWait()
-       for _, c := range p.active {
-               _ = c.conn.Close()
-       }
-       p.active = nil
+       p.connMgr.GracefulStop()
 }
 
 // Serve implements run.Service.
@@ -153,7 +181,7 @@ func (p *pub) Serve() run.StopNotify {
                                        select {
                                        case <-certUpdateCh:
                                                p.log.Info().Msg("CA 
certificate updated, reconnecting clients")
-                                               p.reconnectAllClients()
+                                               p.connMgr.ReconnectAll()
                                        case <-stopCh:
                                                return
                                        }
@@ -171,14 +199,7 @@ var bypassMatches = []MatchFunc{bypassMatch}
 func bypassMatch(_ map[string]string) bool { return true }
 
 func (p *pub) Broadcast(timeout time.Duration, topic bus.Topic, messages 
bus.Message) ([]bus.Future, error) {
-       var nodes []*databasev1.Node
-       p.mu.RLock()
-       for k := range p.active {
-               if n := p.registered[k]; n != nil {
-                       nodes = append(nodes, n)
-               }
-       }
-       p.mu.RUnlock()
+       nodes := p.connMgr.ActiveRegisteredNodes()
        if len(nodes) == 0 {
                return nil, errors.New("no active nodes")
        }
@@ -219,7 +240,6 @@ func (p *pub) Broadcast(timeout time.Duration, topic 
bus.Topic, messages bus.Mes
        if len(names) == 0 {
                return nil, fmt.Errorf("no nodes match the selector %v", 
messages.NodeSelectors())
        }
-
        futureCh := make(chan publishResult, len(names))
        var wg sync.WaitGroup
        for n := range names {
@@ -238,8 +258,8 @@ func (p *pub) Broadcast(timeout time.Duration, topic 
bus.Topic, messages bus.Mes
        var errs error
        for f := range futureCh {
                if f.e != nil {
-                       errs = multierr.Append(errs, errors.Wrapf(f.e, "failed 
to publish message to %s", f.n))
-                       if isFailoverError(f.e) {
+                       errs = multierr.Append(errs, pkgerrors.Wrapf(f.e, 
"failed to publish message to %s", f.n))
+                       if grpchelper.IsFailoverError(f.e) {
                                if p.closer.AddRunning() {
                                        go func() {
                                                defer p.closer.Done()
@@ -275,37 +295,25 @@ func (p *pub) publish(timeout time.Duration, topic 
bus.Topic, messages ...bus.Me
                        return multierr.Append(err, fmt.Errorf("failed to 
marshal message[%d]: %w", m.ID(), errSend))
                }
                node := m.Node()
-
-               // Check circuit breaker before attempting send
-               if !p.isRequestAllowed(node) {
-                       return multierr.Append(err, fmt.Errorf("circuit breaker 
open for node %s", node))
-               }
-
-               p.mu.RLock()
-               client, ok := p.active[node]
-               p.mu.RUnlock()
-               if !ok {
-                       return multierr.Append(err, fmt.Errorf("failed to get 
client for node %s", node))
-               }
-               ctx, cancel := context.WithTimeout(context.Background(), 
timeout)
-               f.cancelFn = append(f.cancelFn, cancel)
-               stream, errCreateStream := client.client.Send(ctx)
-               if errCreateStream != nil {
-                       // Record failure for circuit breaker (only for 
transient/internal errors)
-                       p.recordFailure(node, errCreateStream)
-                       return multierr.Append(err, fmt.Errorf("failed to get 
stream for node %s: %w", node, errCreateStream))
-               }
-               errSend = stream.Send(r)
-               if errSend != nil {
-                       // Record failure for circuit breaker (only for 
transient/internal errors)
-                       p.recordFailure(node, errSend)
-                       return multierr.Append(err, fmt.Errorf("failed to send 
message to node %s: %w", node, errSend))
+               execErr := p.connMgr.Execute(node, func(c *client) error {
+                       ctx, cancel := 
context.WithTimeout(context.Background(), timeout)
+                       f.cancelFn = append(f.cancelFn, cancel)
+                       stream, errCreateStream := c.client.Send(ctx)
+                       if errCreateStream != nil {
+                               // Record failure for circuit breaker (only for 
transient/internal errors)
+                               return fmt.Errorf("failed to get stream for 
node %s: %w", node, errCreateStream)
+                       }
+                       if sendErr := stream.Send(r); sendErr != nil {
+                               return fmt.Errorf("failed to send message to 
node %s: %w", node, sendErr)
+                       }
+                       f.clients = append(f.clients, stream)
+                       f.topics = append(f.topics, topic)
+                       f.nodes = append(f.nodes, node)
+                       return nil
+               })
+               if execErr != nil {
+                       err = multierr.Append(err, execErr)
                }
-               // Record success for circuit breaker
-               p.recordSuccess(node)
-               f.clients = append(f.clients, stream)
-               f.topics = append(f.topics, topic)
-               f.nodes = append(f.nodes, node)
                return err
        }
        for _, m := range messages {
@@ -322,35 +330,7 @@ func (p *pub) Publish(_ context.Context, topic bus.Topic, 
messages ...bus.Messag
 // GetRouteTable implements RouteTableProvider interface.
 // Returns a RouteTable with all registered nodes and their health states.
 func (p *pub) GetRouteTable() *databasev1.RouteTable {
-       p.mu.RLock()
-       defer p.mu.RUnlock()
-
-       registered := make([]*databasev1.Node, 0, len(p.registered))
-       for _, node := range p.registered {
-               if node != nil {
-                       registered = append(registered, node)
-               }
-       }
-
-       active := make([]string, 0, len(p.active))
-       for nodeID := range p.active {
-               if node := p.registered[nodeID]; node != nil && node.Metadata 
!= nil {
-                       active = append(active, node.Metadata.Name)
-               }
-       }
-
-       evictable := make([]string, 0, len(p.evictable))
-       for nodeID := range p.evictable {
-               if node := p.registered[nodeID]; node != nil && node.Metadata 
!= nil {
-                       evictable = append(evictable, node.Metadata.Name)
-               }
-       }
-
-       return &databasev1.RouteTable{
-               Registered: registered,
-               Active:     active,
-               Evictable:  evictable,
-       }
+       return p.connMgr.GetRouteTable()
 }
 
 // New returns a new queue client targeting the given node roles.
@@ -372,15 +352,11 @@ func New(metadata metadata.Repo, roles 
...databasev1.Role) queue.Client {
        }
        p := &pub{
                metadata:      metadata,
-               active:        make(map[string]*client),
-               evictable:     make(map[string]evictNode),
-               registered:    make(map[string]*databasev1.Node),
                handlers:      make(map[bus.Topic]schema.EventHandler),
                closer:        run.NewCloser(1),
                allowedRoles:  roles,
                prefix:        strBuilder.String(),
                writableProbe: make(map[string]map[string]struct{}),
-               cbStates:      make(map[string]*circuitState),
                retryPolicy:   retryPolicy,
        }
        return p
@@ -389,7 +365,14 @@ func New(metadata metadata.Repo, roles ...databasev1.Role) 
queue.Client {
 // NewWithoutMetadata returns a new queue client without metadata, defaulting 
to data nodes.
 func NewWithoutMetadata() queue.Client {
        p := New(nil, databasev1.Role_ROLE_DATA)
-       p.(*pub).log = logger.GetLogger("queue-client")
+       pp := p.(*pub)
+       pp.log = logger.GetLogger("queue-client")
+       pp.connMgr = 
grpchelper.NewConnManager(grpchelper.ConnManagerConfig[*client]{
+               Handler:        pp,
+               Logger:         pp.log,
+               RetryPolicy:    pp.retryPolicy,
+               MaxRecvMsgSize: maxReceiveMessageSize,
+       })
        return p
 }
 
@@ -404,12 +387,20 @@ func (p *pub) PreRun(context.Context) error {
 
        p.log = logger.GetLogger("server-queue-pub-" + p.prefix)
 
+       // Initialize connection manager with the pub as the handler
+       p.connMgr = 
grpchelper.NewConnManager(grpchelper.ConnManagerConfig[*client]{
+               Handler:        p,
+               Logger:         p.log,
+               RetryPolicy:    p.retryPolicy,
+               MaxRecvMsgSize: maxReceiveMessageSize,
+       })
+
        // Initialize CA certificate reloader if TLS is enabled and CA cert 
path is provided
        if p.tlsEnabled && p.caCertPath != "" {
                var err error
                p.caCertReloader, err = 
pkgtls.NewClientCertReloader(p.caCertPath, p.log)
                if err != nil {
-                       return errors.Wrapf(err, "failed to initialize CA 
certificate reloader for %s", p.prefix)
+                       return pkgerrors.Wrapf(err, "failed to initialize CA 
certificate reloader for %s", p.prefix)
                }
                p.log.Info().Str("caCertPath", p.caCertPath).Msg("Initialized 
CA certificate reloader")
        }
@@ -514,14 +505,6 @@ func (l *future) GetAll() ([]bus.Message, error) {
        }
 }
 
-func isFailoverError(err error) bool {
-       s, ok := status.FromError(err)
-       if !ok {
-               return false
-       }
-       return s.Code() == codes.Unavailable || s.Code() == 
codes.DeadlineExceeded
-}
-
 func (p *pub) getClientTransportCredentials() ([]grpc.DialOption, error) {
        if !p.tlsEnabled {
                return grpchelper.SecureOptions(nil, false, false, "")
@@ -547,39 +530,6 @@ func (p *pub) getClientTransportCredentials() 
([]grpc.DialOption, error) {
        return opts, nil
 }
 
-// reconnectAllClients reconnects all active clients when CA certificate is 
updated.
-func (p *pub) reconnectAllClients() {
-       // Collect nodes and close connections
-       p.mu.Lock()
-       nodesToReconnect := make([]schema.Metadata, 0, len(p.registered))
-       for name, node := range p.registered {
-               // Handle evictable nodes: close channel and remove from 
evictable
-               if en, ok := p.evictable[name]; ok {
-                       close(en.c)
-                       delete(p.evictable, name)
-               }
-               // Handle active nodes: close connection and remove from active
-               if client, ok := p.active[name]; ok {
-                       _ = client.conn.Close()
-                       delete(p.active, name)
-                       p.deleteClient(client.md)
-               }
-               md := schema.Metadata{
-                       TypeMeta: schema.TypeMeta{
-                               Kind: schema.KindNode,
-                       },
-                       Spec: node,
-               }
-               nodesToReconnect = append(nodesToReconnect, md)
-       }
-       p.mu.Unlock()
-
-       // Reconnect with new credentials
-       for _, md := range nodesToReconnect {
-               p.OnAddOrUpdate(md)
-       }
-}
-
 // NewChunkedSyncClient implements queue.Client.
 func (p *pub) NewChunkedSyncClient(node string, chunkSize uint32) 
(queue.ChunkedSyncClient, error) {
        return p.NewChunkedSyncClientWithConfig(node, &ChunkedSyncClientConfig{
@@ -592,9 +542,7 @@ func (p *pub) NewChunkedSyncClient(node string, chunkSize 
uint32) (queue.Chunked
 
 // NewChunkedSyncClientWithConfig creates a chunked sync client with advanced 
configuration.
 func (p *pub) NewChunkedSyncClientWithConfig(node string, config 
*ChunkedSyncClientConfig) (queue.ChunkedSyncClient, error) {
-       p.mu.RLock()
-       client, ok := p.active[node]
-       p.mu.RUnlock()
+       c, ok := p.connMgr.GetClient(node)
        if !ok {
                return nil, fmt.Errorf("no active client for node %s", node)
        }
@@ -602,10 +550,9 @@ func (p *pub) NewChunkedSyncClientWithConfig(node string, 
config *ChunkedSyncCli
        if config.ChunkSize == 0 {
                config.ChunkSize = defaultChunkSize
        }
-
        return &chunkedSyncClient{
-               client:    client.client,
-               conn:      client.conn,
+               client:    c.client,
+               conn:      c.conn,
                node:      node,
                log:       p.log,
                chunkSize: config.ChunkSize,
@@ -615,12 +562,5 @@ func (p *pub) NewChunkedSyncClientWithConfig(node string, 
config *ChunkedSyncCli
 
 // HealthyNodes returns a list of node names that are currently healthy and 
connected.
 func (p *pub) HealthyNodes() []string {
-       p.mu.RLock()
-       defer p.mu.RUnlock()
-
-       nodes := make([]string, 0, len(p.active))
-       for name := range p.active {
-               nodes = append(nodes, name)
-       }
-       return nodes
+       return p.connMgr.ActiveNames()
 }
diff --git a/banyand/queue/pub/pub_suite_test.go 
b/banyand/queue/pub/pub_suite_test.go
index fab13eb4d..91479eb20 100644
--- a/banyand/queue/pub/pub_suite_test.go
+++ b/banyand/queue/pub/pub_suite_test.go
@@ -42,6 +42,7 @@ import (
        modelv1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
        "github.com/apache/skywalking-banyandb/banyand/metadata/schema"
        "github.com/apache/skywalking-banyandb/pkg/bus"
+       "github.com/apache/skywalking-banyandb/pkg/grpchelper"
        "github.com/apache/skywalking-banyandb/pkg/logger"
        "github.com/apache/skywalking-banyandb/pkg/test"
        "github.com/apache/skywalking-banyandb/pkg/test/flags"
@@ -214,18 +215,30 @@ func (m *mockHandler) OnDelete(_ schema.Metadata) {
        m.deleteCount++
 }
 
+func initConnMgr(pp *pub) {
+       pp.connMgr = 
grpchelper.NewConnManager(grpchelper.ConnManagerConfig[*client]{
+               Handler:        pp,
+               Logger:         pp.log,
+               RetryPolicy:    pp.retryPolicy,
+               MaxRecvMsgSize: maxReceiveMessageSize,
+       })
+}
+
 func newPub(roles ...databasev1.Role) *pub {
        p := New(nil, roles...)
-       p.(*pub).log = logger.GetLogger("queue-client")
+       pp := p.(*pub)
+       pp.log = logger.GetLogger("queue-client")
+       initConnMgr(pp)
        p.Register(data.TopicStreamWrite, &mockHandler{})
        p.Register(data.TopicMeasureWrite, &mockHandler{})
-       return p.(*pub)
+       return pp
 }
 
 // newPubWithNoRetry creates a pub with a retry policy that doesn't retry 
Unavailable errors.
 func newPubWithNoRetry(roles ...databasev1.Role) *pub {
        p := New(nil, roles...)
-       p.(*pub).log = logger.GetLogger("queue-client")
+       pp := p.(*pub)
+       pp.log = logger.GetLogger("queue-client")
        p.Register(data.TopicStreamWrite, &mockHandler{})
        p.Register(data.TopicMeasureWrite, &mockHandler{})
 
@@ -244,10 +257,9 @@ func newPubWithNoRetry(roles ...databasev1.Role) *pub {
                    }
                  }
                ]}`
-
-       // Store the original retry policy and replace it
-       p.(*pub).retryPolicy = noRetryPolicy
-       return p.(*pub)
+       pp.retryPolicy = noRetryPolicy
+       initConnMgr(pp)
+       return pp
 }
 
 func getDataNode(name string, address string) schema.Metadata {
diff --git a/banyand/queue/pub/pub_test.go b/banyand/queue/pub/pub_test.go
index d6c8237df..680f3eee6 100644
--- a/banyand/queue/pub/pub_test.go
+++ b/banyand/queue/pub/pub_test.go
@@ -112,17 +112,13 @@ var _ = ginkgo.Describe("Publish and Broadcast", func() {
                        gomega.Expect(cee).Should(gomega.HaveLen(1))
                        gomega.Expect(cee).Should(gomega.HaveKey("node2"))
                        gomega.Eventually(func() int {
-                               p.mu.RLock()
-                               defer p.mu.RUnlock()
-                               return len(p.active)
+                               return p.connMgr.ActiveCount()
                        }, flags.EventuallyTimeout).Should(gomega.Equal(1))
-                       func() {
-                               p.mu.RLock()
-                               defer p.mu.RUnlock()
-                               
gomega.Expect(p.evictable).Should(gomega.HaveLen(1))
-                               
gomega.Expect(p.evictable).Should(gomega.HaveKey("node2"))
-                               
gomega.Expect(p.active).Should(gomega.HaveKey("node1"))
-                       }()
+                       
gomega.Expect(p.connMgr.EvictableCount()).Should(gomega.Equal(1))
+                       _, node1Active := p.connMgr.GetClient("node1")
+                       gomega.Expect(node1Active).Should(gomega.BeTrue())
+                       _, node2Active := p.connMgr.GetClient("node2")
+                       gomega.Expect(node2Active).Should(gomega.BeFalse())
                })
 
                ginkgo.It("should go to evict queue when node's disk is full", 
func() {
@@ -155,10 +151,8 @@ var _ = ginkgo.Describe("Publish and Broadcast", func() {
                        gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
                        gomega.Expect(cee).Should(gomega.BeNil())
                        gomega.Consistently(func(g gomega.Gomega) {
-                               p.mu.RLock()
-                               defer p.mu.RUnlock()
-                               g.Expect(p.active).Should(gomega.HaveLen(2))
-                               g.Expect(p.evictable).Should(gomega.HaveLen(0))
+                               
g.Expect(p.connMgr.ActiveCount()).Should(gomega.Equal(2))
+                               
g.Expect(p.connMgr.EvictableCount()).Should(gomega.Equal(0))
                        }).Should(gomega.Succeed())
                })
 
@@ -192,9 +186,7 @@ var _ = ginkgo.Describe("Publish and Broadcast", func() {
                        gomega.Expect(cee).Should(gomega.HaveLen(1))
                        gomega.Expect(cee).Should(gomega.HaveKey("node2"))
                        gomega.Consistently(func() int {
-                               p.mu.RLock()
-                               defer p.mu.RUnlock()
-                               return len(p.active)
+                               return p.connMgr.ActiveCount()
                        }, "1s").Should(gomega.Equal(2))
                })
        })
diff --git a/banyand/queue/pub/pub_tls_test.go 
b/banyand/queue/pub/pub_tls_test.go
index 0a708a9ba..a11264a1e 100644
--- a/banyand/queue/pub/pub_tls_test.go
+++ b/banyand/queue/pub/pub_tls_test.go
@@ -162,9 +162,7 @@ var _ = ginkgo.Describe("Broadcast over one-way TLS", 
func() {
                p.OnAddOrUpdate(node)
 
                gomega.Eventually(func() int {
-                       p.mu.RLock()
-                       defer p.mu.RUnlock()
-                       return len(p.active)
+                       return p.connMgr.ActiveCount()
                }, flags.EventuallyTimeout).Should(gomega.Equal(1))
 
                futures, err := p.Broadcast(
diff --git a/banyand/queue/pub/retry_test.go b/banyand/queue/pub/retry_test.go
index f7a6ca820..b073c3470 100644
--- a/banyand/queue/pub/retry_test.go
+++ b/banyand/queue/pub/retry_test.go
@@ -20,7 +20,6 @@ package pub
 import (
        "context"
        "fmt"
-       "math"
        "sync"
        "testing"
        "time"
@@ -98,168 +97,6 @@ func (m *MockSendClient) SetCloseSendFunc(f func() error) {
        m.closeSendFunc = f
 }
 
-func TestJitteredBackoff(t *testing.T) {
-       tests := []struct {
-               name         string
-               baseBackoff  time.Duration
-               maxBackoff   time.Duration
-               attempt      int
-               jitterFactor float64
-               expectedMin  time.Duration
-               expectedMax  time.Duration
-       }{
-               {
-                       name:         "zero_attempt",
-                       baseBackoff:  100 * time.Millisecond,
-                       maxBackoff:   10 * time.Second,
-                       attempt:      0,
-                       jitterFactor: 0.2,
-                       expectedMin:  80 * time.Millisecond,
-                       expectedMax:  120 * time.Millisecond,
-               },
-               {
-                       name:         "first_attempt",
-                       baseBackoff:  100 * time.Millisecond,
-                       maxBackoff:   10 * time.Second,
-                       attempt:      1,
-                       jitterFactor: 0.2,
-                       expectedMin:  160 * time.Millisecond,
-                       expectedMax:  240 * time.Millisecond,
-               },
-               {
-                       name:         "max_backoff_reached",
-                       baseBackoff:  100 * time.Millisecond,
-                       maxBackoff:   300 * time.Millisecond,
-                       attempt:      10,
-                       jitterFactor: 0.2,
-                       expectedMin:  240 * time.Millisecond,
-                       expectedMax:  360 * time.Millisecond,
-               },
-               {
-                       name:         "no_jitter",
-                       baseBackoff:  100 * time.Millisecond,
-                       maxBackoff:   10 * time.Second,
-                       attempt:      0,
-                       jitterFactor: 0.0,
-                       expectedMin:  100 * time.Millisecond,
-                       expectedMax:  100 * time.Millisecond,
-               },
-               {
-                       name:         "max_jitter",
-                       baseBackoff:  100 * time.Millisecond,
-                       maxBackoff:   10 * time.Second,
-                       attempt:      0,
-                       jitterFactor: 1.0,
-                       expectedMin:  0,
-                       expectedMax:  200 * time.Millisecond,
-               },
-       }
-
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       // Test multiple times to account for randomness
-                       for i := 0; i < 100; i++ {
-                               result := jitteredBackoff(tt.baseBackoff, 
tt.maxBackoff, tt.attempt, tt.jitterFactor)
-
-                               assert.GreaterOrEqual(t, result, tt.expectedMin,
-                                       "result %v should be >= expected min 
%v", result, tt.expectedMin)
-                               assert.LessOrEqual(t, result, tt.expectedMax,
-                                       "result %v should be <= expected max 
%v", result, tt.expectedMax)
-                       }
-               })
-       }
-}
-
-func TestJitteredBackoffEdgeCases(t *testing.T) {
-       tests := []struct {
-               name         string
-               baseBackoff  time.Duration
-               maxBackoff   time.Duration
-               attempt      int
-               jitterFactor float64
-       }{
-               {
-                       name:         "negative_jitter_factor",
-                       baseBackoff:  100 * time.Millisecond,
-                       maxBackoff:   10 * time.Second,
-                       attempt:      0,
-                       jitterFactor: -0.5,
-               },
-               {
-                       name:         "jitter_factor_greater_than_one",
-                       baseBackoff:  100 * time.Millisecond,
-                       maxBackoff:   10 * time.Second,
-                       attempt:      0,
-                       jitterFactor: 1.5,
-               },
-       }
-
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       // Should not panic and should return a reasonable value
-                       result := jitteredBackoff(tt.baseBackoff, 
tt.maxBackoff, tt.attempt, tt.jitterFactor)
-                       assert.Greater(t, result, time.Duration(0), "result 
should be positive")
-                       assert.LessOrEqual(t, result, tt.maxBackoff, "result 
should not exceed max backoff")
-               })
-       }
-}
-
-func TestIsTransientError(t *testing.T) {
-       tests := []struct {
-               err         error
-               name        string
-               expectRetry bool
-       }{
-               {
-                       name:        "nil_error",
-                       err:         nil,
-                       expectRetry: false,
-               },
-               {
-                       name:        "unavailable_error",
-                       err:         status.Error(codes.Unavailable, "service 
unavailable"),
-                       expectRetry: true,
-               },
-               {
-                       name:        "deadline_exceeded_error",
-                       err:         status.Error(codes.DeadlineExceeded, 
"deadline exceeded"),
-                       expectRetry: true,
-               },
-               {
-                       name:        "resource_exhausted_error",
-                       err:         status.Error(codes.ResourceExhausted, 
"rate limited"),
-                       expectRetry: true,
-               },
-               {
-                       name:        "not_found_error",
-                       err:         status.Error(codes.NotFound, "not found"),
-                       expectRetry: false,
-               },
-               {
-                       name:        "permission_denied_error",
-                       err:         status.Error(codes.PermissionDenied, 
"permission denied"),
-                       expectRetry: false,
-               },
-               {
-                       name:        "invalid_argument_error",
-                       err:         status.Error(codes.InvalidArgument, 
"invalid argument"),
-                       expectRetry: false,
-               },
-               {
-                       name:        "generic_error",
-                       err:         fmt.Errorf("generic error"),
-                       expectRetry: false,
-               },
-       }
-
-       for _, tt := range tests {
-               t.Run(tt.name, func(t *testing.T) {
-                       result := isTransientError(tt.err)
-                       assert.Equal(t, tt.expectRetry, result, "transient 
error classification mismatch")
-               })
-       }
-}
-
 func TestRetrySendSuccess(t *testing.T) {
        ctx := context.Background()
        mockStream := NewMockSendClient(ctx)
@@ -485,42 +322,3 @@ func TestRetrySendConcurrency(t *testing.T) {
 
        assert.Equal(t, numGoroutines, successCount, "all goroutines should 
succeed eventually (got %d successes, %d failures)", successCount, failureCount)
 }
-
-func TestBackoffDistribution(t *testing.T) {
-       // Test that jitter produces a reasonable distribution
-       const numSamples = 1000
-       baseBackoff := 100 * time.Millisecond
-       maxBackoff := 10 * time.Second
-       attempt := 0
-       jitterFactor := 0.2
-
-       durations := make([]time.Duration, numSamples)
-       for i := 0; i < numSamples; i++ {
-               durations[i] = jitteredBackoff(baseBackoff, maxBackoff, 
attempt, jitterFactor)
-       }
-
-       // Calculate mean and standard deviation
-       var sum time.Duration
-       for _, d := range durations {
-               sum += d
-       }
-       mean := sum / time.Duration(numSamples)
-
-       var sumSquaredDiffs float64
-       for _, d := range durations {
-               diff := float64(d - mean)
-               sumSquaredDiffs += diff * diff
-       }
-       variance := sumSquaredDiffs / float64(numSamples)
-       stdDev := time.Duration(math.Sqrt(variance))
-
-       // Expected mean should be around baseBackoff (100ms)
-       expectedMean := baseBackoff
-       assert.InDelta(t, float64(expectedMean), float64(mean), 
float64(baseBackoff)/10,
-               "mean should be close to base backoff")
-
-       // Standard deviation should be reasonable (not too low, indicating 
proper jitter)
-       minStdDev := time.Duration(float64(baseBackoff) * jitterFactor / 4)
-       assert.Greater(t, stdDev, minStdDev,
-               "standard deviation should indicate proper jitter is applied")
-}
diff --git a/banyand/queue/pub/circuitbreaker.go 
b/pkg/grpchelper/circuitbreaker.go
similarity index 80%
rename from banyand/queue/pub/circuitbreaker.go
rename to pkg/grpchelper/circuitbreaker.go
index 7bb64e296..3393a9eb5 100644
--- a/banyand/queue/pub/circuitbreaker.go
+++ b/pkg/grpchelper/circuitbreaker.go
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-package pub
+package grpchelper
 
 import (
        "time"
@@ -45,13 +45,13 @@ type circuitState struct {
        halfOpenProbeInFlight bool
 }
 
-// isRequestAllowed checks if a request to the given node is allowed based on 
circuit breaker state.
+// IsRequestAllowed checks if a request to the given node is allowed based on 
circuit breaker state.
 // It also handles state transitions from Open to Half-Open when cooldown 
expires.
-func (p *pub) isRequestAllowed(node string) bool {
-       p.cbMu.Lock()
-       defer p.cbMu.Unlock()
+func (m *ConnManager[C]) IsRequestAllowed(node string) bool {
+       m.cbMu.Lock()
+       defer m.cbMu.Unlock()
 
-       cb, exists := p.cbStates[node]
+       cb, exists := m.cbStates[node]
        if !exists {
                return true // No circuit breaker state, allow request
        }
@@ -65,7 +65,7 @@ func (p *pub) isRequestAllowed(node string) bool {
                        // Transition to Half-Open to allow a single probe 
request
                        cb.state = StateHalfOpen
                        cb.halfOpenProbeInFlight = true // Set token for the 
probe
-                       p.log.Info().Str("node", node).Msg("circuit breaker 
transitioned to half-open")
+                       m.log.Info().Str("node", node).Msg("circuit breaker 
transitioned to half-open")
                        return true
                }
                return false // Still in cooldown period
@@ -83,23 +83,21 @@ func (p *pub) isRequestAllowed(node string) bool {
        }
 }
 
-// recordSuccess resets the circuit breaker state to Closed on successful 
operation.
+// RecordSuccess resets the circuit breaker state to Closed on successful 
operation.
 // This handles Half-Open -> Closed transitions.
-func (p *pub) recordSuccess(node string) {
-       p.cbMu.Lock()
-       defer p.cbMu.Unlock()
-
-       cb, exists := p.cbStates[node]
+func (m *ConnManager[C]) RecordSuccess(node string) {
+       m.cbMu.Lock()
+       defer m.cbMu.Unlock()
+       cb, exists := m.cbStates[node]
        if !exists {
                // Initialize circuit breaker state
-               p.cbStates[node] = &circuitState{
+               m.cbStates[node] = &circuitState{
                        state:               StateClosed,
                        consecutiveFailures: 0,
                }
                return
        }
 
-       // Reset to closed state
        cb.state = StateClosed
        cb.consecutiveFailures = 0
        cb.lastFailureTime = time.Time{}
@@ -107,17 +105,17 @@ func (p *pub) recordSuccess(node string) {
        cb.halfOpenProbeInFlight = false // Clear probe token
 }
 
-// recordFailure updates the circuit breaker state on failed operation.
+// RecordFailure updates the circuit breaker state on failed operation.
 // Only records failures for transient/internal errors that should count 
toward opening the circuit.
-func (p *pub) recordFailure(node string, err error) {
+func (m *ConnManager[C]) RecordFailure(node string, err error) {
        // Only record failure if the error is transient or internal
-       if !isTransientError(err) && !isInternalError(err) {
+       if !IsTransientError(err) && !IsInternalError(err) {
                return
        }
-       p.cbMu.Lock()
-       defer p.cbMu.Unlock()
+       m.cbMu.Lock()
+       defer m.cbMu.Unlock()
 
-       cb, exists := p.cbStates[node]
+       cb, exists := m.cbStates[node]
        if !exists {
                // Initialize circuit breaker state
                cb = &circuitState{
@@ -125,7 +123,7 @@ func (p *pub) recordFailure(node string, err error) {
                        consecutiveFailures: 1,
                        lastFailureTime:     time.Now(),
                }
-               p.cbStates[node] = cb
+               m.cbStates[node] = cb
        } else {
                cb.consecutiveFailures++
                cb.lastFailureTime = time.Now()
@@ -136,12 +134,12 @@ func (p *pub) recordFailure(node string, err error) {
        if cb.consecutiveFailures >= threshold && cb.state == StateClosed {
                cb.state = StateOpen
                cb.openTime = time.Now()
-               p.log.Warn().Str("node", node).Int("failures", 
cb.consecutiveFailures).Msg("circuit breaker opened")
+               m.log.Warn().Str("node", node).Int("failures", 
cb.consecutiveFailures).Msg("circuit breaker opened")
        } else if cb.state == StateHalfOpen {
                // Failed during half-open, go back to open
                cb.state = StateOpen
                cb.openTime = time.Now()
                cb.halfOpenProbeInFlight = false // Clear probe token
-               p.log.Warn().Str("node", node).Msg("circuit breaker reopened 
after half-open failure")
+               m.log.Warn().Str("node", node).Msg("circuit breaker reopened 
after half-open failure")
        }
 }
diff --git a/banyand/queue/pub/circuitbreaker_test.go 
b/pkg/grpchelper/circuitbreaker_test.go
similarity index 64%
rename from banyand/queue/pub/circuitbreaker_test.go
rename to pkg/grpchelper/circuitbreaker_test.go
index 8bc902230..dda620cdc 100644
--- a/banyand/queue/pub/circuitbreaker_test.go
+++ b/pkg/grpchelper/circuitbreaker_test.go
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-package pub
+package grpchelper
 
 import (
        "sync"
@@ -41,96 +41,110 @@ var (
        errNonRetry  = status.Error(codes.InvalidArgument, "invalid argument")
 )
 
+func init() {
+       _ = logger.Init(logger.Logging{
+               Env:   "dev",
+               Level: "warn",
+       })
+}
+
+func newTestConnManager() *ConnManager[*mockClient] {
+       return NewConnManager(ConnManagerConfig[*mockClient]{
+               Handler: &mockHandler{},
+               Logger:  logger.GetLogger("test-cb"),
+       })
+}
+
 func TestCircuitBreakerStateTransitions(t *testing.T) {
        tests := []struct {
                name           string
-               setup          func(*pub)
-               actions        []func(*pub, string)
+               setup          func(*ConnManager[*mockClient])
+               actions        []func(*ConnManager[*mockClient], string)
                expectedState  CircuitState
                allowsRequests bool
        }{
                {
                        name: "closed_to_open_after_failures",
-                       setup: func(p *pub) {
-                               p.recordSuccess(testNodeName)
+                       setup: func(m *ConnManager[*mockClient]) {
+                               m.RecordSuccess(testNodeName)
                        },
-                       actions: []func(*pub, string){
-                               func(p *pub, node string) { 
p.recordFailure(node, errTransient) },
-                               func(p *pub, node string) { 
p.recordFailure(node, errTransient) },
-                               func(p *pub, node string) { 
p.recordFailure(node, errTransient) },
-                               func(p *pub, node string) { 
p.recordFailure(node, errTransient) },
-                               func(p *pub, node string) { 
p.recordFailure(node, errTransient) },
+                       actions: []func(*ConnManager[*mockClient], string){
+                               func(m *ConnManager[*mockClient], node string) 
{ m.RecordFailure(node, errTransient) },
+                               func(m *ConnManager[*mockClient], node string) 
{ m.RecordFailure(node, errTransient) },
+                               func(m *ConnManager[*mockClient], node string) 
{ m.RecordFailure(node, errTransient) },
+                               func(m *ConnManager[*mockClient], node string) 
{ m.RecordFailure(node, errTransient) },
+                               func(m *ConnManager[*mockClient], node string) 
{ m.RecordFailure(node, errTransient) },
                        },
                        expectedState:  StateOpen,
                        allowsRequests: false,
                },
                {
                        name: "closed_remains_closed_below_threshold",
-                       setup: func(p *pub) {
-                               p.recordSuccess(testNodeName)
+                       setup: func(m *ConnManager[*mockClient]) {
+                               m.RecordSuccess(testNodeName)
                        },
-                       actions: []func(*pub, string){
-                               func(p *pub, node string) { 
p.recordFailure(node, errTransient) },
-                               func(p *pub, node string) { 
p.recordFailure(node, errTransient) },
-                               func(p *pub, node string) { 
p.recordFailure(node, errTransient) },
+                       actions: []func(*ConnManager[*mockClient], string){
+                               func(m *ConnManager[*mockClient], node string) 
{ m.RecordFailure(node, errTransient) },
+                               func(m *ConnManager[*mockClient], node string) 
{ m.RecordFailure(node, errTransient) },
+                               func(m *ConnManager[*mockClient], node string) 
{ m.RecordFailure(node, errTransient) },
                        },
                        expectedState:  StateClosed,
                        allowsRequests: true,
                },
                {
                        name: "open_to_half_open_after_cooldown",
-                       setup: func(p *pub) {
-                               p.recordSuccess(testNodeName)
+                       setup: func(m *ConnManager[*mockClient]) {
+                               m.RecordSuccess(testNodeName)
                                // Trip the circuit breaker
                                for i := 0; i < defaultCBThreshold; i++ {
-                                       p.recordFailure(testNodeName, 
errTransient)
+                                       m.RecordFailure(testNodeName, 
errTransient)
                                }
                                // Simulate cooldown period has passed
-                               p.cbMu.Lock()
-                               cb := p.cbStates[testNodeName]
+                               m.cbMu.Lock()
+                               cb := m.cbStates[testNodeName]
                                cb.openTime = 
time.Now().Add(-defaultCBResetTimeout - time.Second)
-                               p.cbMu.Unlock()
+                               m.cbMu.Unlock()
                        },
-                       actions:        []func(*pub, string){},
-                       expectedState:  StateOpen, // Will transition to 
half-open in isRequestAllowed
+                       actions:        []func(*ConnManager[*mockClient], 
string){},
+                       expectedState:  StateOpen, // Will transition to 
half-open in IsRequestAllowed
                        allowsRequests: true,      // Should allow requests 
after cooldown
                },
                {
                        name: "half_open_to_closed_on_success",
-                       setup: func(p *pub) {
-                               p.recordSuccess(testNodeName)
+                       setup: func(m *ConnManager[*mockClient]) {
+                               m.RecordSuccess(testNodeName)
                                // Trip the circuit breaker
                                for i := 0; i < defaultCBThreshold; i++ {
-                                       p.recordFailure(testNodeName, 
errTransient)
+                                       m.RecordFailure(testNodeName, 
errTransient)
                                }
                                // Set to half-open state
-                               p.cbMu.Lock()
-                               cb := p.cbStates[testNodeName]
+                               m.cbMu.Lock()
+                               cb := m.cbStates[testNodeName]
                                cb.state = StateHalfOpen
-                               p.cbMu.Unlock()
+                               m.cbMu.Unlock()
                        },
-                       actions: []func(*pub, string){
-                               func(p *pub, node string) { 
p.recordSuccess(node) },
+                       actions: []func(*ConnManager[*mockClient], string){
+                               func(m *ConnManager[*mockClient], node string) 
{ m.RecordSuccess(node) },
                        },
                        expectedState:  StateClosed,
                        allowsRequests: true,
                },
                {
                        name: "half_open_to_open_on_failure",
-                       setup: func(p *pub) {
-                               p.recordSuccess(testNodeName)
+                       setup: func(m *ConnManager[*mockClient]) {
+                               m.RecordSuccess(testNodeName)
                                // Trip the circuit breaker
                                for i := 0; i < defaultCBThreshold; i++ {
-                                       p.recordFailure(testNodeName, 
errTransient)
+                                       m.RecordFailure(testNodeName, 
errTransient)
                                }
                                // Set to half-open state
-                               p.cbMu.Lock()
-                               cb := p.cbStates[testNodeName]
+                               m.cbMu.Lock()
+                               cb := m.cbStates[testNodeName]
                                cb.state = StateHalfOpen
-                               p.cbMu.Unlock()
+                               m.cbMu.Unlock()
                        },
-                       actions: []func(*pub, string){
-                               func(p *pub, node string) { 
p.recordFailure(node, errTransient) },
+                       actions: []func(*ConnManager[*mockClient], string){
+                               func(m *ConnManager[*mockClient], node string) 
{ m.RecordFailure(node, errTransient) },
                        },
                        expectedState:  StateOpen,
                        allowsRequests: false,
@@ -139,43 +153,37 @@ func TestCircuitBreakerStateTransitions(t *testing.T) {
 
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       p := &pub{
-                               cbStates: make(map[string]*circuitState),
-                               cbMu:     sync.RWMutex{},
-                               log:      logger.GetLogger("test"),
-                       }
+                       m := newTestConnManager()
+                       defer m.GracefulStop()
 
                        // Setup
                        if tt.setup != nil {
-                               tt.setup(p)
+                               tt.setup(m)
                        }
 
                        // Execute actions
                        for _, action := range tt.actions {
-                               action(p, testNodeName)
+                               action(m, testNodeName)
                        }
 
                        // Check final state
-                       p.cbMu.RLock()
-                       cb, exists := p.cbStates[testNodeName]
-                       p.cbMu.RUnlock()
+                       m.cbMu.RLock()
+                       cb, exists := m.cbStates[testNodeName]
+                       m.cbMu.RUnlock()
 
                        require.True(t, exists, "circuit breaker state should 
exist")
                        assert.Equal(t, tt.expectedState, cb.state, "circuit 
breaker state mismatch")
 
                        // Check if requests are allowed
-                       allowed := p.isRequestAllowed(testNodeName)
+                       allowed := m.IsRequestAllowed(testNodeName)
                        assert.Equal(t, tt.allowsRequests, allowed, "request 
allowance mismatch")
                })
        }
 }
 
 func TestCircuitBreakerConcurrency(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
+       m := newTestConnManager()
+       defer m.GracefulStop()
 
        const numGoroutines = 100
        const numOperations = 50
@@ -189,8 +197,8 @@ func TestCircuitBreakerConcurrency(t *testing.T) {
                go func() {
                        defer wg.Done()
                        for j := 0; j < numOperations; j++ {
-                               p.recordSuccess(node)
-                               p.isRequestAllowed(node)
+                               m.RecordSuccess(node)
+                               m.IsRequestAllowed(node)
                        }
                }()
        }
@@ -200,8 +208,8 @@ func TestCircuitBreakerConcurrency(t *testing.T) {
                go func() {
                        defer wg.Done()
                        for j := 0; j < numOperations; j++ {
-                               p.recordFailure(node, errTransient)
-                               p.isRequestAllowed(node)
+                               m.RecordFailure(node, errTransient)
+                               m.IsRequestAllowed(node)
                        }
                }()
        }
@@ -209,126 +217,115 @@ func TestCircuitBreakerConcurrency(t *testing.T) {
        wg.Wait()
 
        // Verify circuit breaker state exists and is in a valid state
-       p.cbMu.RLock()
-       cb, exists := p.cbStates[node]
-       p.cbMu.RUnlock()
+       m.cbMu.RLock()
+       cb, exists := m.cbStates[node]
+       m.cbMu.RUnlock()
 
        require.True(t, exists, "circuit breaker state should exist")
        assert.Contains(t, []CircuitState{StateClosed, StateOpen, 
StateHalfOpen}, cb.state, "circuit breaker should be in a valid state")
 }
 
 func TestCircuitBreakerMultipleNodes(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
+       m := newTestConnManager()
+       defer m.GracefulStop()
 
        nodes := []string{"node1", "node2", "node3"}
 
        // Initialize all nodes
        for _, node := range nodes {
-               p.recordSuccess(node)
+               m.RecordSuccess(node)
        }
 
        // Trip circuit breaker for node1 only
        for i := 0; i < defaultCBThreshold; i++ {
-               p.recordFailure("node1", errTransient)
+               m.RecordFailure("node1", errTransient)
        }
 
        // Add some failures to node2 but below threshold
        for i := 0; i < defaultCBThreshold-1; i++ {
-               p.recordFailure("node2", errTransient)
+               m.RecordFailure("node2", errTransient)
        }
 
        // Keep node3 healthy
-       p.recordSuccess("node3")
+       m.RecordSuccess("node3")
 
        // Verify states
-       assert.False(t, p.isRequestAllowed("node1"), "node1 should have circuit 
breaker open")
-       assert.True(t, p.isRequestAllowed("node2"), "node2 should still allow 
requests")
-       assert.True(t, p.isRequestAllowed("node3"), "node3 should allow 
requests")
+       assert.False(t, m.IsRequestAllowed("node1"), "node1 should have circuit 
breaker open")
+       assert.True(t, m.IsRequestAllowed("node2"), "node2 should still allow 
requests")
+       assert.True(t, m.IsRequestAllowed("node3"), "node3 should allow 
requests")
 
        // Check circuit breaker states
-       p.cbMu.RLock()
-       defer p.cbMu.RUnlock()
+       m.cbMu.RLock()
+       defer m.cbMu.RUnlock()
 
-       cb1, exists1 := p.cbStates["node1"]
+       cb1, exists1 := m.cbStates["node1"]
        require.True(t, exists1)
        assert.Equal(t, StateOpen, cb1.state)
 
-       cb2, exists2 := p.cbStates["node2"]
+       cb2, exists2 := m.cbStates["node2"]
        require.True(t, exists2)
        assert.Equal(t, StateClosed, cb2.state)
 
-       cb3, exists3 := p.cbStates["node3"]
+       cb3, exists3 := m.cbStates["node3"]
        require.True(t, exists3)
        assert.Equal(t, StateClosed, cb3.state)
 }
 
 func TestCircuitBreakerRecoveryAfterCooldown(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
-
+       m := newTestConnManager()
+       defer m.GracefulStop()
        node := testNodeName
 
        // Initialize node
-       p.recordSuccess(node)
+       m.RecordSuccess(node)
 
        // Trip circuit breaker
        for i := 0; i < defaultCBThreshold; i++ {
-               p.recordFailure(node, errTransient)
+               m.RecordFailure(node, errTransient)
        }
 
        // Verify circuit is open
-       assert.False(t, p.isRequestAllowed(node), "circuit should be open")
+       assert.False(t, m.IsRequestAllowed(node), "circuit should be open")
 
        // Simulate cooldown period passage
-       p.cbMu.Lock()
-       cb := p.cbStates[node]
+       m.cbMu.Lock()
+       cb := m.cbStates[node]
        cb.openTime = time.Now().Add(-defaultCBResetTimeout - time.Second)
-       p.cbMu.Unlock()
+       m.cbMu.Unlock()
 
        // Check that circuit allows requests (transitions to half-open)
-       allowed := p.isRequestAllowed(node)
+       allowed := m.IsRequestAllowed(node)
        assert.True(t, allowed, "circuit should allow requests after cooldown")
 
        // Verify state transitioned to half-open
-       p.cbMu.RLock()
+       m.cbMu.RLock()
        assert.Equal(t, StateHalfOpen, cb.state, "circuit should be in 
half-open state")
-       p.cbMu.RUnlock()
+       m.cbMu.RUnlock()
 
        // Successful request should close the circuit
-       p.recordSuccess(node)
+       m.RecordSuccess(node)
 
-       p.cbMu.RLock()
+       m.cbMu.RLock()
        assert.Equal(t, StateClosed, cb.state, "circuit should be closed after 
success")
        assert.Equal(t, 0, cb.consecutiveFailures, "failure count should be 
reset")
-       p.cbMu.RUnlock()
+       m.cbMu.RUnlock()
 }
 
 func TestCircuitBreakerInitialization(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
-
+       m := newTestConnManager()
+       defer m.GracefulStop()
        node := "new-node"
 
        // First request to non-existent circuit breaker should be allowed
-       allowed := p.isRequestAllowed(node)
+       allowed := m.IsRequestAllowed(node)
        assert.True(t, allowed, "requests should be allowed for non-existent 
circuit breaker")
 
        // Record success should initialize the circuit breaker
-       p.recordSuccess(node)
+       m.RecordSuccess(node)
 
-       p.cbMu.RLock()
-       cb, exists := p.cbStates[node]
-       p.cbMu.RUnlock()
+       m.cbMu.RLock()
+       cb, exists := m.cbStates[node]
+       m.cbMu.RUnlock()
 
        require.True(t, exists, "circuit breaker should be initialized")
        assert.Equal(t, StateClosed, cb.state, "new circuit breaker should be 
closed")
@@ -336,184 +333,164 @@ func TestCircuitBreakerInitialization(t *testing.T) {
 }
 
 func TestCircuitBreakerFailureThresholdEdgeCase(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
-
+       m := newTestConnManager()
+       defer m.GracefulStop()
        node := testNodeName
 
        // Initialize node
-       p.recordSuccess(node)
+       m.RecordSuccess(node)
 
        // Add failures just below threshold
        for i := 0; i < defaultCBThreshold-1; i++ {
-               p.recordFailure(node, errTransient)
+               m.RecordFailure(node, errTransient)
        }
 
        // Circuit should still be closed
-       p.cbMu.RLock()
-       cb := p.cbStates[node]
+       m.cbMu.RLock()
+       cb := m.cbStates[node]
        assert.Equal(t, StateClosed, cb.state, "circuit should still be closed")
        assert.Equal(t, defaultCBThreshold-1, cb.consecutiveFailures, "failure 
count should be at threshold-1")
-       p.cbMu.RUnlock()
+       m.cbMu.RUnlock()
 
        // One more failure should open the circuit
-       p.recordFailure(node, errTransient)
+       m.RecordFailure(node, errTransient)
 
-       p.cbMu.RLock()
+       m.cbMu.RLock()
        assert.Equal(t, StateOpen, cb.state, "circuit should be open after 
reaching threshold")
        assert.Equal(t, defaultCBThreshold, cb.consecutiveFailures, "failure 
count should be at threshold")
-       p.cbMu.RUnlock()
+       m.cbMu.RUnlock()
 }
 
 func TestCircuitBreakerSingleProbeEnforcement(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
-
+       m := newTestConnManager()
+       defer m.GracefulStop()
        node := testNodeName
 
        // Initialize and trip circuit breaker
-       p.recordSuccess(node)
+       m.RecordSuccess(node)
        for i := 0; i < defaultCBThreshold; i++ {
-               p.recordFailure(node, errTransient)
+               m.RecordFailure(node, errTransient)
        }
 
        // Verify circuit is open
-       assert.False(t, p.isRequestAllowed(node), "circuit should be open")
+       assert.False(t, m.IsRequestAllowed(node), "circuit should be open")
 
        // Simulate cooldown period passage
-       p.cbMu.Lock()
-       cb := p.cbStates[node]
+       m.cbMu.Lock()
+       cb := m.cbStates[node]
        cb.openTime = time.Now().Add(-defaultCBResetTimeout - time.Second)
-       p.cbMu.Unlock()
+       m.cbMu.Unlock()
 
        // First request should transition to half-open and be allowed
-       allowed1 := p.isRequestAllowed(node)
+       allowed1 := m.IsRequestAllowed(node)
        assert.True(t, allowed1, "first request after cooldown should be 
allowed and transition to half-open")
 
        // Verify state is half-open with probe in flight
-       p.cbMu.RLock()
+       m.cbMu.RLock()
        assert.Equal(t, StateHalfOpen, cb.state, "circuit should be in 
half-open state")
        assert.True(t, cb.halfOpenProbeInFlight, "probe should be marked as in 
flight")
-       p.cbMu.RUnlock()
+       m.cbMu.RUnlock()
 
        // Second request should be denied while probe is in flight
-       allowed2 := p.isRequestAllowed(node)
+       allowed2 := m.IsRequestAllowed(node)
        assert.False(t, allowed2, "second request should be denied while probe 
is in flight")
 
        // Third request should also be denied
-       allowed3 := p.isRequestAllowed(node)
+       allowed3 := m.IsRequestAllowed(node)
        assert.False(t, allowed3, "third request should also be denied while 
probe is in flight")
 }
 
 func TestCircuitBreakerSingleProbeSuccess(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
-
+       m := newTestConnManager()
+       defer m.GracefulStop()
        node := testNodeName
 
        // Setup half-open state with probe in flight
-       p.recordSuccess(node)
+       m.RecordSuccess(node)
        for i := 0; i < defaultCBThreshold; i++ {
-               p.recordFailure(node, errTransient)
+               m.RecordFailure(node, errTransient)
        }
 
-       p.cbMu.Lock()
-       cb := p.cbStates[node]
+       m.cbMu.Lock()
+       cb := m.cbStates[node]
        cb.openTime = time.Now().Add(-defaultCBResetTimeout - time.Second)
-       p.cbMu.Unlock()
+       m.cbMu.Unlock()
 
        // Transition to half-open
-       allowed := p.isRequestAllowed(node)
+       allowed := m.IsRequestAllowed(node)
        assert.True(t, allowed, "first request should be allowed")
 
        // Verify probe is in flight
-       p.cbMu.RLock()
+       m.cbMu.RLock()
        assert.True(t, cb.halfOpenProbeInFlight, "probe should be in flight")
-       p.cbMu.RUnlock()
+       m.cbMu.RUnlock()
 
        // Record success (probe succeeds)
-       p.recordSuccess(node)
+       m.RecordSuccess(node)
 
        // Verify circuit is closed and probe token is cleared
-       p.cbMu.RLock()
+       m.cbMu.RLock()
        assert.Equal(t, StateClosed, cb.state, "circuit should be closed after 
successful probe")
        assert.False(t, cb.halfOpenProbeInFlight, "probe token should be 
cleared")
-       p.cbMu.RUnlock()
+       m.cbMu.RUnlock()
 
        // Subsequent requests should be allowed in closed state
-       assert.True(t, p.isRequestAllowed(node), "requests should be allowed in 
closed state")
+       assert.True(t, m.IsRequestAllowed(node), "requests should be allowed in 
closed state")
 }
 
 func TestCircuitBreakerSingleProbeFailure(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
-
+       m := newTestConnManager()
+       defer m.GracefulStop()
        node := testNodeName
 
        // Setup half-open state with probe in flight
-       p.recordSuccess(node)
+       m.RecordSuccess(node)
        for i := 0; i < defaultCBThreshold; i++ {
-               p.recordFailure(node, errTransient)
+               m.RecordFailure(node, errTransient)
        }
 
-       p.cbMu.Lock()
-       cb := p.cbStates[node]
+       m.cbMu.Lock()
+       cb := m.cbStates[node]
        cb.openTime = time.Now().Add(-defaultCBResetTimeout - time.Second)
-       p.cbMu.Unlock()
+       m.cbMu.Unlock()
 
        // Transition to half-open
-       allowed := p.isRequestAllowed(node)
+       allowed := m.IsRequestAllowed(node)
        assert.True(t, allowed, "first request should be allowed")
 
        // Verify probe is in flight
-       p.cbMu.RLock()
+       m.cbMu.RLock()
        assert.True(t, cb.halfOpenProbeInFlight, "probe should be in flight")
-       p.cbMu.RUnlock()
+       m.cbMu.RUnlock()
 
        // Record failure (probe fails)
-       p.recordFailure(node, errTransient)
+       m.RecordFailure(node, errTransient)
 
        // Verify circuit is back to open and probe token is cleared
-       p.cbMu.RLock()
+       m.cbMu.RLock()
        assert.Equal(t, StateOpen, cb.state, "circuit should be back to open 
after failed probe")
        assert.False(t, cb.halfOpenProbeInFlight, "probe token should be 
cleared")
-       p.cbMu.RUnlock()
+       m.cbMu.RUnlock()
 
        // Subsequent requests should be denied in open state
-       assert.False(t, p.isRequestAllowed(node), "requests should be denied in 
open state")
+       assert.False(t, m.IsRequestAllowed(node), "requests should be denied in 
open state")
 }
 
 func TestCircuitBreakerConcurrentProbeAttempts(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
-
+       m := newTestConnManager()
+       defer m.GracefulStop()
        node := testNodeName
 
        // Setup open state ready for half-open transition
-       p.recordSuccess(node)
+       m.RecordSuccess(node)
        for i := 0; i < defaultCBThreshold; i++ {
-               p.recordFailure(node, errTransient)
+               m.RecordFailure(node, errTransient)
        }
 
-       p.cbMu.Lock()
-       cb := p.cbStates[node]
+       m.cbMu.Lock()
+       cb := m.cbStates[node]
        cb.openTime = time.Now().Add(-defaultCBResetTimeout - time.Second)
-       p.cbMu.Unlock()
+       m.cbMu.Unlock()
 
        const numGoroutines = 10
        var wg sync.WaitGroup
@@ -524,7 +501,7 @@ func TestCircuitBreakerConcurrentProbeAttempts(t 
*testing.T) {
                wg.Add(1)
                go func() {
                        defer wg.Done()
-                       allowed := p.isRequestAllowed(node)
+                       allowed := m.IsRequestAllowed(node)
                        results <- allowed
                }()
        }
@@ -544,56 +521,48 @@ func TestCircuitBreakerConcurrentProbeAttempts(t 
*testing.T) {
        assert.Equal(t, 1, allowedCount, "exactly one request should be allowed 
in half-open state")
 
        // Verify circuit is in half-open with probe in flight
-       p.cbMu.RLock()
+       m.cbMu.RLock()
        assert.Equal(t, StateHalfOpen, cb.state, "circuit should be in 
half-open state")
        assert.True(t, cb.halfOpenProbeInFlight, "probe should be marked as in 
flight")
-       p.cbMu.RUnlock()
+       m.cbMu.RUnlock()
 }
 
 func TestCircuitBreakerProbeTokenInitialization(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
-
+       m := newTestConnManager()
+       defer m.GracefulStop()
        node := testNodeName
 
        // Test that new circuit breaker states have probe token cleared
-       p.recordSuccess(node)
+       m.RecordSuccess(node)
 
-       p.cbMu.RLock()
-       cb := p.cbStates[node]
+       m.cbMu.RLock()
+       cb := m.cbStates[node]
        assert.False(t, cb.halfOpenProbeInFlight, "new circuit breaker should 
have probe token cleared")
-       p.cbMu.RUnlock()
+       m.cbMu.RUnlock()
 }
 
 func TestCircuitBreakerErrorFiltering(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
-
+       m := newTestConnManager()
+       defer m.GracefulStop()
        node := testNodeName
 
        // Verify initial state is closed
-       assert.True(t, p.isRequestAllowed(node))
-       p.cbMu.Lock()
-       _, exists := p.cbStates[node]
-       p.cbMu.Unlock()
+       assert.True(t, m.IsRequestAllowed(node))
+       m.cbMu.Lock()
+       _, exists := m.cbStates[node]
+       m.cbMu.Unlock()
        assert.False(t, exists)
 
        // Try to trip circuit with non-retryable error
        for i := 0; i < defaultCBThreshold*2; i++ {
-               p.recordFailure(node, errNonRetry)
+               m.RecordFailure(node, errNonRetry)
        }
 
        // Circuit should still be closed (non-retryable errors don't count)
-       assert.True(t, p.isRequestAllowed(node))
-       p.cbMu.Lock()
-       cb, exists := p.cbStates[node]
-       p.cbMu.Unlock()
+       assert.True(t, m.IsRequestAllowed(node))
+       m.cbMu.Lock()
+       cb, exists := m.cbStates[node]
+       m.cbMu.Unlock()
        // State should not exist or be closed with 0 failures
        if exists {
                assert.Equal(t, StateClosed, cb.state)
@@ -602,81 +571,77 @@ func TestCircuitBreakerErrorFiltering(t *testing.T) {
 
        // Now try with transient errors - these should count
        for i := 0; i < defaultCBThreshold; i++ {
-               p.recordFailure(node, errTransient)
+               m.RecordFailure(node, errTransient)
        }
 
        // Circuit should now be open
-       assert.False(t, p.isRequestAllowed(node))
-       p.cbMu.Lock()
-       cb, exists = p.cbStates[node]
-       p.cbMu.Unlock()
+       assert.False(t, m.IsRequestAllowed(node))
+       m.cbMu.Lock()
+       cb, exists = m.cbStates[node]
+       m.cbMu.Unlock()
        assert.True(t, exists)
        assert.Equal(t, StateOpen, cb.state)
        assert.Equal(t, defaultCBThreshold, cb.consecutiveFailures)
 
        // Reset for internal error test
-       p.recordSuccess(node)
-       assert.True(t, p.isRequestAllowed(node))
+       m.RecordSuccess(node)
+       assert.True(t, m.IsRequestAllowed(node))
 
        // Try with internal errors - these should also count
        for i := 0; i < defaultCBThreshold; i++ {
-               p.recordFailure(node, errInternal)
+               m.RecordFailure(node, errInternal)
        }
 
        // Circuit should be open again
-       assert.False(t, p.isRequestAllowed(node))
-       p.cbMu.Lock()
-       cb, exists = p.cbStates[node]
-       p.cbMu.Unlock()
+       assert.False(t, m.IsRequestAllowed(node))
+       m.cbMu.Lock()
+       cb, exists = m.cbStates[node]
+       m.cbMu.Unlock()
        assert.True(t, exists)
        assert.Equal(t, StateOpen, cb.state)
        assert.Equal(t, defaultCBThreshold, cb.consecutiveFailures)
 }
 
 func TestCircuitBreakerFailoverErrors(t *testing.T) {
-       p := &pub{
-               cbStates: make(map[string]*circuitState),
-               cbMu:     sync.RWMutex{},
-               log:      logger.GetLogger("test"),
-       }
-
+       m := newTestConnManager()
+       defer m.GracefulStop()
        node := testNodeName
 
        // Test failover errors from Recv() - using codes.Unavailable which is 
a failover error
        failoverErr := status.Error(codes.Unavailable, "service unavailable")
 
        // Verify initial state is closed
-       assert.True(t, p.isRequestAllowed(node))
+       assert.True(t, m.IsRequestAllowed(node))
 
        // Simulate failover errors that should increment circuit breaker
        for i := 0; i < defaultCBThreshold; i++ {
-               p.recordFailure(node, failoverErr)
+               m.RecordFailure(node, failoverErr)
        }
 
        // Circuit should be open after failover errors
-       assert.False(t, p.isRequestAllowed(node))
-       p.cbMu.Lock()
-       cb, exists := p.cbStates[node]
-       p.cbMu.Unlock()
+       assert.False(t, m.IsRequestAllowed(node))
+       m.cbMu.Lock()
+       cb, exists := m.cbStates[node]
+       m.cbMu.Unlock()
        assert.True(t, exists)
        assert.Equal(t, StateOpen, cb.state)
        assert.Equal(t, defaultCBThreshold, cb.consecutiveFailures)
 
        // Reset and test with common.Error (simulating Close() method errors)
-       p.recordSuccess(node)
-       assert.True(t, p.isRequestAllowed(node))
+       m.RecordSuccess(node)
+       assert.True(t, m.IsRequestAllowed(node))
 
        // Test with common.Error types that would come from Close() method
        internalErr := 
common.NewErrorWithStatus(modelv1.Status_STATUS_INTERNAL_ERROR, "internal 
error")
        for i := 0; i < defaultCBThreshold; i++ {
-               p.recordFailure(node, internalErr)
+               m.RecordFailure(node, internalErr)
        }
 
        // Circuit should be open again
-       assert.False(t, p.isRequestAllowed(node))
-       p.cbMu.Lock()
-       cb, exists = p.cbStates[node]
-       p.cbMu.Unlock()
+       assert.False(t, m.IsRequestAllowed(node))
+       m.cbMu.Lock()
+       cb, exists = m.cbStates[node]
+       m.cbMu.Unlock()
        assert.True(t, exists)
        assert.Equal(t, StateOpen, cb.state)
        assert.Equal(t, defaultCBThreshold, cb.consecutiveFailures)
diff --git a/pkg/grpchelper/connmanager.go b/pkg/grpchelper/connmanager.go
new file mode 100644
index 000000000..a480d2ee1
--- /dev/null
+++ b/pkg/grpchelper/connmanager.go
@@ -0,0 +1,526 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package grpchelper
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "sync"
+       "time"
+
+       "google.golang.org/grpc"
+       "google.golang.org/grpc/health/grpc_health_v1"
+
+       databasev1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
+       "github.com/apache/skywalking-banyandb/pkg/logger"
+       "github.com/apache/skywalking-banyandb/pkg/run"
+)
+
+var (
+       // ErrCircuitBreakerOpen is returned when the circuit breaker is open 
for a node.
+       ErrCircuitBreakerOpen = errors.New("circuit breaker open")
+       // ErrClientNotFound is returned when no active client exists for a 
node.
+       ErrClientNotFound = errors.New("client not found")
+)
+
+// Client is the minimal interface for a managed gRPC client.
+type Client interface {
+       Close() error
+}
+
+// ConnectionHandler provides upper-layer callbacks for connection lifecycle.
+type ConnectionHandler[C Client] interface {
+       // AddressOf extracts the gRPC address from a node.
+       AddressOf(node *databasev1.Node) string
+       // GetDialOptions returns gRPC dial options for the given address.
+       GetDialOptions() ([]grpc.DialOption, error)
+       // NewClient creates a client from a gRPC connection and node.
+       NewClient(conn *grpc.ClientConn, node *databasev1.Node) (C, error)
+       // OnActive is called when a node transitions to active.
+       OnActive(name string, client C)
+       // OnInactive is called when a node leaves active.
+       OnInactive(name string, client C)
+}
+
+// ConnManagerConfig holds configuration for ConnManager.
+type ConnManagerConfig[C Client] struct {
+       Handler        ConnectionHandler[C]
+       Logger         *logger.Logger
+       RetryPolicy    string
+       ExtraDialOpts  []grpc.DialOption
+       MaxRecvMsgSize int
+}
+
+// ConnManager manages gRPC connections with health checking, circuit 
breaking, and eviction.
+type ConnManager[C Client] struct {
+       handler    ConnectionHandler[C]
+       log        *logger.Logger
+       registered map[string]*databasev1.Node
+       active     map[string]*managedNode[C]
+       evictable  map[string]evictNode
+       cbStates   map[string]*circuitState
+       closer     *run.Closer
+       dialOpts   []grpc.DialOption
+       mu         sync.RWMutex
+       cbMu       sync.RWMutex
+}
+
+type managedNode[C Client] struct {
+       client C
+       conn   *grpc.ClientConn
+       node   *databasev1.Node
+}
+
+type evictNode struct {
+       n *databasev1.Node
+       c chan struct{}
+}
+
+// NewConnManager creates a new ConnManager.
+func NewConnManager[C Client](cfg ConnManagerConfig[C]) *ConnManager[C] {
+       var dialOpts []grpc.DialOption
+       if cfg.RetryPolicy != "" {
+               dialOpts = append(dialOpts, 
grpc.WithDefaultServiceConfig(cfg.RetryPolicy))
+       }
+       if cfg.MaxRecvMsgSize > 0 {
+               dialOpts = append(dialOpts, 
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.MaxRecvMsgSize)))
+       }
+       dialOpts = append(dialOpts, cfg.ExtraDialOpts...)
+       m := &ConnManager[C]{
+               handler:    cfg.Handler,
+               log:        cfg.Logger,
+               registered: make(map[string]*databasev1.Node),
+               active:     make(map[string]*managedNode[C]),
+               evictable:  make(map[string]evictNode),
+               cbStates:   make(map[string]*circuitState),
+               closer:     run.NewCloser(1),
+               dialOpts:   dialOpts,
+       }
+       return m
+}
+
+// OnAddOrUpdate registers or updates a node and manages its connection.
+func (m *ConnManager[C]) OnAddOrUpdate(node *databasev1.Node) {
+       address := m.handler.AddressOf(node)
+       if address == "" {
+               m.log.Warn().Stringer("node", node).Msg("grpc address is empty")
+               return
+       }
+       name := node.Metadata.GetName()
+       if name == "" {
+               m.log.Warn().Stringer("node", node).Msg("node name is empty")
+               return
+       }
+       m.mu.Lock()
+       defer m.mu.Unlock()
+       m.registerNode(node)
+
+       if _, ok := m.active[name]; ok {
+               return
+       }
+       if _, ok := m.evictable[name]; ok {
+               return
+       }
+       credOpts, dialErr := m.handler.GetDialOptions()
+       if dialErr != nil {
+               m.log.Error().Err(dialErr).Msg("failed to load client TLS 
credentials")
+               return
+       }
+       allOpts := make([]grpc.DialOption, 0, len(credOpts)+len(m.dialOpts))
+       allOpts = append(allOpts, credOpts...)
+       allOpts = append(allOpts, m.dialOpts...)
+       conn, connErr := grpc.NewClient(address, allOpts...)
+       if connErr != nil {
+               m.log.Error().Err(connErr).Msg("failed to connect to grpc 
server")
+               return
+       }
+
+       client, clientErr := m.handler.NewClient(conn, node)
+       if clientErr != nil {
+               m.log.Error().Err(clientErr).Msg("failed to create client")
+               _ = conn.Close()
+               return
+       }
+       if !m.checkHealthAndReconnect(conn, node, client) {
+               m.log.Info().Str("status", m.dump()).Stringer("node", 
node).Msg("node is unhealthy in the register flow, move it to evict queue")
+               return
+       }
+       m.active[name] = &managedNode[C]{conn: conn, client: client, node: node}
+       m.handler.OnActive(name, client)
+       // Initialize or reset circuit breaker state to closed
+       m.RecordSuccess(name)
+       m.log.Info().Str("status", m.dump()).Stringer("node", node).Msg("new 
node is healthy, add it to active queue")
+}
+
+// OnDelete removes a node and its connection.
+func (m *ConnManager[C]) OnDelete(node *databasev1.Node) {
+       name := node.Metadata.GetName()
+       if name == "" {
+               m.log.Warn().Stringer("node", node).Msg("node name is empty")
+               return
+       }
+       m.mu.Lock()
+       defer m.mu.Unlock()
+       delete(m.registered, name)
+       if en, ok := m.evictable[name]; ok {
+               close(en.c)
+               delete(m.evictable, name)
+               m.log.Info().Str("status", m.dump()).Stringer("node", 
node).Msg("node is removed from evict queue by delete event")
+               return
+       }
+       if mn, ok := m.active[name]; ok {
+               if m.removeNodeIfUnhealthy(name, mn) {
+                       m.log.Info().Str("status", m.dump()).Stringer("node", 
node).Msg("remove node from active queue by delete event")
+                       return
+               }
+               if !m.closer.AddRunning() {
+                       return
+               }
+               go func() {
+                       defer m.closer.Done()
+                       var elapsed time.Duration
+                       attempt := 0
+                       for {
+                               backoff := JitteredBackoff(InitBackoff, 
MaxBackoff, attempt, DefaultJitterFactor)
+                               select {
+                               case <-time.After(backoff):
+                                       if func() bool {
+                                               elapsed += backoff
+                                               m.mu.Lock()
+                                               defer m.mu.Unlock()
+                                               if _, ok := m.registered[name]; 
ok {
+                                                       return true
+                                               }
+                                               if 
m.removeNodeIfUnhealthy(name, mn) {
+                                                       
m.log.Info().Str("status", m.dump()).Stringer("node", node).Dur("after", 
elapsed).Msg("remove node from active queue by delete event")
+                                                       return true
+                                               }
+                                               return false
+                                       }() {
+                                               return
+                                       }
+                               case <-m.closer.CloseNotify():
+                                       return
+                               }
+                               attempt++
+                       }
+               }()
+       }
+}
+
+// GetClient returns the client for the given node name.
+func (m *ConnManager[C]) GetClient(name string) (C, bool) {
+       m.mu.RLock()
+       defer m.mu.RUnlock()
+       mn, ok := m.active[name]
+       if !ok {
+               var zero C
+               return zero, false
+       }
+       return mn.client, true
+}
+
+// Execute checks the circuit breaker, gets the client, calls fn, and records 
success or failure.
+func (m *ConnManager[C]) Execute(node string, fn func(C) error) error {
+       if !m.IsRequestAllowed(node) {
+               return fmt.Errorf("%w for node %s", ErrCircuitBreakerOpen, node)
+       }
+       c, ok := m.GetClient(node)
+       if !ok {
+               return fmt.Errorf("%w for node %s", ErrClientNotFound, node)
+       }
+       if callbackErr := fn(c); callbackErr != nil {
+               m.RecordFailure(node, callbackErr)
+               return callbackErr
+       }
+       m.RecordSuccess(node)
+       return nil
+}
+
+// ActiveNames returns the names of all active nodes.
+func (m *ConnManager[C]) ActiveNames() []string {
+       m.mu.RLock()
+       defer m.mu.RUnlock()
+       names := make([]string, 0, len(m.active))
+       for name := range m.active {
+               names = append(names, name)
+       }
+       return names
+}
+
+// ActiveCount returns the number of active nodes.
+func (m *ConnManager[C]) ActiveCount() int {
+       m.mu.RLock()
+       defer m.mu.RUnlock()
+       return len(m.active)
+}
+
+// EvictableCount returns the number of evictable nodes.
+func (m *ConnManager[C]) EvictableCount() int {
+       m.mu.RLock()
+       defer m.mu.RUnlock()
+       return len(m.evictable)
+}
+
+// ActiveRegisteredNodes returns the registered node info for all active nodes.
+func (m *ConnManager[C]) ActiveRegisteredNodes() []*databasev1.Node {
+       m.mu.RLock()
+       defer m.mu.RUnlock()
+       var nodes []*databasev1.Node
+       for k := range m.active {
+               if n := m.registered[k]; n != nil {
+                       nodes = append(nodes, n)
+               }
+       }
+       return nodes
+}
+
+// GetRouteTable returns a snapshot of registered, active, and evictable node 
info.
+func (m *ConnManager[C]) GetRouteTable() *databasev1.RouteTable {
+       m.mu.RLock()
+       defer m.mu.RUnlock()
+       registered := make([]*databasev1.Node, 0, len(m.registered))
+       for _, node := range m.registered {
+               if node != nil {
+                       registered = append(registered, node)
+               }
+       }
+       activeNames := make([]string, 0, len(m.active))
+       for nodeID := range m.active {
+               if node := m.registered[nodeID]; node != nil && node.Metadata 
!= nil {
+                       activeNames = append(activeNames, node.Metadata.Name)
+               }
+       }
+       evictableNames := make([]string, 0, len(m.evictable))
+       for nodeID := range m.evictable {
+               if node := m.registered[nodeID]; node != nil && node.Metadata 
!= nil {
+                       evictableNames = append(evictableNames, 
node.Metadata.Name)
+               }
+       }
+       return &databasev1.RouteTable{
+               Registered: registered,
+               Active:     activeNames,
+               Evictable:  evictableNames,
+       }
+}
+
+// FailoverNode checks health for a node and moves it to evictable if 
unhealthy.
+func (m *ConnManager[C]) FailoverNode(node string) {
+       m.mu.Lock()
+       defer m.mu.Unlock()
+       if en, ok := m.evictable[node]; ok {
+               if _, registered := m.registered[node]; !registered {
+                       close(en.c)
+                       delete(m.evictable, node)
+                       m.log.Info().Str("node", node).Str("status", 
m.dump()).Msg("node is removed from evict queue by wire event")
+               }
+               return
+       }
+       if mn, ok := m.active[node]; ok && !m.checkHealthAndReconnect(mn.conn, 
mn.node, mn.client) {
+               _ = mn.conn.Close()
+               delete(m.active, node)
+               m.handler.OnInactive(node, mn.client)
+               m.log.Info().Str("status", m.dump()).Str("node", 
node).Msg("node is unhealthy in the failover flow, move it to evict queue")
+       }
+}
+
+// ReconnectAll closes all active and evictable connections and re-registers 
all nodes.
+func (m *ConnManager[C]) ReconnectAll() {
+       m.mu.Lock()
+       nodesToReconnect := make([]*databasev1.Node, 0, len(m.registered))
+       for name, node := range m.registered {
+               if en, ok := m.evictable[name]; ok {
+                       close(en.c)
+                       delete(m.evictable, name)
+               }
+               if mn, ok := m.active[name]; ok {
+                       _ = mn.conn.Close()
+                       delete(m.active, name)
+                       m.handler.OnInactive(name, mn.client)
+               }
+               nodesToReconnect = append(nodesToReconnect, node)
+       }
+       m.mu.Unlock()
+       for _, node := range nodesToReconnect {
+               m.OnAddOrUpdate(node)
+       }
+}
+
+// GracefulStop closes all connections and stops background goroutines.
+func (m *ConnManager[C]) GracefulStop() {
+       m.mu.Lock()
+       for idx := range m.evictable {
+               close(m.evictable[idx].c)
+       }
+       m.evictable = nil
+       m.mu.Unlock()
+       m.closer.Done()
+       m.closer.CloseThenWait()
+       m.mu.Lock()
+       for _, mn := range m.active {
+               _ = mn.conn.Close()
+       }
+       m.active = nil
+       m.mu.Unlock()
+}
+
+func (m *ConnManager[C]) registerNode(node *databasev1.Node) {
+       name := node.Metadata.GetName()
+       defer func() {
+               m.registered[name] = node
+       }()
+
+       n, ok := m.registered[name]
+       if !ok {
+               return
+       }
+       if m.handler.AddressOf(n) == m.handler.AddressOf(node) {
+               return
+       }
+       if en, ok := m.evictable[name]; ok {
+               close(en.c)
+               delete(m.evictable, name)
+               m.log.Info().Str("node", name).Str("status", 
m.dump()).Msg("node is removed from evict queue by the new gRPC address updated 
event")
+       }
+       if mn, ok := m.active[name]; ok {
+               _ = mn.conn.Close()
+               delete(m.active, name)
+               m.handler.OnInactive(name, mn.client)
+               m.log.Info().Str("status", m.dump()).Str("node", 
name).Msg("node is removed from active queue by the new gRPC address updated 
event")
+       }
+}
+
+func (m *ConnManager[C]) removeNodeIfUnhealthy(name string, mn 
*managedNode[C]) bool {
+       if m.healthCheck(mn.node.String(), mn.conn) {
+               return false
+       }
+       _ = mn.conn.Close()
+       delete(m.active, name)
+       m.handler.OnInactive(name, mn.client)
+       return true
+}
+
+// checkHealthAndReconnect checks if a node is healthy. If not, closes the 
conn,
+// adds to evictable, calls OnInactive, and starts a retry goroutine.
+// Returns true if healthy.
+func (m *ConnManager[C]) checkHealthAndReconnect(conn *grpc.ClientConn, node 
*databasev1.Node, client C) bool {
+       if m.healthCheck(node.String(), conn) {
+               return true
+       }
+       _ = conn.Close()
+       if !m.closer.AddRunning() {
+               return false
+       }
+       name := node.Metadata.Name
+       m.evictable[name] = evictNode{n: node, c: make(chan struct{})}
+       m.handler.OnInactive(name, client)
+       go func(name string, en evictNode) {
+               defer m.closer.Done()
+               attempt := 0
+               for {
+                       backoff := JitteredBackoff(InitBackoff, MaxBackoff, 
attempt, DefaultJitterFactor)
+                       select {
+                       case <-time.After(backoff):
+                               address := m.handler.AddressOf(en.n)
+                               credOpts, errEvict := m.handler.GetDialOptions()
+                               if errEvict != nil {
+                                       m.log.Error().Err(errEvict).Msg("failed 
to load client TLS credentials (evict)")
+                                       return
+                               }
+                               allOpts := make([]grpc.DialOption, 0, 
len(credOpts)+len(m.dialOpts))
+                               allOpts = append(allOpts, credOpts...)
+                               allOpts = append(allOpts, m.dialOpts...)
+                               connEvict, errEvict := grpc.NewClient(address, 
allOpts...)
+                               if errEvict == nil && 
m.healthCheck(en.n.String(), connEvict) {
+                                       func() {
+                                               m.mu.Lock()
+                                               defer m.mu.Unlock()
+                                               if _, ok := m.evictable[name]; 
!ok {
+                                                       // The client has been 
removed from evict clients map, just return
+                                                       return
+                                               }
+                                               newClient, clientErr := 
m.handler.NewClient(connEvict, en.n)
+                                               if clientErr != nil {
+                                                       
m.log.Error().Err(clientErr).Msg("failed to create client during reconnect")
+                                                       _ = connEvict.Close()
+                                                       return
+                                               }
+                                               m.active[name] = 
&managedNode[C]{conn: connEvict, client: newClient, node: en.n}
+                                               m.handler.OnActive(name, 
newClient)
+                                               delete(m.evictable, name)
+                                               m.RecordSuccess(name)
+                                               m.log.Info().Str("status", 
m.dump()).Stringer("node", en.n).Msg("node is healthy, move it back to active 
queue")
+                                       }()
+                                       return
+                               }
+                               if connEvict != nil {
+                                       _ = connEvict.Close()
+                               }
+                               if func() bool {
+                                       m.mu.RLock()
+                                       defer m.mu.RUnlock()
+                                       _, ok := m.registered[name]
+                                       return !ok
+                               }() {
+                                       return
+                               }
+                               m.log.Error().Err(errEvict).Msgf("failed to 
re-connect to grpc server %s after waiting for %s", address, backoff)
+                       case <-en.c:
+                               return
+                       case <-m.closer.CloseNotify():
+                               return
+                       }
+                       attempt++
+               }
+       }(name, m.evictable[name])
+       return false
+}
+
+func (m *ConnManager[C]) healthCheck(node string, conn *grpc.ClientConn) bool {
+       var resp *grpc_health_v1.HealthCheckResponse
+       if requestErr := Request(context.Background(), 2*time.Second, 
func(rpcCtx context.Context) (err error) {
+               resp, err = grpc_health_v1.NewHealthClient(conn).Check(rpcCtx,
+                       &grpc_health_v1.HealthCheckRequest{
+                               Service: "",
+                       })
+               return err
+       }); requestErr != nil {
+               if e := m.log.Debug(); e.Enabled() {
+                       e.Err(requestErr).Str("node", node).Msg("service 
unhealthy")
+               }
+               return false
+       }
+       return resp.GetStatus() == grpc_health_v1.HealthCheckResponse_SERVING
+}
+
+func (m *ConnManager[C]) dump() string {
+       keysRegistered := make([]string, 0, len(m.registered))
+       for k := range m.registered {
+               keysRegistered = append(keysRegistered, k)
+       }
+       keysActive := make([]string, 0, len(m.active))
+       for k := range m.active {
+               keysActive = append(keysActive, k)
+       }
+       keysEvictable := make([]string, 0, len(m.evictable))
+       for k := range m.evictable {
+               keysEvictable = append(keysEvictable, k)
+       }
+       return fmt.Sprintf("registered: %v, active :%v, evictable :%v", 
keysRegistered, keysActive, keysEvictable)
+}
diff --git a/pkg/grpchelper/connmanager_test.go 
b/pkg/grpchelper/connmanager_test.go
new file mode 100644
index 000000000..f85303fbf
--- /dev/null
+++ b/pkg/grpchelper/connmanager_test.go
@@ -0,0 +1,115 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package grpchelper
+
+import (
+       "testing"
+
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
+       "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/status"
+)
+
+func addTestClient(m *ConnManager[*mockClient], node string) {
+       m.mu.Lock()
+       m.active[node] = &managedNode[*mockClient]{client: &mockClient{}}
+       m.mu.Unlock()
+}
+
+func removeTestClient(m *ConnManager[*mockClient], node string) {
+       m.mu.Lock()
+       delete(m.active, node)
+       m.mu.Unlock()
+}
+
+func TestExecute_Success(t *testing.T) {
+       m := newTestConnManager()
+       node := "exec-success"
+       addTestClient(m, node)
+       defer func() {
+               removeTestClient(m, node)
+               m.GracefulStop()
+       }()
+       m.RecordSuccess(node)
+
+       called := false
+       execErr := m.Execute(node, func(_ *mockClient) error {
+               called = true
+               return nil
+       })
+       require.NoError(t, execErr)
+       assert.True(t, called)
+       assert.True(t, m.IsRequestAllowed(node), "circuit breaker should still 
allow requests after success")
+}
+
+func TestExecute_CallbackFailure(t *testing.T) {
+       m := newTestConnManager()
+       node := "exec-fail"
+       addTestClient(m, node)
+       defer func() {
+               removeTestClient(m, node)
+               m.GracefulStop()
+       }()
+       m.RecordSuccess(node)
+
+       cbErr := status.Error(codes.Unavailable, "callback error")
+       for i := 0; i < defaultCBThreshold; i++ {
+               execErr := m.Execute(node, func(_ *mockClient) error {
+                       return cbErr
+               })
+               require.Error(t, execErr)
+       }
+       assert.False(t, m.IsRequestAllowed(node), "circuit breaker should open 
after repeated failures")
+}
+
+func TestExecute_CircuitBreakerOpen(t *testing.T) {
+       m := newTestConnManager()
+       node := "exec-cb-open"
+       addTestClient(m, node)
+       defer func() {
+               removeTestClient(m, node)
+               m.GracefulStop()
+       }()
+
+       for i := 0; i < defaultCBThreshold; i++ {
+               m.RecordFailure(node, errTransient)
+       }
+
+       execErr := m.Execute(node, func(_ *mockClient) error {
+               t.Fatal("callback should not be called when circuit breaker is 
open")
+               return nil
+       })
+       require.Error(t, execErr)
+       assert.ErrorIs(t, execErr, ErrCircuitBreakerOpen)
+       assert.Contains(t, execErr.Error(), node)
+}
+
+func TestExecute_ClientNotFound(t *testing.T) {
+       m := newTestConnManager()
+       defer m.GracefulStop()
+       node := "exec-no-client"
+
+       execErr := m.Execute(node, func(_ *mockClient) error {
+               t.Fatal("callback should not be called when client is not 
found")
+               return nil
+       })
+       require.Error(t, execErr)
+       assert.ErrorIs(t, execErr, ErrClientNotFound)
+       assert.Contains(t, execErr.Error(), node)
+}
diff --git a/pkg/grpchelper/helpers_test.go b/pkg/grpchelper/helpers_test.go
new file mode 100644
index 000000000..1c6c9815c
--- /dev/null
+++ b/pkg/grpchelper/helpers_test.go
@@ -0,0 +1,67 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package grpchelper
+
+import (
+       "google.golang.org/grpc"
+
+       databasev1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
+)
+
+// mockClient implements the Client interface for tests.
+type mockClient struct {
+       closed bool
+}
+
+func (m *mockClient) Close() error {
+       m.closed = true
+       return nil
+}
+
+// mockHandler implements ConnectionHandler[*mockClient] for tests.
+type mockHandler struct {
+       activeNodes   map[string]*mockClient
+       inactiveNodes map[string]*mockClient
+}
+
+func (h *mockHandler) AddressOf(node *databasev1.Node) string {
+       return node.GetGrpcAddress()
+}
+
+func (h *mockHandler) GetDialOptions() ([]grpc.DialOption, error) {
+       return nil, nil
+}
+
+func (h *mockHandler) NewClient(_ *grpc.ClientConn, _ *databasev1.Node) 
(*mockClient, error) {
+       return &mockClient{}, nil
+}
+
+func (h *mockHandler) OnActive(name string, client *mockClient) {
+       if h.activeNodes == nil {
+               h.activeNodes = make(map[string]*mockClient)
+       }
+       h.activeNodes[name] = client
+}
+
+func (h *mockHandler) OnInactive(name string, client *mockClient) {
+       if h.inactiveNodes == nil {
+               h.inactiveNodes = make(map[string]*mockClient)
+       }
+       h.inactiveNodes[name] = client
+       delete(h.activeNodes, name)
+}
diff --git a/banyand/queue/pub/retry.go b/pkg/grpchelper/retry.go
similarity index 72%
rename from banyand/queue/pub/retry.go
rename to pkg/grpchelper/retry.go
index 1adfa0e2b..2478a4c3d 100644
--- a/banyand/queue/pub/retry.go
+++ b/pkg/grpchelper/retry.go
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-package pub
+package grpchelper
 
 import (
        "crypto/rand"
@@ -30,18 +30,14 @@ import (
        modelv1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
 )
 
-const (
-       defaultJitterFactor      = 0.2
-       defaultMaxRetries        = 3
-       defaultPerRequestTimeout = 2 * time.Second
-       defaultBackoffBase       = 500 * time.Millisecond
-       defaultBackoffMax        = 30 * time.Second
-)
+// DefaultJitterFactor is the default jitter factor for backoff calculations.
+const DefaultJitterFactor = 0.2
 
 var (
-       // Retry policy for health check.
-       initBackoff = time.Second
-       maxBackoff  = 20 * time.Second
+       // InitBackoff is the initial backoff duration for health check retries.
+       InitBackoff = time.Second
+       // MaxBackoff is the maximum backoff duration for health check retries.
+       MaxBackoff = 20 * time.Second
 
        // Retryable gRPC status codes for streaming send retries.
        retryableCodes = map[codes.Code]bool{
@@ -77,9 +73,9 @@ func secureRandFloat64() float64 {
        return float64(n.Uint64()) / float64(1<<53)
 }
 
-// jitteredBackoff calculates backoff duration with jitter to avoid thundering 
herds.
+// JitteredBackoff calculates backoff duration with jitter to avoid thundering 
herds.
 // Uses bounded symmetric jitter: backoff * (1 + jitter * (rand() - 0.5) * 2).
-func jitteredBackoff(baseBackoff, maxBackoff time.Duration, attempt int, 
jitterFactor float64) time.Duration {
+func JitteredBackoff(baseBackoff, maxBackoff time.Duration, attempt int, 
jitterFactor float64) time.Duration {
        if jitterFactor < 0 {
                jitterFactor = 0
        }
@@ -114,15 +110,29 @@ func jitteredBackoff(baseBackoff, maxBackoff 
time.Duration, attempt int, jitterF
        return jitteredDuration
 }
 
-// isTransientError checks if the error is considered transient and retryable.
-func isTransientError(err error) bool {
+// statusCodeFromError extracts the gRPC status code from an error,
+// supporting both direct gRPC status errors and wrapped errors.
+func statusCodeFromError(err error) (codes.Code, bool) {
+       if s, ok := status.FromError(err); ok {
+               return s.Code(), true
+       }
+       type grpcStatusProvider interface{ GRPCStatus() *status.Status }
+       var se grpcStatusProvider
+       if errors.As(err, &se) {
+               return se.GRPCStatus().Code(), true
+       }
+       return codes.OK, false
+}
+
+// IsTransientError checks if the error is considered transient and retryable.
+func IsTransientError(err error) bool {
        if err == nil {
                return false
        }
 
        // Handle gRPC status errors
-       if s, ok := status.FromError(err); ok {
-               return retryableCodes[s.Code()]
+       if code, ok := statusCodeFromError(err); ok {
+               return retryableCodes[code]
        }
 
        // Handle common.Error types
@@ -140,15 +150,24 @@ func isTransientError(err error) bool {
        return false
 }
 
-// isInternalError checks if the error is an internal server error.
-func isInternalError(err error) bool {
+// IsFailoverError checks if the error indicates the node should be failed 
over.
+func IsFailoverError(err error) bool {
+       code, ok := statusCodeFromError(err)
+       if !ok {
+               return false
+       }
+       return code == codes.Unavailable || code == codes.DeadlineExceeded
+}
+
+// IsInternalError checks if the error is an internal server error.
+func IsInternalError(err error) bool {
        if err == nil {
                return false
        }
 
        // Handle gRPC status errors
-       if s, ok := status.FromError(err); ok {
-               return s.Code() == codes.Internal
+       if code, ok := statusCodeFromError(err); ok {
+               return code == codes.Internal
        }
 
        // Handle common.Error types
diff --git a/pkg/grpchelper/retry_test.go b/pkg/grpchelper/retry_test.go
new file mode 100644
index 000000000..d13962640
--- /dev/null
+++ b/pkg/grpchelper/retry_test.go
@@ -0,0 +1,500 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package grpchelper
+
+import (
+       "errors"
+       "fmt"
+       "testing"
+       "time"
+
+       pkgerrors "github.com/pkg/errors"
+       "github.com/stretchr/testify/assert"
+       "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/status"
+)
+
+// customUnwrapError wraps a single error via Unwrap() error.
+type customUnwrapError struct {
+       cause error
+       msg   string
+}
+
+func (e *customUnwrapError) Error() string { return e.msg + ": " + 
e.cause.Error() }
+func (e *customUnwrapError) Unwrap() error { return e.cause }
+
+// multiUnwrapError wraps multiple errors via Unwrap() []error.
+type multiUnwrapError struct {
+       msg    string
+       causes []error
+}
+
+func (e *multiUnwrapError) Error() string   { return e.msg }
+func (e *multiUnwrapError) Unwrap() []error { return e.causes }
+
+func TestJitteredBackoff(t *testing.T) {
+       tests := []struct {
+               name         string
+               baseBackoff  time.Duration
+               maxBackoff   time.Duration
+               attempt      int
+               jitterFactor float64
+               expectedMin  time.Duration
+               expectedMax  time.Duration
+       }{
+               {
+                       name:         "zero_attempt",
+                       baseBackoff:  100 * time.Millisecond,
+                       maxBackoff:   10 * time.Second,
+                       attempt:      0,
+                       jitterFactor: 0.2,
+                       expectedMin:  80 * time.Millisecond,
+                       expectedMax:  120 * time.Millisecond,
+               },
+               {
+                       name:         "first_attempt",
+                       baseBackoff:  100 * time.Millisecond,
+                       maxBackoff:   10 * time.Second,
+                       attempt:      1,
+                       jitterFactor: 0.2,
+                       expectedMin:  160 * time.Millisecond,
+                       expectedMax:  240 * time.Millisecond,
+               },
+               {
+                       name:         "max_backoff_reached",
+                       baseBackoff:  100 * time.Millisecond,
+                       maxBackoff:   300 * time.Millisecond,
+                       attempt:      10,
+                       jitterFactor: 0.2,
+                       expectedMin:  240 * time.Millisecond,
+                       expectedMax:  300 * time.Millisecond,
+               },
+               {
+                       name:         "no_jitter",
+                       baseBackoff:  100 * time.Millisecond,
+                       maxBackoff:   10 * time.Second,
+                       attempt:      0,
+                       jitterFactor: 0.0,
+                       expectedMin:  100 * time.Millisecond,
+                       expectedMax:  100 * time.Millisecond,
+               },
+               {
+                       name:         "max_jitter",
+                       baseBackoff:  100 * time.Millisecond,
+                       maxBackoff:   10 * time.Second,
+                       attempt:      0,
+                       jitterFactor: 1.0,
+                       expectedMin:  0,
+                       expectedMax:  200 * time.Millisecond,
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       // Test multiple times to account for randomness
+                       for i := 0; i < 100; i++ {
+                               result := JitteredBackoff(tt.baseBackoff, 
tt.maxBackoff, tt.attempt, tt.jitterFactor)
+
+                               assert.GreaterOrEqual(t, result, tt.expectedMin,
+                                       "result %v should be >= expected min 
%v", result, tt.expectedMin)
+                               assert.LessOrEqual(t, result, tt.expectedMax,
+                                       "result %v should be <= expected max 
%v", result, tt.expectedMax)
+                       }
+               })
+       }
+}
+
+func TestJitteredBackoffEdgeCases(t *testing.T) {
+       tests := []struct {
+               name         string
+               baseBackoff  time.Duration
+               maxBackoff   time.Duration
+               attempt      int
+               jitterFactor float64
+       }{
+               {
+                       name:         "negative_jitter_factor",
+                       baseBackoff:  100 * time.Millisecond,
+                       maxBackoff:   10 * time.Second,
+                       attempt:      0,
+                       jitterFactor: -0.5,
+               },
+               {
+                       name:         "jitter_factor_greater_than_one",
+                       baseBackoff:  100 * time.Millisecond,
+                       maxBackoff:   10 * time.Second,
+                       attempt:      0,
+                       jitterFactor: 1.5,
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       // Should not panic and should return a reasonable value
+                       result := JitteredBackoff(tt.baseBackoff, 
tt.maxBackoff, tt.attempt, tt.jitterFactor)
+                       assert.GreaterOrEqual(t, result, time.Duration(0), 
"result should be non-negative")
+                       assert.LessOrEqual(t, result, tt.maxBackoff, "result 
should not exceed max backoff")
+               })
+       }
+}
+
+func TestIsTransientError(t *testing.T) {
+       tests := []struct {
+               err         error
+               name        string
+               expectRetry bool
+       }{
+               {
+                       name:        "nil_error",
+                       err:         nil,
+                       expectRetry: false,
+               },
+               {
+                       name:        "unavailable_error",
+                       err:         status.Error(codes.Unavailable, "service 
unavailable"),
+                       expectRetry: true,
+               },
+               {
+                       name:        "deadline_exceeded_error",
+                       err:         status.Error(codes.DeadlineExceeded, 
"deadline exceeded"),
+                       expectRetry: true,
+               },
+               {
+                       name:        "resource_exhausted_error",
+                       err:         status.Error(codes.ResourceExhausted, 
"rate limited"),
+                       expectRetry: true,
+               },
+               {
+                       name:        "internal_error",
+                       err:         status.Error(codes.Internal, "internal"),
+                       expectRetry: true,
+               },
+               {
+                       name:        "not_found_error",
+                       err:         status.Error(codes.NotFound, "not found"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "permission_denied_error",
+                       err:         status.Error(codes.PermissionDenied, 
"permission denied"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "invalid_argument_error",
+                       err:         status.Error(codes.InvalidArgument, 
"invalid argument"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "ok_error",
+                       err:         status.Error(codes.OK, "ok"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "canceled_error",
+                       err:         status.Error(codes.Canceled, "canceled"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "unknown_error",
+                       err:         status.Error(codes.Unknown, "unknown"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "already_exists_error",
+                       err:         status.Error(codes.AlreadyExists, 
"exists"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "failed_precondition_error",
+                       err:         status.Error(codes.FailedPrecondition, 
"failed"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "aborted_error",
+                       err:         status.Error(codes.Aborted, "aborted"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "out_of_range_error",
+                       err:         status.Error(codes.OutOfRange, "range"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "unimplemented_error",
+                       err:         status.Error(codes.Unimplemented, 
"unimpl"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "data_loss_error",
+                       err:         status.Error(codes.DataLoss, "loss"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "unauthenticated_error",
+                       err:         status.Error(codes.Unauthenticated, 
"unauth"),
+                       expectRetry: false,
+               },
+               {
+                       name:        "generic_error",
+                       err:         fmt.Errorf("generic error"),
+                       expectRetry: false,
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       result := IsTransientError(tt.err)
+                       assert.Equal(t, tt.expectRetry, result, "transient 
error classification mismatch")
+               })
+       }
+}
+
+func TestIsFailoverError(t *testing.T) {
+       tests := []struct {
+               err            error
+               name           string
+               expectFailover bool
+       }{
+               {name: "nil_error", err: nil, expectFailover: false},
+               {name: "unavailable_error", err: 
status.Error(codes.Unavailable, "unavailable"), expectFailover: true},
+               {name: "deadline_exceeded_error", err: 
status.Error(codes.DeadlineExceeded, "timeout"), expectFailover: true},
+               {name: "internal_error", err: status.Error(codes.Internal, 
"internal"), expectFailover: false},
+               {name: "invalid_argument_error", err: 
status.Error(codes.InvalidArgument, "bad"), expectFailover: false},
+               {name: "ok_error", err: status.Error(codes.OK, "ok"), 
expectFailover: false},
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       result := IsFailoverError(tt.err)
+                       assert.Equal(t, tt.expectFailover, result, "failover 
error classification mismatch")
+               })
+       }
+}
+
+func TestIsTransientErrorWrapped(t *testing.T) {
+       grpcUnavailable := status.Error(codes.Unavailable, "unavailable")
+       grpcInternal := status.Error(codes.Internal, "internal")
+       grpcInvalidArg := status.Error(codes.InvalidArgument, "bad arg")
+       grpcDeadline := status.Error(codes.DeadlineExceeded, "timeout")
+
+       tests := []struct {
+               err         error
+               name        string
+               expectRetry bool
+       }{
+               // direct gRPC status (baseline)
+               {name: "direct_unavailable", err: grpcUnavailable, expectRetry: 
true},
+               {name: "direct_internal", err: grpcInternal, expectRetry: true},
+               {name: "direct_invalid_argument", err: grpcInvalidArg, 
expectRetry: false},
+               {name: "direct_deadline_exceeded", err: grpcDeadline, 
expectRetry: true},
+               // fmt.Errorf %w
+               {name: "fmt_errorf_unavailable", err: fmt.Errorf("wrap: %w", 
grpcUnavailable), expectRetry: true},
+               {name: "fmt_errorf_internal", err: fmt.Errorf("wrap: %w", 
grpcInternal), expectRetry: true},
+               {name: "fmt_errorf_invalid_argument", err: fmt.Errorf("wrap: 
%w", grpcInvalidArg), expectRetry: false},
+               {
+                       name: "fmt_errorf_double_wrapped",
+                       err:  fmt.Errorf("outer: %w", fmt.Errorf("inner: %w", 
grpcDeadline)), expectRetry: true,
+               },
+               // errors.New with message only — no gRPC status in chain
+               {name: "errors_new_with_grpc_message_text", err: 
errors.New(grpcUnavailable.Error()), expectRetry: false},
+               // errors.Join
+               {name: "errors_join_unavailable", err: 
errors.Join(errors.New("extra"), grpcUnavailable), expectRetry: true},
+               {name: "errors_join_invalid_argument", err: 
errors.Join(errors.New("extra"), grpcInvalidArg), expectRetry: false},
+               {
+                       name: "errors_join_multiple_grpc_errors",
+                       err:  errors.Join(grpcInvalidArg, grpcUnavailable), 
expectRetry: false,
+               },
+               // custom Unwrap() error
+               {
+                       name: "custom_unwrap_internal",
+                       err:  &customUnwrapError{msg: "custom", cause: 
grpcInternal}, expectRetry: true,
+               },
+               {
+                       name: "custom_unwrap_invalid_argument",
+                       err:  &customUnwrapError{msg: "custom", cause: 
grpcInvalidArg}, expectRetry: false,
+               },
+               {
+                       name: "custom_unwrap_nested_in_fmt_errorf",
+                       err:  fmt.Errorf("outer: %w", &customUnwrapError{msg: 
"inner", cause: grpcUnavailable}), expectRetry: true,
+               },
+               // custom Unwrap() []error
+               {
+                       name: "multi_unwrap_internal",
+                       err:  &multiUnwrapError{msg: "multi", causes: 
[]error{errors.New("other"), grpcInternal}}, expectRetry: true,
+               },
+               {
+                       name: "multi_unwrap_all_non_retryable",
+                       err:  &multiUnwrapError{msg: "multi", causes: 
[]error{errors.New("a"), grpcInvalidArg}}, expectRetry: false,
+               },
+               // pkg/errors Wrap and Wrapf
+               {name: "pkgerrors_wrap_unavailable", err: 
pkgerrors.Wrap(grpcUnavailable, "wrap"), expectRetry: true},
+               {name: "pkgerrors_wrapf_internal", err: 
pkgerrors.Wrapf(grpcInternal, "wrap %s", "ctx"), expectRetry: true},
+               {name: "pkgerrors_wrap_invalid_argument", err: 
pkgerrors.Wrap(grpcInvalidArg, "wrap"), expectRetry: false},
+               {
+                       name: "pkgerrors_wrap_double_wrapped",
+                       err:  pkgerrors.Wrap(pkgerrors.Wrap(grpcDeadline, 
"inner"), "outer"), expectRetry: true,
+               },
+               // plain errors — no gRPC status
+               {name: "plain_errors_new", err: errors.New("plain error"), 
expectRetry: false},
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       result := IsTransientError(tt.err)
+                       assert.Equal(t, tt.expectRetry, result, "transient 
error classification mismatch")
+               })
+       }
+}
+
+func TestIsFailoverErrorWrapped(t *testing.T) {
+       grpcUnavailable := status.Error(codes.Unavailable, "unavailable")
+       grpcDeadline := status.Error(codes.DeadlineExceeded, "timeout")
+       grpcInternal := status.Error(codes.Internal, "internal")
+
+       tests := []struct {
+               err            error
+               name           string
+               expectFailover bool
+       }{
+               // direct gRPC status (baseline)
+               {name: "direct_unavailable", err: grpcUnavailable, 
expectFailover: true},
+               {name: "direct_deadline_exceeded", err: grpcDeadline, 
expectFailover: true},
+               {name: "direct_internal", err: grpcInternal, expectFailover: 
false},
+               // fmt.Errorf %w
+               {name: "fmt_errorf_unavailable", err: fmt.Errorf("wrap: %w", 
grpcUnavailable), expectFailover: true},
+               {name: "fmt_errorf_deadline_exceeded", err: fmt.Errorf("wrap: 
%w", grpcDeadline), expectFailover: true},
+               {name: "fmt_errorf_internal", err: fmt.Errorf("wrap: %w", 
grpcInternal), expectFailover: false},
+               {
+                       name: "fmt_errorf_double_wrapped",
+                       err:  fmt.Errorf("outer: %w", fmt.Errorf("inner: %w", 
grpcUnavailable)), expectFailover: true,
+               },
+               // errors.New with message only — no gRPC status in chain
+               {name: "errors_new_with_grpc_message_text", err: 
errors.New(grpcUnavailable.Error()), expectFailover: false},
+               // errors.Join
+               {name: "errors_join_unavailable", err: 
errors.Join(errors.New("extra"), grpcUnavailable), expectFailover: true},
+               {name: "errors_join_internal", err: 
errors.Join(errors.New("extra"), grpcInternal), expectFailover: false},
+               // custom Unwrap() error
+               {
+                       name: "custom_unwrap_deadline_exceeded",
+                       err:  &customUnwrapError{msg: "custom", cause: 
grpcDeadline}, expectFailover: true,
+               },
+               {
+                       name: "custom_unwrap_internal",
+                       err:  &customUnwrapError{msg: "custom", cause: 
grpcInternal}, expectFailover: false,
+               },
+               // custom Unwrap() []error
+               {
+                       name: "multi_unwrap_unavailable",
+                       err:  &multiUnwrapError{msg: "multi", causes: 
[]error{errors.New("other"), grpcUnavailable}}, expectFailover: true,
+               },
+               // pkg/errors Wrap and Wrapf
+               {name: "pkgerrors_wrap_unavailable", err: 
pkgerrors.Wrap(grpcUnavailable, "wrap"), expectFailover: true},
+               {name: "pkgerrors_wrapf_deadline_exceeded", err: 
pkgerrors.Wrapf(grpcDeadline, "wrap %s", "ctx"), expectFailover: true},
+               {name: "pkgerrors_wrap_internal", err: 
pkgerrors.Wrap(grpcInternal, "wrap"), expectFailover: false},
+               // plain errors — no gRPC status
+               {name: "plain_errors_new", err: errors.New("plain error"), 
expectFailover: false},
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       result := IsFailoverError(tt.err)
+                       assert.Equal(t, tt.expectFailover, result, "failover 
error classification mismatch")
+               })
+       }
+}
+
+func TestIsInternalErrorWrapped(t *testing.T) {
+       grpcInternal := status.Error(codes.Internal, "internal")
+       grpcUnavailable := status.Error(codes.Unavailable, "unavailable")
+
+       tests := []struct {
+               err            error
+               name           string
+               expectInternal bool
+       }{
+               // direct gRPC status (baseline)
+               {name: "direct_internal", err: grpcInternal, expectInternal: 
true},
+               {name: "direct_unavailable", err: grpcUnavailable, 
expectInternal: false},
+               // fmt.Errorf %w
+               {name: "fmt_errorf_internal", err: fmt.Errorf("wrap: %w", 
grpcInternal), expectInternal: true},
+               {name: "fmt_errorf_unavailable", err: fmt.Errorf("wrap: %w", 
grpcUnavailable), expectInternal: false},
+               {
+                       name: "fmt_errorf_double_wrapped",
+                       err:  fmt.Errorf("outer: %w", fmt.Errorf("inner: %w", 
grpcInternal)), expectInternal: true,
+               },
+               // errors.New with message only — no gRPC status in chain
+               {name: "errors_new_with_grpc_message_text", err: 
errors.New(grpcInternal.Error()), expectInternal: false},
+               // errors.Join
+               {name: "errors_join_internal", err: 
errors.Join(errors.New("extra"), grpcInternal), expectInternal: true},
+               {name: "errors_join_unavailable", err: 
errors.Join(errors.New("extra"), grpcUnavailable), expectInternal: false},
+               // custom Unwrap() error
+               {
+                       name: "custom_unwrap_internal",
+                       err:  &customUnwrapError{msg: "custom", cause: 
grpcInternal}, expectInternal: true,
+               },
+               {
+                       name: "custom_unwrap_unavailable",
+                       err:  &customUnwrapError{msg: "custom", cause: 
grpcUnavailable}, expectInternal: false,
+               },
+               // custom Unwrap() []error
+               {
+                       name: "multi_unwrap_internal",
+                       err:  &multiUnwrapError{msg: "multi", causes: 
[]error{errors.New("other"), grpcInternal}}, expectInternal: true,
+               },
+               {
+                       name: "multi_unwrap_all_non_internal",
+                       err:  &multiUnwrapError{msg: "multi", causes: 
[]error{errors.New("a"), grpcUnavailable}}, expectInternal: false,
+               },
+               // pkg/errors Wrap and Wrapf
+               {name: "pkgerrors_wrap_internal", err: 
pkgerrors.Wrap(grpcInternal, "wrap"), expectInternal: true},
+               {name: "pkgerrors_wrapf_internal", err: 
pkgerrors.Wrapf(grpcInternal, "wrap %s", "ctx"), expectInternal: true},
+               {name: "pkgerrors_wrap_unavailable", err: 
pkgerrors.Wrap(grpcUnavailable, "wrap"), expectInternal: false},
+               // plain errors — no gRPC status
+               {name: "plain_errors_new", err: errors.New("plain error"), 
expectInternal: false},
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       result := IsInternalError(tt.err)
+                       assert.Equal(t, tt.expectInternal, result, "internal 
error classification mismatch")
+               })
+       }
+}
+
+func TestBackoffDistribution(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping distribution test in short mode")
+       }
+
+       const samples = 100
+       baseBackoff := 100 * time.Millisecond
+       maxBackoff := 10 * time.Second
+       jitter := 0.2
+
+       var totalBackoff time.Duration
+       for i := 0; i < samples; i++ {
+               b := JitteredBackoff(baseBackoff, maxBackoff, 0, jitter)
+               totalBackoff += b
+               assert.GreaterOrEqual(t, b, 80*time.Millisecond, "backoff 
should be >= 80ms with 0.2 jitter")
+               assert.LessOrEqual(t, b, 120*time.Millisecond, "backoff should 
be <= 120ms with 0.2 jitter")
+       }
+
+       avgBackoff := totalBackoff / time.Duration(samples)
+       assert.InDelta(t, float64(100*time.Millisecond), float64(avgBackoff), 
float64(20*time.Millisecond),
+               "average backoff should be near base duration")
+}

Reply via email to