This is an automated email from the ASF dual-hosted git repository.
hanahmily pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git
The following commit(s) were added to refs/heads/main by this push:
new 74334e027 Refactor part of the pub.Client as common library (#972)
74334e027 is described below
commit 74334e027b7fad414c9773f803163b5c97a9655b
Author: mrproliu <[email protected]>
AuthorDate: Fri Feb 13 07:40:40 2026 +0800
Refactor part of the pub.Client as common library (#972)
* Refactor part of the pub.Client as common library
---
banyand/internal/storage/segment.go | 8 +-
banyand/internal/storage/tsdb.go | 2 +-
banyand/queue/pub/batch.go | 38 +-
banyand/queue/pub/client.go | 317 +------------
banyand/queue/pub/client_test.go | 10 +-
banyand/queue/pub/pub.go | 246 ++++------
banyand/queue/pub/pub_suite_test.go | 26 +-
banyand/queue/pub/pub_test.go | 26 +-
banyand/queue/pub/pub_tls_test.go | 4 +-
banyand/queue/pub/retry_test.go | 202 --------
.../queue/pub => pkg/grpchelper}/circuitbreaker.go | 46 +-
.../pub => pkg/grpchelper}/circuitbreaker_test.go | 459 +++++++++---------
pkg/grpchelper/connmanager.go | 526 +++++++++++++++++++++
pkg/grpchelper/connmanager_test.go | 115 +++++
pkg/grpchelper/helpers_test.go | 67 +++
{banyand/queue/pub => pkg/grpchelper}/retry.go | 61 ++-
pkg/grpchelper/retry_test.go | 500 ++++++++++++++++++++
17 files changed, 1655 insertions(+), 998 deletions(-)
diff --git a/banyand/internal/storage/segment.go
b/banyand/internal/storage/segment.go
index 74dff1518..2953ce3b5 100644
--- a/banyand/internal/storage/segment.go
+++ b/banyand/internal/storage/segment.go
@@ -74,7 +74,7 @@ type segment[T TSTable, O any] struct {
suffix string
location string
lastAccessed atomic.Int64
- mu sync.Mutex
+ mu sync.RWMutex
refCount int32
mustBeDeleted uint32
id segmentID
@@ -194,6 +194,12 @@ func (s *segment[T, O]) initialize(ctx context.Context)
error {
return nil
}
+func (s *segment[T, O]) collectMetrics() {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ s.index.store.CollectMetrics(s.index.p.SegLabelValues()...)
+}
+
func (s *segment[T, O]) DecRef() {
shouldCleanup := false
diff --git a/banyand/internal/storage/tsdb.go b/banyand/internal/storage/tsdb.go
index f341b5aec..a0ad00e28 100644
--- a/banyand/internal/storage/tsdb.go
+++ b/banyand/internal/storage/tsdb.go
@@ -365,7 +365,7 @@ func (d *database[T, O]) collect() {
for _, t := range tables {
t.Collect(d.segmentController.metrics)
}
- s.index.store.CollectMetrics(s.index.p.SegLabelValues()...)
+ s.collectMetrics()
s.DecRef()
refCount += atomic.LoadInt32(&s.refCount)
}
diff --git a/banyand/queue/pub/batch.go b/banyand/queue/pub/batch.go
index 127bec113..b755d300a 100644
--- a/banyand/queue/pub/batch.go
+++ b/banyand/queue/pub/batch.go
@@ -31,9 +31,17 @@ import (
modelv1
"github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
"github.com/apache/skywalking-banyandb/banyand/queue"
"github.com/apache/skywalking-banyandb/pkg/bus"
+ "github.com/apache/skywalking-banyandb/pkg/grpchelper"
"github.com/apache/skywalking-banyandb/pkg/logger"
)
+const (
+ defaultMaxRetries = 3
+ defaultPerRequestTimeout = 2 * time.Second
+ defaultBackoffBase = 500 * time.Millisecond
+ defaultBackoffMax = 30 * time.Second
+)
+
type writeStream struct {
client clusterv1.Service_SendClient
ctxDoneCh <-chan struct{}
@@ -72,7 +80,7 @@ func (bp *batchPublisher) Publish(ctx context.Context, topic
bus.Topic, messages
node := m.Node()
// Check circuit breaker before attempting send
- if !bp.pub.isRequestAllowed(node) {
+ if !bp.pub.connMgr.IsRequestAllowed(node) {
err = multierr.Append(err, fmt.Errorf("circuit breaker
open for node %s", node))
continue
}
@@ -95,11 +103,11 @@ func (bp *batchPublisher) Publish(ctx context.Context,
topic bus.Topic, messages
if errSend != nil {
err = multierr.Append(err,
fmt.Errorf("failed to send message to node %s: %w", node, errSend))
// Record failure for circuit breaker
(only for transient/internal errors)
- bp.pub.recordFailure(node, errSend)
+ bp.pub.connMgr.RecordFailure(node,
errSend)
return false
}
// Record success for circuit breaker
- bp.pub.recordSuccess(node)
+ bp.pub.connMgr.RecordSuccess(node)
return true
}
return false
@@ -119,14 +127,12 @@ func (bp *batchPublisher) Publish(ctx context.Context,
topic bus.Topic, messages
}
continue
}
- var client *client
+ var nodeClient *client
// nolint: contextcheck
if func() bool {
- bp.pub.mu.RLock()
- defer bp.pub.mu.RUnlock()
- var ok bool
- client, ok = bp.pub.active[node]
- if !ok {
+ var clientOK bool
+ nodeClient, clientOK = bp.pub.connMgr.GetClient(node)
+ if !clientOK {
err = multierr.Append(err, fmt.Errorf("failed
to get client for node %s", node))
return true
}
@@ -147,7 +153,7 @@ func (bp *batchPublisher) Publish(ctx context.Context,
topic bus.Topic, messages
streamCtx, cancel := context.WithTimeout(ctx, bp.timeout)
// this assignment is for getting around the go vet lint
deferFn := cancel
- stream, errCreateStream := client.client.Send(streamCtx)
+ stream, errCreateStream := nodeClient.client.Send(streamCtx)
if errCreateStream != nil {
err = multierr.Append(err, fmt.Errorf("failed to get
stream for node %s: %w", node, errCreateStream))
continue
@@ -171,9 +177,9 @@ func (bp *batchPublisher) Publish(ctx context.Context,
topic bus.Topic, messages
}
resp, errRecv := s.Recv()
if errRecv != nil {
- if isFailoverError(errRecv) {
+ if grpchelper.IsFailoverError(errRecv) {
// Record circuit breaker failure
before creating failover event
- bp.pub.recordFailure(curNode, errRecv)
+ bp.pub.connMgr.RecordFailure(curNode,
errRecv)
bc <- batchEvent{n: curNode, e:
common.NewErrorWithStatus(modelv1.Status_STATUS_INTERNAL_ERROR,
errRecv.Error())}
}
return
@@ -187,7 +193,7 @@ func (bp *batchPublisher) Publish(ctx context.Context,
topic bus.Topic, messages
if isFailoverStatus(resp.Status) {
ce := common.NewErrorWithStatus(resp.Status,
resp.Error)
// Record circuit breaker failure before
creating failover event
- bp.pub.recordFailure(curNode, ce)
+ bp.pub.connMgr.RecordFailure(curNode, ce)
bc <- batchEvent{n: curNode, e: ce}
}
}(stream, deferFn, bp.f.events[len(bp.f.events)-1], nodeName)
@@ -211,7 +217,7 @@ func (bp *batchPublisher) Close() (cee
map[string]*common.Error, err error) {
defer bp.pub.closer.Done()
for n, e := range batchEvents {
// Record circuit breaker failure before
failover
- bp.pub.recordFailure(n, e.e)
+ bp.pub.connMgr.RecordFailure(n, e.e)
if bp.topic == nil {
bp.pub.failover(n, e.e,
data.TopicCommon)
continue
@@ -311,7 +317,7 @@ func (bp *batchPublisher) retrySend(ctx context.Context,
stream clusterv1.Servic
lastErr = sendErr
// Check if error is retryable
- if !isTransientError(sendErr) {
+ if !grpchelper.IsTransientError(sendErr) {
// Non-transient error, don't retry
return sendErr
}
@@ -322,7 +328,7 @@ func (bp *batchPublisher) retrySend(ctx context.Context,
stream clusterv1.Servic
}
// Calculate backoff with jitter
- backoff := jitteredBackoff(defaultBackoffBase,
defaultBackoffMax, attempt, defaultJitterFactor)
+ backoff := grpchelper.JitteredBackoff(defaultBackoffBase,
defaultBackoffMax, attempt, grpchelper.DefaultJitterFactor)
// Sleep with backoff, but respect context cancellation
select {
diff --git a/banyand/queue/pub/client.go b/banyand/queue/pub/client.go
index 488b9f98d..ad4ab9567 100644
--- a/banyand/queue/pub/client.go
+++ b/banyand/queue/pub/client.go
@@ -19,11 +19,9 @@ package pub
import (
"context"
- "fmt"
"time"
"google.golang.org/grpc"
- "google.golang.org/grpc/health/grpc_health_v1"
"github.com/apache/skywalking-banyandb/api/common"
clusterv1
"github.com/apache/skywalking-banyandb/api/proto/banyandb/cluster/v1"
@@ -32,7 +30,6 @@ import (
"github.com/apache/skywalking-banyandb/banyand/metadata/schema"
"github.com/apache/skywalking-banyandb/pkg/bus"
"github.com/apache/skywalking-banyandb/pkg/grpchelper"
- "github.com/apache/skywalking-banyandb/pkg/logger"
)
const (
@@ -118,6 +115,11 @@ type client struct {
md schema.Metadata
}
+// Close implements grpchelper.Client.
+func (*client) Close() error {
+ return nil
+}
+
func (p *pub) OnAddOrUpdate(md schema.Metadata) {
if md.Kind != schema.KindNode {
return
@@ -142,78 +144,7 @@ func (p *pub) OnAddOrUpdate(md schema.Metadata) {
if !okRole {
return
}
-
- address := node.GrpcAddress
- if address == "" {
- p.log.Warn().Stringer("node", node).Msg("grpc address is empty")
- return
- }
- name := node.Metadata.GetName()
- if name == "" {
- p.log.Warn().Stringer("node", node).Msg("node name is empty")
- return
- }
- p.mu.Lock()
- defer p.mu.Unlock()
-
- p.registerNode(node)
-
- if _, ok := p.active[name]; ok {
- return
- }
- if _, ok := p.evictable[name]; ok {
- return
- }
- credOpts, err := p.getClientTransportCredentials()
- if err != nil {
- p.log.Error().Err(err).Msg("failed to load client TLS
credentials")
- return
- }
- conn, err := grpc.NewClient(address, append(credOpts,
- grpc.WithDefaultServiceConfig(p.retryPolicy),
-
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxReceiveMessageSize)))...)
- if err != nil {
- p.log.Error().Err(err).Msg("failed to connect to grpc server")
- return
- }
-
- if !p.checkClientHealthAndReconnect(conn, md) {
- p.log.Info().Str("status", p.dump()).Stringer("node",
node).Msg("node is unhealthy in the register flow, move it to evict queue")
- return
- }
-
- c := clusterv1.NewServiceClient(conn)
- p.active[name] = &client{conn: conn, client: c, md: md}
- p.addClient(md)
- // Initialize or reset circuit breaker state to closed
- p.recordSuccess(name)
- p.log.Info().Str("status", p.dump()).Stringer("node", node).Msg("new
node is healthy, add it to active queue")
-}
-
-func (p *pub) registerNode(node *databasev1.Node) {
- name := node.Metadata.GetName()
- defer func() {
- p.registered[name] = node
- }()
-
- n, ok := p.registered[name]
- if !ok {
- return
- }
- if n.GrpcAddress == node.GrpcAddress {
- return
- }
- if en, ok := p.evictable[name]; ok {
- close(en.c)
- delete(p.evictable, name)
- p.log.Info().Str("node", name).Str("status",
p.dump()).Msg("node is removed from evict queue by the new gRPC address updated
event")
- }
- if client, ok := p.active[name]; ok {
- _ = client.conn.Close()
- delete(p.active, name)
- p.deleteClient(client.md)
- p.log.Info().Str("status", p.dump()).Str("node",
name).Msg("node is removed from active queue by the new gRPC address updated
event")
- }
+ p.connMgr.OnAddOrUpdate(node)
}
func (p *pub) OnDelete(md schema.Metadata) {
@@ -225,163 +156,14 @@ func (p *pub) OnDelete(md schema.Metadata) {
p.log.Warn().Msg("failed to cast node spec")
return
}
- name := node.Metadata.GetName()
- if name == "" {
- p.log.Warn().Stringer("node", node).Msg("node name is empty")
- return
- }
- p.mu.Lock()
- defer p.mu.Unlock()
- delete(p.registered, name)
- if en, ok := p.evictable[name]; ok {
- close(en.c)
- delete(p.evictable, name)
- p.log.Info().Str("status", p.dump()).Stringer("node",
node).Msg("node is removed from evict queue by delete event")
- return
- }
-
- if client, ok := p.active[name]; ok {
- if p.removeNodeIfUnhealthy(md, node, client) {
- p.log.Info().Str("status", p.dump()).Stringer("node",
node).Msg("remove node from active queue by delete event")
- return
- }
- if !p.closer.AddRunning() {
- return
- }
- go func() {
- defer p.closer.Done()
- var elapsed time.Duration
- attempt := 0
- for {
- backoff := jitteredBackoff(initBackoff,
maxBackoff, attempt, defaultJitterFactor)
- select {
- case <-time.After(backoff):
- if func() bool {
- elapsed += backoff
- p.mu.Lock()
- defer p.mu.Unlock()
- if _, ok := p.registered[name];
ok {
- // The client has been
added back to registered clients map, just return
- return true
- }
- if p.removeNodeIfUnhealthy(md,
node, client) {
-
p.log.Info().Str("status", p.dump()).Stringer("node", node).Dur("after",
elapsed).Msg("remove node from active queue by delete event")
- return true
- }
- return false
- }() {
- return
- }
- case <-p.closer.CloseNotify():
- return
- }
- attempt++
- }
- }()
- }
-}
-
-func (p *pub) removeNodeIfUnhealthy(md schema.Metadata, node *databasev1.Node,
client *client) bool {
- if p.healthCheck(node.String(), client.conn) {
- return false
- }
- _ = client.conn.Close()
- name := node.Metadata.GetName()
- delete(p.active, name)
- p.deleteClient(md)
- return true
-}
-
-func (p *pub) checkClientHealthAndReconnect(conn *grpc.ClientConn, md
schema.Metadata) bool {
- node, ok := md.Spec.(*databasev1.Node)
- if !ok {
- logger.Panicf("failed to cast node spec")
- return false
- }
- if p.healthCheck(node.String(), conn) {
- return true
- }
- _ = conn.Close()
- if !p.closer.AddRunning() {
- return false
- }
- name := node.Metadata.Name
- p.evictable[name] = evictNode{n: node, c: make(chan struct{})}
- p.deleteClient(md)
- go func(p *pub, name string, en evictNode, md schema.Metadata) {
- defer p.closer.Done()
- attempt := 0
- for {
- backoff := jitteredBackoff(initBackoff, maxBackoff,
attempt, defaultJitterFactor)
- select {
- case <-time.After(backoff):
- credOpts, errEvict :=
p.getClientTransportCredentials()
- if errEvict != nil {
- p.log.Error().Err(errEvict).Msg("failed
to load client TLS credentials (evict)")
- return
- }
- connEvict, errEvict :=
grpc.NewClient(node.GrpcAddress, append(credOpts,
-
grpc.WithDefaultServiceConfig(p.retryPolicy),
-
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxReceiveMessageSize)))...)
- if errEvict == nil &&
p.healthCheck(en.n.String(), connEvict) {
- func() {
- p.mu.Lock()
- defer p.mu.Unlock()
- if _, ok := p.evictable[name];
!ok {
- // The client has been
removed from evict clients map, just return
- return
- }
- c :=
clusterv1.NewServiceClient(connEvict)
- p.active[name] = &client{conn:
connEvict, client: c, md: md}
- p.addClient(md)
- delete(p.evictable, name)
- // Reset circuit breaker state
to closed
- p.recordSuccess(name)
- p.log.Info().Str("status",
p.dump()).Stringer("node", en.n).Msg("node is healthy, move it back to active
queue")
- }()
- return
- }
- _ = connEvict.Close()
- if _, ok := p.registered[name]; !ok {
- return
- }
- p.log.Error().Err(errEvict).Msgf("failed to
re-connect to grpc server %s after waiting for %s", node.GrpcAddress, backoff)
- case <-en.c:
- return
- case <-p.closer.CloseNotify():
- return
- }
- attempt++
- }
- }(p, name, p.evictable[name], md)
- return false
-}
-
-func (p *pub) healthCheck(node string, conn *grpc.ClientConn) bool {
- var resp *grpc_health_v1.HealthCheckResponse
- if err := grpchelper.Request(context.Background(), rpcTimeout,
func(rpcCtx context.Context) (err error) {
- resp, err = grpc_health_v1.NewHealthClient(conn).Check(rpcCtx,
- &grpc_health_v1.HealthCheckRequest{
- Service: "",
- })
- return err
- }); err != nil {
- if e := p.log.Debug(); e.Enabled() {
- e.Err(err).Str("node", node).Msg("service unhealthy")
- }
- return false
- }
- if resp.GetStatus() == grpc_health_v1.HealthCheckResponse_SERVING {
- return true
- }
- return false
+ p.connMgr.OnDelete(node)
}
func (p *pub) checkServiceHealth(svc string, conn *grpc.ClientConn)
*common.Error {
- client := clusterv1.NewServiceClient(conn)
+ serviceClient := clusterv1.NewServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), rpcTimeout)
defer cancel()
- resp, err := client.HealthCheck(ctx, &clusterv1.HealthCheckRequest{
+ resp, err := serviceClient.HealthCheck(ctx,
&clusterv1.HealthCheckRequest{
ServiceName: svc,
})
if err != nil {
@@ -394,27 +176,11 @@ func (p *pub) checkServiceHealth(svc string, conn
*grpc.ClientConn) *common.Erro
}
func (p *pub) failover(node string, ce *common.Error, topic bus.Topic) {
- p.mu.Lock()
- defer p.mu.Unlock()
if ce.Status() != modelv1.Status_STATUS_INTERNAL_ERROR {
_, _ = p.checkWritable(node, topic)
return
}
- if en, evictable := p.evictable[node]; evictable {
- if _, registered := p.registered[node]; !registered {
- close(en.c)
- delete(p.evictable, node)
- p.log.Info().Str("node", node).Str("status",
p.dump()).Msg("node is removed from evict queue by wire event")
- }
- return
- }
-
- if client, ok := p.active[node]; ok &&
!p.checkClientHealthAndReconnect(client.conn, client.md) {
- _ = client.conn.Close()
- delete(p.active, node)
- p.deleteClient(client.md)
- p.log.Info().Str("status", p.dump()).Str("node",
node).Msg("node is unhealthy in the failover flow, move it to evict queue")
- }
+ p.connMgr.FailoverNode(node)
}
func (p *pub) checkWritable(n string, topic bus.Topic) (bool, *common.Error) {
@@ -422,16 +188,16 @@ func (p *pub) checkWritable(n string, topic bus.Topic)
(bool, *common.Error) {
if !ok {
return false, nil
}
- node, ok := p.active[n]
+ c, ok := p.connMgr.GetClient(n)
if !ok {
return false, nil
}
topicStr := topic.String()
- err := p.checkServiceHealth(topicStr, node.conn)
+ err := p.checkServiceHealth(topicStr, c.conn)
if err == nil {
return true, nil
}
- h.OnDelete(node.md)
+ h.OnDelete(c.md)
if !p.closer.AddRunning() {
return false, err
}
@@ -461,28 +227,18 @@ func (p *pub) checkWritable(n string, topic bus.Topic)
(bool, *common.Error) {
}()
attempt := 0
for {
- backoff := jitteredBackoff(initBackoff, maxBackoff,
attempt, defaultJitterFactor)
+ backoff :=
grpchelper.JitteredBackoff(grpchelper.InitBackoff, grpchelper.MaxBackoff,
attempt, grpchelper.DefaultJitterFactor)
select {
case <-time.After(backoff):
- p.mu.RLock()
- nodeCur, okCur := p.active[nodeName]
- p.mu.RUnlock()
+ nodeCur, okCur := p.connMgr.GetClient(nodeName)
if !okCur {
return
}
errInternal := p.checkServiceHealth(t,
nodeCur.conn)
if errInternal == nil {
- func() {
- p.mu.Lock()
- defer p.mu.Unlock()
- nodeCur, okCur :=
p.active[nodeName]
- if !okCur {
- return
- }
- // Record success for circuit
breaker
- p.recordSuccess(nodeName)
- h.OnAddOrUpdate(nodeCur.md)
- }()
+ // Record success for circuit breaker
+ p.connMgr.RecordSuccess(nodeName)
+ h.OnAddOrUpdate(nodeCur.md)
return
}
p.log.Warn().Str("topic",
t).Err(errInternal).Str("node", nodeName).Dur("backoff", backoff).Msg("data
node can not ingest data")
@@ -494,40 +250,3 @@ func (p *pub) checkWritable(n string, topic bus.Topic)
(bool, *common.Error) {
}(n, topicStr)
return false, err
}
-
-func (p *pub) deleteClient(md schema.Metadata) {
- if len(p.handlers) > 0 {
- for _, h := range p.handlers {
- h.OnDelete(md)
- }
- }
-}
-
-func (p *pub) addClient(md schema.Metadata) {
- if len(p.handlers) > 0 {
- for _, h := range p.handlers {
- h.OnAddOrUpdate(md)
- }
- }
-}
-
-func (p *pub) dump() string {
- keysRegistered := make([]string, 0, len(p.registered))
- for k := range p.registered {
- keysRegistered = append(keysRegistered, k)
- }
- keysActive := make([]string, 0, len(p.active))
- for k := range p.active {
- keysActive = append(keysActive, k)
- }
- keysEvictable := make([]string, 0, len(p.evictable))
- for k := range p.evictable {
- keysEvictable = append(keysEvictable, k)
- }
- return fmt.Sprintf("registered: %v, active :%v, evictable :%v",
keysRegistered, keysActive, keysEvictable)
-}
-
-type evictNode struct {
- n *databasev1.Node
- c chan struct{}
-}
diff --git a/banyand/queue/pub/client_test.go b/banyand/queue/pub/client_test.go
index 0b797cb18..84807010c 100644
--- a/banyand/queue/pub/client_test.go
+++ b/banyand/queue/pub/client_test.go
@@ -81,9 +81,7 @@ var _ = ginkgo.Describe("publish clients
register/unregister", func() {
closeFn := setup(addr1, codes.OK, 200*time.Millisecond)
defer closeFn()
gomega.Eventually(func() int {
- p.mu.RLock()
- defer p.mu.RUnlock()
- return len(p.active)
+ return p.connMgr.ActiveCount()
}, flags.EventuallyTimeout).Should(gomega.Equal(1))
verifyClients(p, 1, 0, 1, 1)
})
@@ -171,10 +169,8 @@ func verifyClients(p *pub, active, evict, onAdd, onDelete
int) {
}
func verifyClientsWithGomega(g gomega.Gomega, p *pub, topic bus.Topic, active,
evict, onAdd, onDelete int) {
- p.mu.RLock()
- defer p.mu.RUnlock()
- g.Expect(len(p.active)).Should(gomega.Equal(active))
- g.Expect(len(p.evictable)).Should(gomega.Equal(evict))
+ g.Expect(p.connMgr.ActiveCount()).Should(gomega.Equal(active))
+ g.Expect(p.connMgr.EvictableCount()).Should(gomega.Equal(evict))
for t, eh := range p.handlers {
if topic != data.TopicCommon && t != topic {
continue
diff --git a/banyand/queue/pub/pub.go b/banyand/queue/pub/pub.go
index 2ce624ab5..05043ac0b 100644
--- a/banyand/queue/pub/pub.go
+++ b/banyand/queue/pub/pub.go
@@ -20,18 +20,17 @@ package pub
import (
"context"
+ "errors"
"fmt"
"io"
"strings"
"sync"
"time"
- "github.com/pkg/errors"
+ pkgerrors "github.com/pkg/errors"
"go.uber.org/multierr"
"google.golang.org/grpc"
- "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
- "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/apache/skywalking-banyandb/api/common"
@@ -63,6 +62,8 @@ var (
_ run.PreRunner = (*pub)(nil)
_ run.Service = (*pub)(nil)
_ run.Config = (*pub)(nil)
+
+ _ grpchelper.ConnectionHandler[*client] = (*pub)(nil)
)
type pub struct {
@@ -70,23 +71,61 @@ type pub struct {
metadata metadata.Repo
handlers map[bus.Topic]schema.EventHandler
log *logger.Logger
- registered map[string]*databasev1.Node
- active map[string]*client
- evictable map[string]evictNode
+ connMgr *grpchelper.ConnManager[*client]
closer *run.Closer
writableProbe map[string]map[string]struct{}
- cbStates map[string]*circuitState
caCertPath string
caCertReloader *pkgtls.Reloader
prefix string
retryPolicy string
allowedRoles []databasev1.Role
- mu sync.RWMutex
- cbMu sync.RWMutex
writableProbeMu sync.Mutex
tlsEnabled bool
}
+// AddressOf implements grpchelper.ConnectionHandler.
+func (p *pub) AddressOf(node *databasev1.Node) string {
+ return node.GrpcAddress
+}
+
+// GetDialOptions implements grpchelper.ConnectionHandler.
+func (p *pub) GetDialOptions() ([]grpc.DialOption, error) {
+ return p.getClientTransportCredentials()
+}
+
+// NewClient implements grpchelper.ConnectionHandler.
+func (p *pub) NewClient(conn *grpc.ClientConn, node *databasev1.Node)
(*client, error) {
+ md := schema.Metadata{
+ TypeMeta: schema.TypeMeta{
+ Name: node.Metadata.GetName(),
+ Kind: schema.KindNode,
+ },
+ Spec: node,
+ }
+ return &client{
+ client: clusterv1.NewServiceClient(conn),
+ conn: conn,
+ md: md,
+ }, nil
+}
+
+// OnActive implements grpchelper.ConnectionHandler.
+func (p *pub) OnActive(_ string, c *client) {
+ for _, h := range p.handlers {
+ h.OnAddOrUpdate(c.md)
+ }
+}
+
+// OnInactive implements grpchelper.ConnectionHandler.
+func (p *pub) OnInactive(name string, c *client) {
+ for _, h := range p.handlers {
+ h.OnDelete(c.md)
+ }
+ p.writableProbeMu.Lock()
+ delete(p.writableProbe, name)
+ p.writableProbeMu.Unlock()
+}
+
func (p *pub) FlagSet() *run.FlagSet {
prefixFlag := func(name string) string {
if p.prefix == "" {
@@ -113,23 +152,12 @@ func (p *pub) Register(topic bus.Topic, handler
schema.EventHandler) {
}
func (p *pub) GracefulStop() {
- // Stop CA certificate reloader if enabled
if p.caCertReloader != nil {
p.caCertReloader.Stop()
}
-
- p.mu.Lock()
- defer p.mu.Unlock()
- for i := range p.evictable {
- close(p.evictable[i].c)
- }
- p.evictable = nil
p.closer.Done()
p.closer.CloseThenWait()
- for _, c := range p.active {
- _ = c.conn.Close()
- }
- p.active = nil
+ p.connMgr.GracefulStop()
}
// Serve implements run.Service.
@@ -153,7 +181,7 @@ func (p *pub) Serve() run.StopNotify {
select {
case <-certUpdateCh:
p.log.Info().Msg("CA
certificate updated, reconnecting clients")
- p.reconnectAllClients()
+ p.connMgr.ReconnectAll()
case <-stopCh:
return
}
@@ -171,14 +199,7 @@ var bypassMatches = []MatchFunc{bypassMatch}
func bypassMatch(_ map[string]string) bool { return true }
func (p *pub) Broadcast(timeout time.Duration, topic bus.Topic, messages
bus.Message) ([]bus.Future, error) {
- var nodes []*databasev1.Node
- p.mu.RLock()
- for k := range p.active {
- if n := p.registered[k]; n != nil {
- nodes = append(nodes, n)
- }
- }
- p.mu.RUnlock()
+ nodes := p.connMgr.ActiveRegisteredNodes()
if len(nodes) == 0 {
return nil, errors.New("no active nodes")
}
@@ -219,7 +240,6 @@ func (p *pub) Broadcast(timeout time.Duration, topic
bus.Topic, messages bus.Mes
if len(names) == 0 {
return nil, fmt.Errorf("no nodes match the selector %v",
messages.NodeSelectors())
}
-
futureCh := make(chan publishResult, len(names))
var wg sync.WaitGroup
for n := range names {
@@ -238,8 +258,8 @@ func (p *pub) Broadcast(timeout time.Duration, topic
bus.Topic, messages bus.Mes
var errs error
for f := range futureCh {
if f.e != nil {
- errs = multierr.Append(errs, errors.Wrapf(f.e, "failed
to publish message to %s", f.n))
- if isFailoverError(f.e) {
+ errs = multierr.Append(errs, pkgerrors.Wrapf(f.e,
"failed to publish message to %s", f.n))
+ if grpchelper.IsFailoverError(f.e) {
if p.closer.AddRunning() {
go func() {
defer p.closer.Done()
@@ -275,37 +295,25 @@ func (p *pub) publish(timeout time.Duration, topic
bus.Topic, messages ...bus.Me
return multierr.Append(err, fmt.Errorf("failed to
marshal message[%d]: %w", m.ID(), errSend))
}
node := m.Node()
-
- // Check circuit breaker before attempting send
- if !p.isRequestAllowed(node) {
- return multierr.Append(err, fmt.Errorf("circuit breaker
open for node %s", node))
- }
-
- p.mu.RLock()
- client, ok := p.active[node]
- p.mu.RUnlock()
- if !ok {
- return multierr.Append(err, fmt.Errorf("failed to get
client for node %s", node))
- }
- ctx, cancel := context.WithTimeout(context.Background(),
timeout)
- f.cancelFn = append(f.cancelFn, cancel)
- stream, errCreateStream := client.client.Send(ctx)
- if errCreateStream != nil {
- // Record failure for circuit breaker (only for
transient/internal errors)
- p.recordFailure(node, errCreateStream)
- return multierr.Append(err, fmt.Errorf("failed to get
stream for node %s: %w", node, errCreateStream))
- }
- errSend = stream.Send(r)
- if errSend != nil {
- // Record failure for circuit breaker (only for
transient/internal errors)
- p.recordFailure(node, errSend)
- return multierr.Append(err, fmt.Errorf("failed to send
message to node %s: %w", node, errSend))
+ execErr := p.connMgr.Execute(node, func(c *client) error {
+ ctx, cancel :=
context.WithTimeout(context.Background(), timeout)
+ f.cancelFn = append(f.cancelFn, cancel)
+ stream, errCreateStream := c.client.Send(ctx)
+ if errCreateStream != nil {
+ // Record failure for circuit breaker (only for
transient/internal errors)
+ return fmt.Errorf("failed to get stream for
node %s: %w", node, errCreateStream)
+ }
+ if sendErr := stream.Send(r); sendErr != nil {
+ return fmt.Errorf("failed to send message to
node %s: %w", node, sendErr)
+ }
+ f.clients = append(f.clients, stream)
+ f.topics = append(f.topics, topic)
+ f.nodes = append(f.nodes, node)
+ return nil
+ })
+ if execErr != nil {
+ err = multierr.Append(err, execErr)
}
- // Record success for circuit breaker
- p.recordSuccess(node)
- f.clients = append(f.clients, stream)
- f.topics = append(f.topics, topic)
- f.nodes = append(f.nodes, node)
return err
}
for _, m := range messages {
@@ -322,35 +330,7 @@ func (p *pub) Publish(_ context.Context, topic bus.Topic,
messages ...bus.Messag
// GetRouteTable implements RouteTableProvider interface.
// Returns a RouteTable with all registered nodes and their health states.
func (p *pub) GetRouteTable() *databasev1.RouteTable {
- p.mu.RLock()
- defer p.mu.RUnlock()
-
- registered := make([]*databasev1.Node, 0, len(p.registered))
- for _, node := range p.registered {
- if node != nil {
- registered = append(registered, node)
- }
- }
-
- active := make([]string, 0, len(p.active))
- for nodeID := range p.active {
- if node := p.registered[nodeID]; node != nil && node.Metadata
!= nil {
- active = append(active, node.Metadata.Name)
- }
- }
-
- evictable := make([]string, 0, len(p.evictable))
- for nodeID := range p.evictable {
- if node := p.registered[nodeID]; node != nil && node.Metadata
!= nil {
- evictable = append(evictable, node.Metadata.Name)
- }
- }
-
- return &databasev1.RouteTable{
- Registered: registered,
- Active: active,
- Evictable: evictable,
- }
+ return p.connMgr.GetRouteTable()
}
// New returns a new queue client targeting the given node roles.
@@ -372,15 +352,11 @@ func New(metadata metadata.Repo, roles
...databasev1.Role) queue.Client {
}
p := &pub{
metadata: metadata,
- active: make(map[string]*client),
- evictable: make(map[string]evictNode),
- registered: make(map[string]*databasev1.Node),
handlers: make(map[bus.Topic]schema.EventHandler),
closer: run.NewCloser(1),
allowedRoles: roles,
prefix: strBuilder.String(),
writableProbe: make(map[string]map[string]struct{}),
- cbStates: make(map[string]*circuitState),
retryPolicy: retryPolicy,
}
return p
@@ -389,7 +365,14 @@ func New(metadata metadata.Repo, roles ...databasev1.Role)
queue.Client {
// NewWithoutMetadata returns a new queue client without metadata, defaulting
to data nodes.
func NewWithoutMetadata() queue.Client {
p := New(nil, databasev1.Role_ROLE_DATA)
- p.(*pub).log = logger.GetLogger("queue-client")
+ pp := p.(*pub)
+ pp.log = logger.GetLogger("queue-client")
+ pp.connMgr =
grpchelper.NewConnManager(grpchelper.ConnManagerConfig[*client]{
+ Handler: pp,
+ Logger: pp.log,
+ RetryPolicy: pp.retryPolicy,
+ MaxRecvMsgSize: maxReceiveMessageSize,
+ })
return p
}
@@ -404,12 +387,20 @@ func (p *pub) PreRun(context.Context) error {
p.log = logger.GetLogger("server-queue-pub-" + p.prefix)
+ // Initialize connection manager with the pub as the handler
+ p.connMgr =
grpchelper.NewConnManager(grpchelper.ConnManagerConfig[*client]{
+ Handler: p,
+ Logger: p.log,
+ RetryPolicy: p.retryPolicy,
+ MaxRecvMsgSize: maxReceiveMessageSize,
+ })
+
// Initialize CA certificate reloader if TLS is enabled and CA cert
path is provided
if p.tlsEnabled && p.caCertPath != "" {
var err error
p.caCertReloader, err =
pkgtls.NewClientCertReloader(p.caCertPath, p.log)
if err != nil {
- return errors.Wrapf(err, "failed to initialize CA
certificate reloader for %s", p.prefix)
+ return pkgerrors.Wrapf(err, "failed to initialize CA
certificate reloader for %s", p.prefix)
}
p.log.Info().Str("caCertPath", p.caCertPath).Msg("Initialized
CA certificate reloader")
}
@@ -514,14 +505,6 @@ func (l *future) GetAll() ([]bus.Message, error) {
}
}
-func isFailoverError(err error) bool {
- s, ok := status.FromError(err)
- if !ok {
- return false
- }
- return s.Code() == codes.Unavailable || s.Code() ==
codes.DeadlineExceeded
-}
-
func (p *pub) getClientTransportCredentials() ([]grpc.DialOption, error) {
if !p.tlsEnabled {
return grpchelper.SecureOptions(nil, false, false, "")
@@ -547,39 +530,6 @@ func (p *pub) getClientTransportCredentials()
([]grpc.DialOption, error) {
return opts, nil
}
-// reconnectAllClients reconnects all active clients when CA certificate is
updated.
-func (p *pub) reconnectAllClients() {
- // Collect nodes and close connections
- p.mu.Lock()
- nodesToReconnect := make([]schema.Metadata, 0, len(p.registered))
- for name, node := range p.registered {
- // Handle evictable nodes: close channel and remove from
evictable
- if en, ok := p.evictable[name]; ok {
- close(en.c)
- delete(p.evictable, name)
- }
- // Handle active nodes: close connection and remove from active
- if client, ok := p.active[name]; ok {
- _ = client.conn.Close()
- delete(p.active, name)
- p.deleteClient(client.md)
- }
- md := schema.Metadata{
- TypeMeta: schema.TypeMeta{
- Kind: schema.KindNode,
- },
- Spec: node,
- }
- nodesToReconnect = append(nodesToReconnect, md)
- }
- p.mu.Unlock()
-
- // Reconnect with new credentials
- for _, md := range nodesToReconnect {
- p.OnAddOrUpdate(md)
- }
-}
-
// NewChunkedSyncClient implements queue.Client.
func (p *pub) NewChunkedSyncClient(node string, chunkSize uint32)
(queue.ChunkedSyncClient, error) {
return p.NewChunkedSyncClientWithConfig(node, &ChunkedSyncClientConfig{
@@ -592,9 +542,7 @@ func (p *pub) NewChunkedSyncClient(node string, chunkSize
uint32) (queue.Chunked
// NewChunkedSyncClientWithConfig creates a chunked sync client with advanced
configuration.
func (p *pub) NewChunkedSyncClientWithConfig(node string, config
*ChunkedSyncClientConfig) (queue.ChunkedSyncClient, error) {
- p.mu.RLock()
- client, ok := p.active[node]
- p.mu.RUnlock()
+ c, ok := p.connMgr.GetClient(node)
if !ok {
return nil, fmt.Errorf("no active client for node %s", node)
}
@@ -602,10 +550,9 @@ func (p *pub) NewChunkedSyncClientWithConfig(node string,
config *ChunkedSyncCli
if config.ChunkSize == 0 {
config.ChunkSize = defaultChunkSize
}
-
return &chunkedSyncClient{
- client: client.client,
- conn: client.conn,
+ client: c.client,
+ conn: c.conn,
node: node,
log: p.log,
chunkSize: config.ChunkSize,
@@ -615,12 +562,5 @@ func (p *pub) NewChunkedSyncClientWithConfig(node string,
config *ChunkedSyncCli
// HealthyNodes returns a list of node names that are currently healthy and
connected.
func (p *pub) HealthyNodes() []string {
- p.mu.RLock()
- defer p.mu.RUnlock()
-
- nodes := make([]string, 0, len(p.active))
- for name := range p.active {
- nodes = append(nodes, name)
- }
- return nodes
+ return p.connMgr.ActiveNames()
}
diff --git a/banyand/queue/pub/pub_suite_test.go
b/banyand/queue/pub/pub_suite_test.go
index fab13eb4d..91479eb20 100644
--- a/banyand/queue/pub/pub_suite_test.go
+++ b/banyand/queue/pub/pub_suite_test.go
@@ -42,6 +42,7 @@ import (
modelv1
"github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
"github.com/apache/skywalking-banyandb/banyand/metadata/schema"
"github.com/apache/skywalking-banyandb/pkg/bus"
+ "github.com/apache/skywalking-banyandb/pkg/grpchelper"
"github.com/apache/skywalking-banyandb/pkg/logger"
"github.com/apache/skywalking-banyandb/pkg/test"
"github.com/apache/skywalking-banyandb/pkg/test/flags"
@@ -214,18 +215,30 @@ func (m *mockHandler) OnDelete(_ schema.Metadata) {
m.deleteCount++
}
+func initConnMgr(pp *pub) {
+ pp.connMgr =
grpchelper.NewConnManager(grpchelper.ConnManagerConfig[*client]{
+ Handler: pp,
+ Logger: pp.log,
+ RetryPolicy: pp.retryPolicy,
+ MaxRecvMsgSize: maxReceiveMessageSize,
+ })
+}
+
func newPub(roles ...databasev1.Role) *pub {
p := New(nil, roles...)
- p.(*pub).log = logger.GetLogger("queue-client")
+ pp := p.(*pub)
+ pp.log = logger.GetLogger("queue-client")
+ initConnMgr(pp)
p.Register(data.TopicStreamWrite, &mockHandler{})
p.Register(data.TopicMeasureWrite, &mockHandler{})
- return p.(*pub)
+ return pp
}
// newPubWithNoRetry creates a pub with a retry policy that doesn't retry
Unavailable errors.
func newPubWithNoRetry(roles ...databasev1.Role) *pub {
p := New(nil, roles...)
- p.(*pub).log = logger.GetLogger("queue-client")
+ pp := p.(*pub)
+ pp.log = logger.GetLogger("queue-client")
p.Register(data.TopicStreamWrite, &mockHandler{})
p.Register(data.TopicMeasureWrite, &mockHandler{})
@@ -244,10 +257,9 @@ func newPubWithNoRetry(roles ...databasev1.Role) *pub {
}
}
]}`
-
- // Store the original retry policy and replace it
- p.(*pub).retryPolicy = noRetryPolicy
- return p.(*pub)
+ pp.retryPolicy = noRetryPolicy
+ initConnMgr(pp)
+ return pp
}
func getDataNode(name string, address string) schema.Metadata {
diff --git a/banyand/queue/pub/pub_test.go b/banyand/queue/pub/pub_test.go
index d6c8237df..680f3eee6 100644
--- a/banyand/queue/pub/pub_test.go
+++ b/banyand/queue/pub/pub_test.go
@@ -112,17 +112,13 @@ var _ = ginkgo.Describe("Publish and Broadcast", func() {
gomega.Expect(cee).Should(gomega.HaveLen(1))
gomega.Expect(cee).Should(gomega.HaveKey("node2"))
gomega.Eventually(func() int {
- p.mu.RLock()
- defer p.mu.RUnlock()
- return len(p.active)
+ return p.connMgr.ActiveCount()
}, flags.EventuallyTimeout).Should(gomega.Equal(1))
- func() {
- p.mu.RLock()
- defer p.mu.RUnlock()
-
gomega.Expect(p.evictable).Should(gomega.HaveLen(1))
-
gomega.Expect(p.evictable).Should(gomega.HaveKey("node2"))
-
gomega.Expect(p.active).Should(gomega.HaveKey("node1"))
- }()
+
gomega.Expect(p.connMgr.EvictableCount()).Should(gomega.Equal(1))
+ _, node1Active := p.connMgr.GetClient("node1")
+ gomega.Expect(node1Active).Should(gomega.BeTrue())
+ _, node2Active := p.connMgr.GetClient("node2")
+ gomega.Expect(node2Active).Should(gomega.BeFalse())
})
ginkgo.It("should go to evict queue when node's disk is full",
func() {
@@ -155,10 +151,8 @@ var _ = ginkgo.Describe("Publish and Broadcast", func() {
gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
gomega.Expect(cee).Should(gomega.BeNil())
gomega.Consistently(func(g gomega.Gomega) {
- p.mu.RLock()
- defer p.mu.RUnlock()
- g.Expect(p.active).Should(gomega.HaveLen(2))
- g.Expect(p.evictable).Should(gomega.HaveLen(0))
+
g.Expect(p.connMgr.ActiveCount()).Should(gomega.Equal(2))
+
g.Expect(p.connMgr.EvictableCount()).Should(gomega.Equal(0))
}).Should(gomega.Succeed())
})
@@ -192,9 +186,7 @@ var _ = ginkgo.Describe("Publish and Broadcast", func() {
gomega.Expect(cee).Should(gomega.HaveLen(1))
gomega.Expect(cee).Should(gomega.HaveKey("node2"))
gomega.Consistently(func() int {
- p.mu.RLock()
- defer p.mu.RUnlock()
- return len(p.active)
+ return p.connMgr.ActiveCount()
}, "1s").Should(gomega.Equal(2))
})
})
diff --git a/banyand/queue/pub/pub_tls_test.go
b/banyand/queue/pub/pub_tls_test.go
index 0a708a9ba..a11264a1e 100644
--- a/banyand/queue/pub/pub_tls_test.go
+++ b/banyand/queue/pub/pub_tls_test.go
@@ -162,9 +162,7 @@ var _ = ginkgo.Describe("Broadcast over one-way TLS",
func() {
p.OnAddOrUpdate(node)
gomega.Eventually(func() int {
- p.mu.RLock()
- defer p.mu.RUnlock()
- return len(p.active)
+ return p.connMgr.ActiveCount()
}, flags.EventuallyTimeout).Should(gomega.Equal(1))
futures, err := p.Broadcast(
diff --git a/banyand/queue/pub/retry_test.go b/banyand/queue/pub/retry_test.go
index f7a6ca820..b073c3470 100644
--- a/banyand/queue/pub/retry_test.go
+++ b/banyand/queue/pub/retry_test.go
@@ -20,7 +20,6 @@ package pub
import (
"context"
"fmt"
- "math"
"sync"
"testing"
"time"
@@ -98,168 +97,6 @@ func (m *MockSendClient) SetCloseSendFunc(f func() error) {
m.closeSendFunc = f
}
-func TestJitteredBackoff(t *testing.T) {
- tests := []struct {
- name string
- baseBackoff time.Duration
- maxBackoff time.Duration
- attempt int
- jitterFactor float64
- expectedMin time.Duration
- expectedMax time.Duration
- }{
- {
- name: "zero_attempt",
- baseBackoff: 100 * time.Millisecond,
- maxBackoff: 10 * time.Second,
- attempt: 0,
- jitterFactor: 0.2,
- expectedMin: 80 * time.Millisecond,
- expectedMax: 120 * time.Millisecond,
- },
- {
- name: "first_attempt",
- baseBackoff: 100 * time.Millisecond,
- maxBackoff: 10 * time.Second,
- attempt: 1,
- jitterFactor: 0.2,
- expectedMin: 160 * time.Millisecond,
- expectedMax: 240 * time.Millisecond,
- },
- {
- name: "max_backoff_reached",
- baseBackoff: 100 * time.Millisecond,
- maxBackoff: 300 * time.Millisecond,
- attempt: 10,
- jitterFactor: 0.2,
- expectedMin: 240 * time.Millisecond,
- expectedMax: 360 * time.Millisecond,
- },
- {
- name: "no_jitter",
- baseBackoff: 100 * time.Millisecond,
- maxBackoff: 10 * time.Second,
- attempt: 0,
- jitterFactor: 0.0,
- expectedMin: 100 * time.Millisecond,
- expectedMax: 100 * time.Millisecond,
- },
- {
- name: "max_jitter",
- baseBackoff: 100 * time.Millisecond,
- maxBackoff: 10 * time.Second,
- attempt: 0,
- jitterFactor: 1.0,
- expectedMin: 0,
- expectedMax: 200 * time.Millisecond,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Test multiple times to account for randomness
- for i := 0; i < 100; i++ {
- result := jitteredBackoff(tt.baseBackoff,
tt.maxBackoff, tt.attempt, tt.jitterFactor)
-
- assert.GreaterOrEqual(t, result, tt.expectedMin,
- "result %v should be >= expected min
%v", result, tt.expectedMin)
- assert.LessOrEqual(t, result, tt.expectedMax,
- "result %v should be <= expected max
%v", result, tt.expectedMax)
- }
- })
- }
-}
-
-func TestJitteredBackoffEdgeCases(t *testing.T) {
- tests := []struct {
- name string
- baseBackoff time.Duration
- maxBackoff time.Duration
- attempt int
- jitterFactor float64
- }{
- {
- name: "negative_jitter_factor",
- baseBackoff: 100 * time.Millisecond,
- maxBackoff: 10 * time.Second,
- attempt: 0,
- jitterFactor: -0.5,
- },
- {
- name: "jitter_factor_greater_than_one",
- baseBackoff: 100 * time.Millisecond,
- maxBackoff: 10 * time.Second,
- attempt: 0,
- jitterFactor: 1.5,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Should not panic and should return a reasonable value
- result := jitteredBackoff(tt.baseBackoff,
tt.maxBackoff, tt.attempt, tt.jitterFactor)
- assert.Greater(t, result, time.Duration(0), "result
should be positive")
- assert.LessOrEqual(t, result, tt.maxBackoff, "result
should not exceed max backoff")
- })
- }
-}
-
-func TestIsTransientError(t *testing.T) {
- tests := []struct {
- err error
- name string
- expectRetry bool
- }{
- {
- name: "nil_error",
- err: nil,
- expectRetry: false,
- },
- {
- name: "unavailable_error",
- err: status.Error(codes.Unavailable, "service
unavailable"),
- expectRetry: true,
- },
- {
- name: "deadline_exceeded_error",
- err: status.Error(codes.DeadlineExceeded,
"deadline exceeded"),
- expectRetry: true,
- },
- {
- name: "resource_exhausted_error",
- err: status.Error(codes.ResourceExhausted,
"rate limited"),
- expectRetry: true,
- },
- {
- name: "not_found_error",
- err: status.Error(codes.NotFound, "not found"),
- expectRetry: false,
- },
- {
- name: "permission_denied_error",
- err: status.Error(codes.PermissionDenied,
"permission denied"),
- expectRetry: false,
- },
- {
- name: "invalid_argument_error",
- err: status.Error(codes.InvalidArgument,
"invalid argument"),
- expectRetry: false,
- },
- {
- name: "generic_error",
- err: fmt.Errorf("generic error"),
- expectRetry: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := isTransientError(tt.err)
- assert.Equal(t, tt.expectRetry, result, "transient
error classification mismatch")
- })
- }
-}
-
func TestRetrySendSuccess(t *testing.T) {
ctx := context.Background()
mockStream := NewMockSendClient(ctx)
@@ -485,42 +322,3 @@ func TestRetrySendConcurrency(t *testing.T) {
assert.Equal(t, numGoroutines, successCount, "all goroutines should
succeed eventually (got %d successes, %d failures)", successCount, failureCount)
}
-
-func TestBackoffDistribution(t *testing.T) {
- // Test that jitter produces a reasonable distribution
- const numSamples = 1000
- baseBackoff := 100 * time.Millisecond
- maxBackoff := 10 * time.Second
- attempt := 0
- jitterFactor := 0.2
-
- durations := make([]time.Duration, numSamples)
- for i := 0; i < numSamples; i++ {
- durations[i] = jitteredBackoff(baseBackoff, maxBackoff,
attempt, jitterFactor)
- }
-
- // Calculate mean and standard deviation
- var sum time.Duration
- for _, d := range durations {
- sum += d
- }
- mean := sum / time.Duration(numSamples)
-
- var sumSquaredDiffs float64
- for _, d := range durations {
- diff := float64(d - mean)
- sumSquaredDiffs += diff * diff
- }
- variance := sumSquaredDiffs / float64(numSamples)
- stdDev := time.Duration(math.Sqrt(variance))
-
- // Expected mean should be around baseBackoff (100ms)
- expectedMean := baseBackoff
- assert.InDelta(t, float64(expectedMean), float64(mean),
float64(baseBackoff)/10,
- "mean should be close to base backoff")
-
- // Standard deviation should be reasonable (not too low, indicating
proper jitter)
- minStdDev := time.Duration(float64(baseBackoff) * jitterFactor / 4)
- assert.Greater(t, stdDev, minStdDev,
- "standard deviation should indicate proper jitter is applied")
-}
diff --git a/banyand/queue/pub/circuitbreaker.go
b/pkg/grpchelper/circuitbreaker.go
similarity index 80%
rename from banyand/queue/pub/circuitbreaker.go
rename to pkg/grpchelper/circuitbreaker.go
index 7bb64e296..3393a9eb5 100644
--- a/banyand/queue/pub/circuitbreaker.go
+++ b/pkg/grpchelper/circuitbreaker.go
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-package pub
+package grpchelper
import (
"time"
@@ -45,13 +45,13 @@ type circuitState struct {
halfOpenProbeInFlight bool
}
-// isRequestAllowed checks if a request to the given node is allowed based on
circuit breaker state.
+// IsRequestAllowed checks if a request to the given node is allowed based on
circuit breaker state.
// It also handles state transitions from Open to Half-Open when cooldown
expires.
-func (p *pub) isRequestAllowed(node string) bool {
- p.cbMu.Lock()
- defer p.cbMu.Unlock()
+func (m *ConnManager[C]) IsRequestAllowed(node string) bool {
+ m.cbMu.Lock()
+ defer m.cbMu.Unlock()
- cb, exists := p.cbStates[node]
+ cb, exists := m.cbStates[node]
if !exists {
return true // No circuit breaker state, allow request
}
@@ -65,7 +65,7 @@ func (p *pub) isRequestAllowed(node string) bool {
// Transition to Half-Open to allow a single probe
request
cb.state = StateHalfOpen
cb.halfOpenProbeInFlight = true // Set token for the
probe
- p.log.Info().Str("node", node).Msg("circuit breaker
transitioned to half-open")
+ m.log.Info().Str("node", node).Msg("circuit breaker
transitioned to half-open")
return true
}
return false // Still in cooldown period
@@ -83,23 +83,21 @@ func (p *pub) isRequestAllowed(node string) bool {
}
}
-// recordSuccess resets the circuit breaker state to Closed on successful
operation.
+// RecordSuccess resets the circuit breaker state to Closed on successful
operation.
// This handles Half-Open -> Closed transitions.
-func (p *pub) recordSuccess(node string) {
- p.cbMu.Lock()
- defer p.cbMu.Unlock()
-
- cb, exists := p.cbStates[node]
+func (m *ConnManager[C]) RecordSuccess(node string) {
+ m.cbMu.Lock()
+ defer m.cbMu.Unlock()
+ cb, exists := m.cbStates[node]
if !exists {
// Initialize circuit breaker state
- p.cbStates[node] = &circuitState{
+ m.cbStates[node] = &circuitState{
state: StateClosed,
consecutiveFailures: 0,
}
return
}
- // Reset to closed state
cb.state = StateClosed
cb.consecutiveFailures = 0
cb.lastFailureTime = time.Time{}
@@ -107,17 +105,17 @@ func (p *pub) recordSuccess(node string) {
cb.halfOpenProbeInFlight = false // Clear probe token
}
-// recordFailure updates the circuit breaker state on failed operation.
+// RecordFailure updates the circuit breaker state on failed operation.
// Only records failures for transient/internal errors that should count
toward opening the circuit.
-func (p *pub) recordFailure(node string, err error) {
+func (m *ConnManager[C]) RecordFailure(node string, err error) {
// Only record failure if the error is transient or internal
- if !isTransientError(err) && !isInternalError(err) {
+ if !IsTransientError(err) && !IsInternalError(err) {
return
}
- p.cbMu.Lock()
- defer p.cbMu.Unlock()
+ m.cbMu.Lock()
+ defer m.cbMu.Unlock()
- cb, exists := p.cbStates[node]
+ cb, exists := m.cbStates[node]
if !exists {
// Initialize circuit breaker state
cb = &circuitState{
@@ -125,7 +123,7 @@ func (p *pub) recordFailure(node string, err error) {
consecutiveFailures: 1,
lastFailureTime: time.Now(),
}
- p.cbStates[node] = cb
+ m.cbStates[node] = cb
} else {
cb.consecutiveFailures++
cb.lastFailureTime = time.Now()
@@ -136,12 +134,12 @@ func (p *pub) recordFailure(node string, err error) {
if cb.consecutiveFailures >= threshold && cb.state == StateClosed {
cb.state = StateOpen
cb.openTime = time.Now()
- p.log.Warn().Str("node", node).Int("failures",
cb.consecutiveFailures).Msg("circuit breaker opened")
+ m.log.Warn().Str("node", node).Int("failures",
cb.consecutiveFailures).Msg("circuit breaker opened")
} else if cb.state == StateHalfOpen {
// Failed during half-open, go back to open
cb.state = StateOpen
cb.openTime = time.Now()
cb.halfOpenProbeInFlight = false // Clear probe token
- p.log.Warn().Str("node", node).Msg("circuit breaker reopened
after half-open failure")
+ m.log.Warn().Str("node", node).Msg("circuit breaker reopened
after half-open failure")
}
}
diff --git a/banyand/queue/pub/circuitbreaker_test.go
b/pkg/grpchelper/circuitbreaker_test.go
similarity index 64%
rename from banyand/queue/pub/circuitbreaker_test.go
rename to pkg/grpchelper/circuitbreaker_test.go
index 8bc902230..dda620cdc 100644
--- a/banyand/queue/pub/circuitbreaker_test.go
+++ b/pkg/grpchelper/circuitbreaker_test.go
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-package pub
+package grpchelper
import (
"sync"
@@ -41,96 +41,110 @@ var (
errNonRetry = status.Error(codes.InvalidArgument, "invalid argument")
)
+func init() {
+ _ = logger.Init(logger.Logging{
+ Env: "dev",
+ Level: "warn",
+ })
+}
+
+func newTestConnManager() *ConnManager[*mockClient] {
+ return NewConnManager(ConnManagerConfig[*mockClient]{
+ Handler: &mockHandler{},
+ Logger: logger.GetLogger("test-cb"),
+ })
+}
+
func TestCircuitBreakerStateTransitions(t *testing.T) {
tests := []struct {
name string
- setup func(*pub)
- actions []func(*pub, string)
+ setup func(*ConnManager[*mockClient])
+ actions []func(*ConnManager[*mockClient], string)
expectedState CircuitState
allowsRequests bool
}{
{
name: "closed_to_open_after_failures",
- setup: func(p *pub) {
- p.recordSuccess(testNodeName)
+ setup: func(m *ConnManager[*mockClient]) {
+ m.RecordSuccess(testNodeName)
},
- actions: []func(*pub, string){
- func(p *pub, node string) {
p.recordFailure(node, errTransient) },
- func(p *pub, node string) {
p.recordFailure(node, errTransient) },
- func(p *pub, node string) {
p.recordFailure(node, errTransient) },
- func(p *pub, node string) {
p.recordFailure(node, errTransient) },
- func(p *pub, node string) {
p.recordFailure(node, errTransient) },
+ actions: []func(*ConnManager[*mockClient], string){
+ func(m *ConnManager[*mockClient], node string)
{ m.RecordFailure(node, errTransient) },
+ func(m *ConnManager[*mockClient], node string)
{ m.RecordFailure(node, errTransient) },
+ func(m *ConnManager[*mockClient], node string)
{ m.RecordFailure(node, errTransient) },
+ func(m *ConnManager[*mockClient], node string)
{ m.RecordFailure(node, errTransient) },
+ func(m *ConnManager[*mockClient], node string)
{ m.RecordFailure(node, errTransient) },
},
expectedState: StateOpen,
allowsRequests: false,
},
{
name: "closed_remains_closed_below_threshold",
- setup: func(p *pub) {
- p.recordSuccess(testNodeName)
+ setup: func(m *ConnManager[*mockClient]) {
+ m.RecordSuccess(testNodeName)
},
- actions: []func(*pub, string){
- func(p *pub, node string) {
p.recordFailure(node, errTransient) },
- func(p *pub, node string) {
p.recordFailure(node, errTransient) },
- func(p *pub, node string) {
p.recordFailure(node, errTransient) },
+ actions: []func(*ConnManager[*mockClient], string){
+ func(m *ConnManager[*mockClient], node string)
{ m.RecordFailure(node, errTransient) },
+ func(m *ConnManager[*mockClient], node string)
{ m.RecordFailure(node, errTransient) },
+ func(m *ConnManager[*mockClient], node string)
{ m.RecordFailure(node, errTransient) },
},
expectedState: StateClosed,
allowsRequests: true,
},
{
name: "open_to_half_open_after_cooldown",
- setup: func(p *pub) {
- p.recordSuccess(testNodeName)
+ setup: func(m *ConnManager[*mockClient]) {
+ m.RecordSuccess(testNodeName)
// Trip the circuit breaker
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(testNodeName,
errTransient)
+ m.RecordFailure(testNodeName,
errTransient)
}
// Simulate cooldown period has passed
- p.cbMu.Lock()
- cb := p.cbStates[testNodeName]
+ m.cbMu.Lock()
+ cb := m.cbStates[testNodeName]
cb.openTime =
time.Now().Add(-defaultCBResetTimeout - time.Second)
- p.cbMu.Unlock()
+ m.cbMu.Unlock()
},
- actions: []func(*pub, string){},
- expectedState: StateOpen, // Will transition to
half-open in isRequestAllowed
+ actions: []func(*ConnManager[*mockClient],
string){},
+ expectedState: StateOpen, // Will transition to
half-open in IsRequestAllowed
allowsRequests: true, // Should allow requests
after cooldown
},
{
name: "half_open_to_closed_on_success",
- setup: func(p *pub) {
- p.recordSuccess(testNodeName)
+ setup: func(m *ConnManager[*mockClient]) {
+ m.RecordSuccess(testNodeName)
// Trip the circuit breaker
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(testNodeName,
errTransient)
+ m.RecordFailure(testNodeName,
errTransient)
}
// Set to half-open state
- p.cbMu.Lock()
- cb := p.cbStates[testNodeName]
+ m.cbMu.Lock()
+ cb := m.cbStates[testNodeName]
cb.state = StateHalfOpen
- p.cbMu.Unlock()
+ m.cbMu.Unlock()
},
- actions: []func(*pub, string){
- func(p *pub, node string) {
p.recordSuccess(node) },
+ actions: []func(*ConnManager[*mockClient], string){
+ func(m *ConnManager[*mockClient], node string)
{ m.RecordSuccess(node) },
},
expectedState: StateClosed,
allowsRequests: true,
},
{
name: "half_open_to_open_on_failure",
- setup: func(p *pub) {
- p.recordSuccess(testNodeName)
+ setup: func(m *ConnManager[*mockClient]) {
+ m.RecordSuccess(testNodeName)
// Trip the circuit breaker
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(testNodeName,
errTransient)
+ m.RecordFailure(testNodeName,
errTransient)
}
// Set to half-open state
- p.cbMu.Lock()
- cb := p.cbStates[testNodeName]
+ m.cbMu.Lock()
+ cb := m.cbStates[testNodeName]
cb.state = StateHalfOpen
- p.cbMu.Unlock()
+ m.cbMu.Unlock()
},
- actions: []func(*pub, string){
- func(p *pub, node string) {
p.recordFailure(node, errTransient) },
+ actions: []func(*ConnManager[*mockClient], string){
+ func(m *ConnManager[*mockClient], node string)
{ m.RecordFailure(node, errTransient) },
},
expectedState: StateOpen,
allowsRequests: false,
@@ -139,43 +153,37 @@ func TestCircuitBreakerStateTransitions(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
+ m := newTestConnManager()
+ defer m.GracefulStop()
// Setup
if tt.setup != nil {
- tt.setup(p)
+ tt.setup(m)
}
// Execute actions
for _, action := range tt.actions {
- action(p, testNodeName)
+ action(m, testNodeName)
}
// Check final state
- p.cbMu.RLock()
- cb, exists := p.cbStates[testNodeName]
- p.cbMu.RUnlock()
+ m.cbMu.RLock()
+ cb, exists := m.cbStates[testNodeName]
+ m.cbMu.RUnlock()
require.True(t, exists, "circuit breaker state should
exist")
assert.Equal(t, tt.expectedState, cb.state, "circuit
breaker state mismatch")
// Check if requests are allowed
- allowed := p.isRequestAllowed(testNodeName)
+ allowed := m.IsRequestAllowed(testNodeName)
assert.Equal(t, tt.allowsRequests, allowed, "request
allowance mismatch")
})
}
}
func TestCircuitBreakerConcurrency(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
+ m := newTestConnManager()
+ defer m.GracefulStop()
const numGoroutines = 100
const numOperations = 50
@@ -189,8 +197,8 @@ func TestCircuitBreakerConcurrency(t *testing.T) {
go func() {
defer wg.Done()
for j := 0; j < numOperations; j++ {
- p.recordSuccess(node)
- p.isRequestAllowed(node)
+ m.RecordSuccess(node)
+ m.IsRequestAllowed(node)
}
}()
}
@@ -200,8 +208,8 @@ func TestCircuitBreakerConcurrency(t *testing.T) {
go func() {
defer wg.Done()
for j := 0; j < numOperations; j++ {
- p.recordFailure(node, errTransient)
- p.isRequestAllowed(node)
+ m.RecordFailure(node, errTransient)
+ m.IsRequestAllowed(node)
}
}()
}
@@ -209,126 +217,115 @@ func TestCircuitBreakerConcurrency(t *testing.T) {
wg.Wait()
// Verify circuit breaker state exists and is in a valid state
- p.cbMu.RLock()
- cb, exists := p.cbStates[node]
- p.cbMu.RUnlock()
+ m.cbMu.RLock()
+ cb, exists := m.cbStates[node]
+ m.cbMu.RUnlock()
require.True(t, exists, "circuit breaker state should exist")
assert.Contains(t, []CircuitState{StateClosed, StateOpen,
StateHalfOpen}, cb.state, "circuit breaker should be in a valid state")
}
func TestCircuitBreakerMultipleNodes(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
+ m := newTestConnManager()
+ defer m.GracefulStop()
nodes := []string{"node1", "node2", "node3"}
// Initialize all nodes
for _, node := range nodes {
- p.recordSuccess(node)
+ m.RecordSuccess(node)
}
// Trip circuit breaker for node1 only
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure("node1", errTransient)
+ m.RecordFailure("node1", errTransient)
}
// Add some failures to node2 but below threshold
for i := 0; i < defaultCBThreshold-1; i++ {
- p.recordFailure("node2", errTransient)
+ m.RecordFailure("node2", errTransient)
}
// Keep node3 healthy
- p.recordSuccess("node3")
+ m.RecordSuccess("node3")
// Verify states
- assert.False(t, p.isRequestAllowed("node1"), "node1 should have circuit
breaker open")
- assert.True(t, p.isRequestAllowed("node2"), "node2 should still allow
requests")
- assert.True(t, p.isRequestAllowed("node3"), "node3 should allow
requests")
+ assert.False(t, m.IsRequestAllowed("node1"), "node1 should have circuit
breaker open")
+ assert.True(t, m.IsRequestAllowed("node2"), "node2 should still allow
requests")
+ assert.True(t, m.IsRequestAllowed("node3"), "node3 should allow
requests")
// Check circuit breaker states
- p.cbMu.RLock()
- defer p.cbMu.RUnlock()
+ m.cbMu.RLock()
+ defer m.cbMu.RUnlock()
- cb1, exists1 := p.cbStates["node1"]
+ cb1, exists1 := m.cbStates["node1"]
require.True(t, exists1)
assert.Equal(t, StateOpen, cb1.state)
- cb2, exists2 := p.cbStates["node2"]
+ cb2, exists2 := m.cbStates["node2"]
require.True(t, exists2)
assert.Equal(t, StateClosed, cb2.state)
- cb3, exists3 := p.cbStates["node3"]
+ cb3, exists3 := m.cbStates["node3"]
require.True(t, exists3)
assert.Equal(t, StateClosed, cb3.state)
}
func TestCircuitBreakerRecoveryAfterCooldown(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
-
+ m := newTestConnManager()
+ defer m.GracefulStop()
node := testNodeName
// Initialize node
- p.recordSuccess(node)
+ m.RecordSuccess(node)
// Trip circuit breaker
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(node, errTransient)
+ m.RecordFailure(node, errTransient)
}
// Verify circuit is open
- assert.False(t, p.isRequestAllowed(node), "circuit should be open")
+ assert.False(t, m.IsRequestAllowed(node), "circuit should be open")
// Simulate cooldown period passage
- p.cbMu.Lock()
- cb := p.cbStates[node]
+ m.cbMu.Lock()
+ cb := m.cbStates[node]
cb.openTime = time.Now().Add(-defaultCBResetTimeout - time.Second)
- p.cbMu.Unlock()
+ m.cbMu.Unlock()
// Check that circuit allows requests (transitions to half-open)
- allowed := p.isRequestAllowed(node)
+ allowed := m.IsRequestAllowed(node)
assert.True(t, allowed, "circuit should allow requests after cooldown")
// Verify state transitioned to half-open
- p.cbMu.RLock()
+ m.cbMu.RLock()
assert.Equal(t, StateHalfOpen, cb.state, "circuit should be in
half-open state")
- p.cbMu.RUnlock()
+ m.cbMu.RUnlock()
// Successful request should close the circuit
- p.recordSuccess(node)
+ m.RecordSuccess(node)
- p.cbMu.RLock()
+ m.cbMu.RLock()
assert.Equal(t, StateClosed, cb.state, "circuit should be closed after
success")
assert.Equal(t, 0, cb.consecutiveFailures, "failure count should be
reset")
- p.cbMu.RUnlock()
+ m.cbMu.RUnlock()
}
func TestCircuitBreakerInitialization(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
-
+ m := newTestConnManager()
+ defer m.GracefulStop()
node := "new-node"
// First request to non-existent circuit breaker should be allowed
- allowed := p.isRequestAllowed(node)
+ allowed := m.IsRequestAllowed(node)
assert.True(t, allowed, "requests should be allowed for non-existent
circuit breaker")
// Record success should initialize the circuit breaker
- p.recordSuccess(node)
+ m.RecordSuccess(node)
- p.cbMu.RLock()
- cb, exists := p.cbStates[node]
- p.cbMu.RUnlock()
+ m.cbMu.RLock()
+ cb, exists := m.cbStates[node]
+ m.cbMu.RUnlock()
require.True(t, exists, "circuit breaker should be initialized")
assert.Equal(t, StateClosed, cb.state, "new circuit breaker should be
closed")
@@ -336,184 +333,164 @@ func TestCircuitBreakerInitialization(t *testing.T) {
}
func TestCircuitBreakerFailureThresholdEdgeCase(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
-
+ m := newTestConnManager()
+ defer m.GracefulStop()
node := testNodeName
// Initialize node
- p.recordSuccess(node)
+ m.RecordSuccess(node)
// Add failures just below threshold
for i := 0; i < defaultCBThreshold-1; i++ {
- p.recordFailure(node, errTransient)
+ m.RecordFailure(node, errTransient)
}
// Circuit should still be closed
- p.cbMu.RLock()
- cb := p.cbStates[node]
+ m.cbMu.RLock()
+ cb := m.cbStates[node]
assert.Equal(t, StateClosed, cb.state, "circuit should still be closed")
assert.Equal(t, defaultCBThreshold-1, cb.consecutiveFailures, "failure
count should be at threshold-1")
- p.cbMu.RUnlock()
+ m.cbMu.RUnlock()
// One more failure should open the circuit
- p.recordFailure(node, errTransient)
+ m.RecordFailure(node, errTransient)
- p.cbMu.RLock()
+ m.cbMu.RLock()
assert.Equal(t, StateOpen, cb.state, "circuit should be open after
reaching threshold")
assert.Equal(t, defaultCBThreshold, cb.consecutiveFailures, "failure
count should be at threshold")
- p.cbMu.RUnlock()
+ m.cbMu.RUnlock()
}
func TestCircuitBreakerSingleProbeEnforcement(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
-
+ m := newTestConnManager()
+ defer m.GracefulStop()
node := testNodeName
// Initialize and trip circuit breaker
- p.recordSuccess(node)
+ m.RecordSuccess(node)
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(node, errTransient)
+ m.RecordFailure(node, errTransient)
}
// Verify circuit is open
- assert.False(t, p.isRequestAllowed(node), "circuit should be open")
+ assert.False(t, m.IsRequestAllowed(node), "circuit should be open")
// Simulate cooldown period passage
- p.cbMu.Lock()
- cb := p.cbStates[node]
+ m.cbMu.Lock()
+ cb := m.cbStates[node]
cb.openTime = time.Now().Add(-defaultCBResetTimeout - time.Second)
- p.cbMu.Unlock()
+ m.cbMu.Unlock()
// First request should transition to half-open and be allowed
- allowed1 := p.isRequestAllowed(node)
+ allowed1 := m.IsRequestAllowed(node)
assert.True(t, allowed1, "first request after cooldown should be
allowed and transition to half-open")
// Verify state is half-open with probe in flight
- p.cbMu.RLock()
+ m.cbMu.RLock()
assert.Equal(t, StateHalfOpen, cb.state, "circuit should be in
half-open state")
assert.True(t, cb.halfOpenProbeInFlight, "probe should be marked as in
flight")
- p.cbMu.RUnlock()
+ m.cbMu.RUnlock()
// Second request should be denied while probe is in flight
- allowed2 := p.isRequestAllowed(node)
+ allowed2 := m.IsRequestAllowed(node)
assert.False(t, allowed2, "second request should be denied while probe
is in flight")
// Third request should also be denied
- allowed3 := p.isRequestAllowed(node)
+ allowed3 := m.IsRequestAllowed(node)
assert.False(t, allowed3, "third request should also be denied while
probe is in flight")
}
func TestCircuitBreakerSingleProbeSuccess(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
-
+ m := newTestConnManager()
+ defer m.GracefulStop()
node := testNodeName
// Setup half-open state with probe in flight
- p.recordSuccess(node)
+ m.RecordSuccess(node)
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(node, errTransient)
+ m.RecordFailure(node, errTransient)
}
- p.cbMu.Lock()
- cb := p.cbStates[node]
+ m.cbMu.Lock()
+ cb := m.cbStates[node]
cb.openTime = time.Now().Add(-defaultCBResetTimeout - time.Second)
- p.cbMu.Unlock()
+ m.cbMu.Unlock()
// Transition to half-open
- allowed := p.isRequestAllowed(node)
+ allowed := m.IsRequestAllowed(node)
assert.True(t, allowed, "first request should be allowed")
// Verify probe is in flight
- p.cbMu.RLock()
+ m.cbMu.RLock()
assert.True(t, cb.halfOpenProbeInFlight, "probe should be in flight")
- p.cbMu.RUnlock()
+ m.cbMu.RUnlock()
// Record success (probe succeeds)
- p.recordSuccess(node)
+ m.RecordSuccess(node)
// Verify circuit is closed and probe token is cleared
- p.cbMu.RLock()
+ m.cbMu.RLock()
assert.Equal(t, StateClosed, cb.state, "circuit should be closed after
successful probe")
assert.False(t, cb.halfOpenProbeInFlight, "probe token should be
cleared")
- p.cbMu.RUnlock()
+ m.cbMu.RUnlock()
// Subsequent requests should be allowed in closed state
- assert.True(t, p.isRequestAllowed(node), "requests should be allowed in
closed state")
+ assert.True(t, m.IsRequestAllowed(node), "requests should be allowed in
closed state")
}
func TestCircuitBreakerSingleProbeFailure(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
-
+ m := newTestConnManager()
+ defer m.GracefulStop()
node := testNodeName
// Setup half-open state with probe in flight
- p.recordSuccess(node)
+ m.RecordSuccess(node)
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(node, errTransient)
+ m.RecordFailure(node, errTransient)
}
- p.cbMu.Lock()
- cb := p.cbStates[node]
+ m.cbMu.Lock()
+ cb := m.cbStates[node]
cb.openTime = time.Now().Add(-defaultCBResetTimeout - time.Second)
- p.cbMu.Unlock()
+ m.cbMu.Unlock()
// Transition to half-open
- allowed := p.isRequestAllowed(node)
+ allowed := m.IsRequestAllowed(node)
assert.True(t, allowed, "first request should be allowed")
// Verify probe is in flight
- p.cbMu.RLock()
+ m.cbMu.RLock()
assert.True(t, cb.halfOpenProbeInFlight, "probe should be in flight")
- p.cbMu.RUnlock()
+ m.cbMu.RUnlock()
// Record failure (probe fails)
- p.recordFailure(node, errTransient)
+ m.RecordFailure(node, errTransient)
// Verify circuit is back to open and probe token is cleared
- p.cbMu.RLock()
+ m.cbMu.RLock()
assert.Equal(t, StateOpen, cb.state, "circuit should be back to open
after failed probe")
assert.False(t, cb.halfOpenProbeInFlight, "probe token should be
cleared")
- p.cbMu.RUnlock()
+ m.cbMu.RUnlock()
// Subsequent requests should be denied in open state
- assert.False(t, p.isRequestAllowed(node), "requests should be denied in
open state")
+ assert.False(t, m.IsRequestAllowed(node), "requests should be denied in
open state")
}
func TestCircuitBreakerConcurrentProbeAttempts(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
-
+ m := newTestConnManager()
+ defer m.GracefulStop()
node := testNodeName
// Setup open state ready for half-open transition
- p.recordSuccess(node)
+ m.RecordSuccess(node)
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(node, errTransient)
+ m.RecordFailure(node, errTransient)
}
- p.cbMu.Lock()
- cb := p.cbStates[node]
+ m.cbMu.Lock()
+ cb := m.cbStates[node]
cb.openTime = time.Now().Add(-defaultCBResetTimeout - time.Second)
- p.cbMu.Unlock()
+ m.cbMu.Unlock()
const numGoroutines = 10
var wg sync.WaitGroup
@@ -524,7 +501,7 @@ func TestCircuitBreakerConcurrentProbeAttempts(t
*testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
- allowed := p.isRequestAllowed(node)
+ allowed := m.IsRequestAllowed(node)
results <- allowed
}()
}
@@ -544,56 +521,48 @@ func TestCircuitBreakerConcurrentProbeAttempts(t
*testing.T) {
assert.Equal(t, 1, allowedCount, "exactly one request should be allowed
in half-open state")
// Verify circuit is in half-open with probe in flight
- p.cbMu.RLock()
+ m.cbMu.RLock()
assert.Equal(t, StateHalfOpen, cb.state, "circuit should be in
half-open state")
assert.True(t, cb.halfOpenProbeInFlight, "probe should be marked as in
flight")
- p.cbMu.RUnlock()
+ m.cbMu.RUnlock()
}
func TestCircuitBreakerProbeTokenInitialization(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
-
+ m := newTestConnManager()
+ defer m.GracefulStop()
node := testNodeName
// Test that new circuit breaker states have probe token cleared
- p.recordSuccess(node)
+ m.RecordSuccess(node)
- p.cbMu.RLock()
- cb := p.cbStates[node]
+ m.cbMu.RLock()
+ cb := m.cbStates[node]
assert.False(t, cb.halfOpenProbeInFlight, "new circuit breaker should
have probe token cleared")
- p.cbMu.RUnlock()
+ m.cbMu.RUnlock()
}
func TestCircuitBreakerErrorFiltering(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
-
+ m := newTestConnManager()
+ defer m.GracefulStop()
node := testNodeName
// Verify initial state is closed
- assert.True(t, p.isRequestAllowed(node))
- p.cbMu.Lock()
- _, exists := p.cbStates[node]
- p.cbMu.Unlock()
+ assert.True(t, m.IsRequestAllowed(node))
+ m.cbMu.Lock()
+ _, exists := m.cbStates[node]
+ m.cbMu.Unlock()
assert.False(t, exists)
// Try to trip circuit with non-retryable error
for i := 0; i < defaultCBThreshold*2; i++ {
- p.recordFailure(node, errNonRetry)
+ m.RecordFailure(node, errNonRetry)
}
// Circuit should still be closed (non-retryable errors don't count)
- assert.True(t, p.isRequestAllowed(node))
- p.cbMu.Lock()
- cb, exists := p.cbStates[node]
- p.cbMu.Unlock()
+ assert.True(t, m.IsRequestAllowed(node))
+ m.cbMu.Lock()
+ cb, exists := m.cbStates[node]
+ m.cbMu.Unlock()
// State should not exist or be closed with 0 failures
if exists {
assert.Equal(t, StateClosed, cb.state)
@@ -602,81 +571,77 @@ func TestCircuitBreakerErrorFiltering(t *testing.T) {
// Now try with transient errors - these should count
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(node, errTransient)
+ m.RecordFailure(node, errTransient)
}
// Circuit should now be open
- assert.False(t, p.isRequestAllowed(node))
- p.cbMu.Lock()
- cb, exists = p.cbStates[node]
- p.cbMu.Unlock()
+ assert.False(t, m.IsRequestAllowed(node))
+ m.cbMu.Lock()
+ cb, exists = m.cbStates[node]
+ m.cbMu.Unlock()
assert.True(t, exists)
assert.Equal(t, StateOpen, cb.state)
assert.Equal(t, defaultCBThreshold, cb.consecutiveFailures)
// Reset for internal error test
- p.recordSuccess(node)
- assert.True(t, p.isRequestAllowed(node))
+ m.RecordSuccess(node)
+ assert.True(t, m.IsRequestAllowed(node))
// Try with internal errors - these should also count
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(node, errInternal)
+ m.RecordFailure(node, errInternal)
}
// Circuit should be open again
- assert.False(t, p.isRequestAllowed(node))
- p.cbMu.Lock()
- cb, exists = p.cbStates[node]
- p.cbMu.Unlock()
+ assert.False(t, m.IsRequestAllowed(node))
+ m.cbMu.Lock()
+ cb, exists = m.cbStates[node]
+ m.cbMu.Unlock()
assert.True(t, exists)
assert.Equal(t, StateOpen, cb.state)
assert.Equal(t, defaultCBThreshold, cb.consecutiveFailures)
}
func TestCircuitBreakerFailoverErrors(t *testing.T) {
- p := &pub{
- cbStates: make(map[string]*circuitState),
- cbMu: sync.RWMutex{},
- log: logger.GetLogger("test"),
- }
-
+ m := newTestConnManager()
+ defer m.GracefulStop()
node := testNodeName
// Test failover errors from Recv() - using codes.Unavailable which is
a failover error
failoverErr := status.Error(codes.Unavailable, "service unavailable")
// Verify initial state is closed
- assert.True(t, p.isRequestAllowed(node))
+ assert.True(t, m.IsRequestAllowed(node))
// Simulate failover errors that should increment circuit breaker
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(node, failoverErr)
+ m.RecordFailure(node, failoverErr)
}
// Circuit should be open after failover errors
- assert.False(t, p.isRequestAllowed(node))
- p.cbMu.Lock()
- cb, exists := p.cbStates[node]
- p.cbMu.Unlock()
+ assert.False(t, m.IsRequestAllowed(node))
+ m.cbMu.Lock()
+ cb, exists := m.cbStates[node]
+ m.cbMu.Unlock()
assert.True(t, exists)
assert.Equal(t, StateOpen, cb.state)
assert.Equal(t, defaultCBThreshold, cb.consecutiveFailures)
// Reset and test with common.Error (simulating Close() method errors)
- p.recordSuccess(node)
- assert.True(t, p.isRequestAllowed(node))
+ m.RecordSuccess(node)
+ assert.True(t, m.IsRequestAllowed(node))
// Test with common.Error types that would come from Close() method
internalErr :=
common.NewErrorWithStatus(modelv1.Status_STATUS_INTERNAL_ERROR, "internal
error")
for i := 0; i < defaultCBThreshold; i++ {
- p.recordFailure(node, internalErr)
+ m.RecordFailure(node, internalErr)
}
// Circuit should be open again
- assert.False(t, p.isRequestAllowed(node))
- p.cbMu.Lock()
- cb, exists = p.cbStates[node]
- p.cbMu.Unlock()
+ assert.False(t, m.IsRequestAllowed(node))
+ m.cbMu.Lock()
+ cb, exists = m.cbStates[node]
+ m.cbMu.Unlock()
assert.True(t, exists)
assert.Equal(t, StateOpen, cb.state)
assert.Equal(t, defaultCBThreshold, cb.consecutiveFailures)
diff --git a/pkg/grpchelper/connmanager.go b/pkg/grpchelper/connmanager.go
new file mode 100644
index 000000000..a480d2ee1
--- /dev/null
+++ b/pkg/grpchelper/connmanager.go
@@ -0,0 +1,526 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package grpchelper
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/health/grpc_health_v1"
+
+ databasev1
"github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
+ "github.com/apache/skywalking-banyandb/pkg/logger"
+ "github.com/apache/skywalking-banyandb/pkg/run"
+)
+
+var (
+ // ErrCircuitBreakerOpen is returned when the circuit breaker is open
for a node.
+ ErrCircuitBreakerOpen = errors.New("circuit breaker open")
+ // ErrClientNotFound is returned when no active client exists for a
node.
+ ErrClientNotFound = errors.New("client not found")
+)
+
+// Client is the minimal interface for a managed gRPC client.
+type Client interface {
+ Close() error
+}
+
+// ConnectionHandler provides upper-layer callbacks for connection lifecycle.
+type ConnectionHandler[C Client] interface {
+ // AddressOf extracts the gRPC address from a node.
+ AddressOf(node *databasev1.Node) string
+ // GetDialOptions returns gRPC dial options for the given address.
+ GetDialOptions() ([]grpc.DialOption, error)
+ // NewClient creates a client from a gRPC connection and node.
+ NewClient(conn *grpc.ClientConn, node *databasev1.Node) (C, error)
+ // OnActive is called when a node transitions to active.
+ OnActive(name string, client C)
+ // OnInactive is called when a node leaves active.
+ OnInactive(name string, client C)
+}
+
+// ConnManagerConfig holds configuration for ConnManager.
+type ConnManagerConfig[C Client] struct {
+ Handler ConnectionHandler[C]
+ Logger *logger.Logger
+ RetryPolicy string
+ ExtraDialOpts []grpc.DialOption
+ MaxRecvMsgSize int
+}
+
+// ConnManager manages gRPC connections with health checking, circuit
breaking, and eviction.
+type ConnManager[C Client] struct {
+ handler ConnectionHandler[C]
+ log *logger.Logger
+ registered map[string]*databasev1.Node
+ active map[string]*managedNode[C]
+ evictable map[string]evictNode
+ cbStates map[string]*circuitState
+ closer *run.Closer
+ dialOpts []grpc.DialOption
+ mu sync.RWMutex
+ cbMu sync.RWMutex
+}
+
+type managedNode[C Client] struct {
+ client C
+ conn *grpc.ClientConn
+ node *databasev1.Node
+}
+
+type evictNode struct {
+ n *databasev1.Node
+ c chan struct{}
+}
+
+// NewConnManager creates a new ConnManager.
+func NewConnManager[C Client](cfg ConnManagerConfig[C]) *ConnManager[C] {
+ var dialOpts []grpc.DialOption
+ if cfg.RetryPolicy != "" {
+ dialOpts = append(dialOpts,
grpc.WithDefaultServiceConfig(cfg.RetryPolicy))
+ }
+ if cfg.MaxRecvMsgSize > 0 {
+ dialOpts = append(dialOpts,
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(cfg.MaxRecvMsgSize)))
+ }
+ dialOpts = append(dialOpts, cfg.ExtraDialOpts...)
+ m := &ConnManager[C]{
+ handler: cfg.Handler,
+ log: cfg.Logger,
+ registered: make(map[string]*databasev1.Node),
+ active: make(map[string]*managedNode[C]),
+ evictable: make(map[string]evictNode),
+ cbStates: make(map[string]*circuitState),
+ closer: run.NewCloser(1),
+ dialOpts: dialOpts,
+ }
+ return m
+}
+
+// OnAddOrUpdate registers or updates a node and manages its connection.
+func (m *ConnManager[C]) OnAddOrUpdate(node *databasev1.Node) {
+ address := m.handler.AddressOf(node)
+ if address == "" {
+ m.log.Warn().Stringer("node", node).Msg("grpc address is empty")
+ return
+ }
+ name := node.Metadata.GetName()
+ if name == "" {
+ m.log.Warn().Stringer("node", node).Msg("node name is empty")
+ return
+ }
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.registerNode(node)
+
+ if _, ok := m.active[name]; ok {
+ return
+ }
+ if _, ok := m.evictable[name]; ok {
+ return
+ }
+ credOpts, dialErr := m.handler.GetDialOptions()
+ if dialErr != nil {
+ m.log.Error().Err(dialErr).Msg("failed to load client TLS
credentials")
+ return
+ }
+ allOpts := make([]grpc.DialOption, 0, len(credOpts)+len(m.dialOpts))
+ allOpts = append(allOpts, credOpts...)
+ allOpts = append(allOpts, m.dialOpts...)
+ conn, connErr := grpc.NewClient(address, allOpts...)
+ if connErr != nil {
+ m.log.Error().Err(connErr).Msg("failed to connect to grpc
server")
+ return
+ }
+
+ client, clientErr := m.handler.NewClient(conn, node)
+ if clientErr != nil {
+ m.log.Error().Err(clientErr).Msg("failed to create client")
+ _ = conn.Close()
+ return
+ }
+ if !m.checkHealthAndReconnect(conn, node, client) {
+ m.log.Info().Str("status", m.dump()).Stringer("node",
node).Msg("node is unhealthy in the register flow, move it to evict queue")
+ return
+ }
+ m.active[name] = &managedNode[C]{conn: conn, client: client, node: node}
+ m.handler.OnActive(name, client)
+ // Initialize or reset circuit breaker state to closed
+ m.RecordSuccess(name)
+ m.log.Info().Str("status", m.dump()).Stringer("node", node).Msg("new
node is healthy, add it to active queue")
+}
+
+// OnDelete removes a node and its connection.
+func (m *ConnManager[C]) OnDelete(node *databasev1.Node) {
+ name := node.Metadata.GetName()
+ if name == "" {
+ m.log.Warn().Stringer("node", node).Msg("node name is empty")
+ return
+ }
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ delete(m.registered, name)
+ if en, ok := m.evictable[name]; ok {
+ close(en.c)
+ delete(m.evictable, name)
+ m.log.Info().Str("status", m.dump()).Stringer("node",
node).Msg("node is removed from evict queue by delete event")
+ return
+ }
+ if mn, ok := m.active[name]; ok {
+ if m.removeNodeIfUnhealthy(name, mn) {
+ m.log.Info().Str("status", m.dump()).Stringer("node",
node).Msg("remove node from active queue by delete event")
+ return
+ }
+ if !m.closer.AddRunning() {
+ return
+ }
+ go func() {
+ defer m.closer.Done()
+ var elapsed time.Duration
+ attempt := 0
+ for {
+ backoff := JitteredBackoff(InitBackoff,
MaxBackoff, attempt, DefaultJitterFactor)
+ select {
+ case <-time.After(backoff):
+ if func() bool {
+ elapsed += backoff
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if _, ok := m.registered[name];
ok {
+ return true
+ }
+ if
m.removeNodeIfUnhealthy(name, mn) {
+
m.log.Info().Str("status", m.dump()).Stringer("node", node).Dur("after",
elapsed).Msg("remove node from active queue by delete event")
+ return true
+ }
+ return false
+ }() {
+ return
+ }
+ case <-m.closer.CloseNotify():
+ return
+ }
+ attempt++
+ }
+ }()
+ }
+}
+
+// GetClient returns the client for the given node name.
+func (m *ConnManager[C]) GetClient(name string) (C, bool) {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ mn, ok := m.active[name]
+ if !ok {
+ var zero C
+ return zero, false
+ }
+ return mn.client, true
+}
+
+// Execute checks the circuit breaker, gets the client, calls fn, and records
success or failure.
+func (m *ConnManager[C]) Execute(node string, fn func(C) error) error {
+ if !m.IsRequestAllowed(node) {
+ return fmt.Errorf("%w for node %s", ErrCircuitBreakerOpen, node)
+ }
+ c, ok := m.GetClient(node)
+ if !ok {
+ return fmt.Errorf("%w for node %s", ErrClientNotFound, node)
+ }
+ if callbackErr := fn(c); callbackErr != nil {
+ m.RecordFailure(node, callbackErr)
+ return callbackErr
+ }
+ m.RecordSuccess(node)
+ return nil
+}
+
+// ActiveNames returns the names of all active nodes.
+func (m *ConnManager[C]) ActiveNames() []string {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ names := make([]string, 0, len(m.active))
+ for name := range m.active {
+ names = append(names, name)
+ }
+ return names
+}
+
+// ActiveCount returns the number of active nodes.
+func (m *ConnManager[C]) ActiveCount() int {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return len(m.active)
+}
+
+// EvictableCount returns the number of evictable nodes.
+func (m *ConnManager[C]) EvictableCount() int {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ return len(m.evictable)
+}
+
+// ActiveRegisteredNodes returns the registered node info for all active nodes.
+func (m *ConnManager[C]) ActiveRegisteredNodes() []*databasev1.Node {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ var nodes []*databasev1.Node
+ for k := range m.active {
+ if n := m.registered[k]; n != nil {
+ nodes = append(nodes, n)
+ }
+ }
+ return nodes
+}
+
+// GetRouteTable returns a snapshot of registered, active, and evictable node
info.
+func (m *ConnManager[C]) GetRouteTable() *databasev1.RouteTable {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ registered := make([]*databasev1.Node, 0, len(m.registered))
+ for _, node := range m.registered {
+ if node != nil {
+ registered = append(registered, node)
+ }
+ }
+ activeNames := make([]string, 0, len(m.active))
+ for nodeID := range m.active {
+ if node := m.registered[nodeID]; node != nil && node.Metadata
!= nil {
+ activeNames = append(activeNames, node.Metadata.Name)
+ }
+ }
+ evictableNames := make([]string, 0, len(m.evictable))
+ for nodeID := range m.evictable {
+ if node := m.registered[nodeID]; node != nil && node.Metadata
!= nil {
+ evictableNames = append(evictableNames,
node.Metadata.Name)
+ }
+ }
+ return &databasev1.RouteTable{
+ Registered: registered,
+ Active: activeNames,
+ Evictable: evictableNames,
+ }
+}
+
+// FailoverNode checks health for a node and moves it to evictable if
unhealthy.
+func (m *ConnManager[C]) FailoverNode(node string) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if en, ok := m.evictable[node]; ok {
+ if _, registered := m.registered[node]; !registered {
+ close(en.c)
+ delete(m.evictable, node)
+ m.log.Info().Str("node", node).Str("status",
m.dump()).Msg("node is removed from evict queue by wire event")
+ }
+ return
+ }
+ if mn, ok := m.active[node]; ok && !m.checkHealthAndReconnect(mn.conn,
mn.node, mn.client) {
+ _ = mn.conn.Close()
+ delete(m.active, node)
+ m.handler.OnInactive(node, mn.client)
+ m.log.Info().Str("status", m.dump()).Str("node",
node).Msg("node is unhealthy in the failover flow, move it to evict queue")
+ }
+}
+
+// ReconnectAll closes all active and evictable connections and re-registers
all nodes.
+func (m *ConnManager[C]) ReconnectAll() {
+ m.mu.Lock()
+ nodesToReconnect := make([]*databasev1.Node, 0, len(m.registered))
+ for name, node := range m.registered {
+ if en, ok := m.evictable[name]; ok {
+ close(en.c)
+ delete(m.evictable, name)
+ }
+ if mn, ok := m.active[name]; ok {
+ _ = mn.conn.Close()
+ delete(m.active, name)
+ m.handler.OnInactive(name, mn.client)
+ }
+ nodesToReconnect = append(nodesToReconnect, node)
+ }
+ m.mu.Unlock()
+ for _, node := range nodesToReconnect {
+ m.OnAddOrUpdate(node)
+ }
+}
+
+// GracefulStop closes all connections and stops background goroutines.
+func (m *ConnManager[C]) GracefulStop() {
+ m.mu.Lock()
+ for idx := range m.evictable {
+ close(m.evictable[idx].c)
+ }
+ m.evictable = nil
+ m.mu.Unlock()
+ m.closer.Done()
+ m.closer.CloseThenWait()
+ m.mu.Lock()
+ for _, mn := range m.active {
+ _ = mn.conn.Close()
+ }
+ m.active = nil
+ m.mu.Unlock()
+}
+
+func (m *ConnManager[C]) registerNode(node *databasev1.Node) {
+ name := node.Metadata.GetName()
+ defer func() {
+ m.registered[name] = node
+ }()
+
+ n, ok := m.registered[name]
+ if !ok {
+ return
+ }
+ if m.handler.AddressOf(n) == m.handler.AddressOf(node) {
+ return
+ }
+ if en, ok := m.evictable[name]; ok {
+ close(en.c)
+ delete(m.evictable, name)
+ m.log.Info().Str("node", name).Str("status",
m.dump()).Msg("node is removed from evict queue by the new gRPC address updated
event")
+ }
+ if mn, ok := m.active[name]; ok {
+ _ = mn.conn.Close()
+ delete(m.active, name)
+ m.handler.OnInactive(name, mn.client)
+ m.log.Info().Str("status", m.dump()).Str("node",
name).Msg("node is removed from active queue by the new gRPC address updated
event")
+ }
+}
+
+func (m *ConnManager[C]) removeNodeIfUnhealthy(name string, mn
*managedNode[C]) bool {
+ if m.healthCheck(mn.node.String(), mn.conn) {
+ return false
+ }
+ _ = mn.conn.Close()
+ delete(m.active, name)
+ m.handler.OnInactive(name, mn.client)
+ return true
+}
+
+// checkHealthAndReconnect checks if a node is healthy. If not, closes the
conn,
+// adds to evictable, calls OnInactive, and starts a retry goroutine.
+// Returns true if healthy.
+func (m *ConnManager[C]) checkHealthAndReconnect(conn *grpc.ClientConn, node
*databasev1.Node, client C) bool {
+ if m.healthCheck(node.String(), conn) {
+ return true
+ }
+ _ = conn.Close()
+ if !m.closer.AddRunning() {
+ return false
+ }
+ name := node.Metadata.Name
+ m.evictable[name] = evictNode{n: node, c: make(chan struct{})}
+ m.handler.OnInactive(name, client)
+ go func(name string, en evictNode) {
+ defer m.closer.Done()
+ attempt := 0
+ for {
+ backoff := JitteredBackoff(InitBackoff, MaxBackoff,
attempt, DefaultJitterFactor)
+ select {
+ case <-time.After(backoff):
+ address := m.handler.AddressOf(en.n)
+ credOpts, errEvict := m.handler.GetDialOptions()
+ if errEvict != nil {
+ m.log.Error().Err(errEvict).Msg("failed
to load client TLS credentials (evict)")
+ return
+ }
+ allOpts := make([]grpc.DialOption, 0,
len(credOpts)+len(m.dialOpts))
+ allOpts = append(allOpts, credOpts...)
+ allOpts = append(allOpts, m.dialOpts...)
+ connEvict, errEvict := grpc.NewClient(address,
allOpts...)
+ if errEvict == nil &&
m.healthCheck(en.n.String(), connEvict) {
+ func() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if _, ok := m.evictable[name];
!ok {
+ // The client has been
removed from evict clients map, just return
+ return
+ }
+ newClient, clientErr :=
m.handler.NewClient(connEvict, en.n)
+ if clientErr != nil {
+
m.log.Error().Err(clientErr).Msg("failed to create client during reconnect")
+ _ = connEvict.Close()
+ return
+ }
+ m.active[name] =
&managedNode[C]{conn: connEvict, client: newClient, node: en.n}
+ m.handler.OnActive(name,
newClient)
+ delete(m.evictable, name)
+ m.RecordSuccess(name)
+ m.log.Info().Str("status",
m.dump()).Stringer("node", en.n).Msg("node is healthy, move it back to active
queue")
+ }()
+ return
+ }
+ if connEvict != nil {
+ _ = connEvict.Close()
+ }
+ if func() bool {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ _, ok := m.registered[name]
+ return !ok
+ }() {
+ return
+ }
+ m.log.Error().Err(errEvict).Msgf("failed to
re-connect to grpc server %s after waiting for %s", address, backoff)
+ case <-en.c:
+ return
+ case <-m.closer.CloseNotify():
+ return
+ }
+ attempt++
+ }
+ }(name, m.evictable[name])
+ return false
+}
+
+func (m *ConnManager[C]) healthCheck(node string, conn *grpc.ClientConn) bool {
+ var resp *grpc_health_v1.HealthCheckResponse
+ if requestErr := Request(context.Background(), 2*time.Second,
func(rpcCtx context.Context) (err error) {
+ resp, err = grpc_health_v1.NewHealthClient(conn).Check(rpcCtx,
+ &grpc_health_v1.HealthCheckRequest{
+ Service: "",
+ })
+ return err
+ }); requestErr != nil {
+ if e := m.log.Debug(); e.Enabled() {
+ e.Err(requestErr).Str("node", node).Msg("service
unhealthy")
+ }
+ return false
+ }
+ return resp.GetStatus() == grpc_health_v1.HealthCheckResponse_SERVING
+}
+
+func (m *ConnManager[C]) dump() string {
+ keysRegistered := make([]string, 0, len(m.registered))
+ for k := range m.registered {
+ keysRegistered = append(keysRegistered, k)
+ }
+ keysActive := make([]string, 0, len(m.active))
+ for k := range m.active {
+ keysActive = append(keysActive, k)
+ }
+ keysEvictable := make([]string, 0, len(m.evictable))
+ for k := range m.evictable {
+ keysEvictable = append(keysEvictable, k)
+ }
+ return fmt.Sprintf("registered: %v, active :%v, evictable :%v",
keysRegistered, keysActive, keysEvictable)
+}
diff --git a/pkg/grpchelper/connmanager_test.go
b/pkg/grpchelper/connmanager_test.go
new file mode 100644
index 000000000..f85303fbf
--- /dev/null
+++ b/pkg/grpchelper/connmanager_test.go
@@ -0,0 +1,115 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package grpchelper
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+)
+
+func addTestClient(m *ConnManager[*mockClient], node string) {
+ m.mu.Lock()
+ m.active[node] = &managedNode[*mockClient]{client: &mockClient{}}
+ m.mu.Unlock()
+}
+
+func removeTestClient(m *ConnManager[*mockClient], node string) {
+ m.mu.Lock()
+ delete(m.active, node)
+ m.mu.Unlock()
+}
+
+func TestExecute_Success(t *testing.T) {
+ m := newTestConnManager()
+ node := "exec-success"
+ addTestClient(m, node)
+ defer func() {
+ removeTestClient(m, node)
+ m.GracefulStop()
+ }()
+ m.RecordSuccess(node)
+
+ called := false
+ execErr := m.Execute(node, func(_ *mockClient) error {
+ called = true
+ return nil
+ })
+ require.NoError(t, execErr)
+ assert.True(t, called)
+ assert.True(t, m.IsRequestAllowed(node), "circuit breaker should still
allow requests after success")
+}
+
+func TestExecute_CallbackFailure(t *testing.T) {
+ m := newTestConnManager()
+ node := "exec-fail"
+ addTestClient(m, node)
+ defer func() {
+ removeTestClient(m, node)
+ m.GracefulStop()
+ }()
+ m.RecordSuccess(node)
+
+ cbErr := status.Error(codes.Unavailable, "callback error")
+ for i := 0; i < defaultCBThreshold; i++ {
+ execErr := m.Execute(node, func(_ *mockClient) error {
+ return cbErr
+ })
+ require.Error(t, execErr)
+ }
+ assert.False(t, m.IsRequestAllowed(node), "circuit breaker should open
after repeated failures")
+}
+
+func TestExecute_CircuitBreakerOpen(t *testing.T) {
+ m := newTestConnManager()
+ node := "exec-cb-open"
+ addTestClient(m, node)
+ defer func() {
+ removeTestClient(m, node)
+ m.GracefulStop()
+ }()
+
+ for i := 0; i < defaultCBThreshold; i++ {
+ m.RecordFailure(node, errTransient)
+ }
+
+ execErr := m.Execute(node, func(_ *mockClient) error {
+ t.Fatal("callback should not be called when circuit breaker is
open")
+ return nil
+ })
+ require.Error(t, execErr)
+ assert.ErrorIs(t, execErr, ErrCircuitBreakerOpen)
+ assert.Contains(t, execErr.Error(), node)
+}
+
+func TestExecute_ClientNotFound(t *testing.T) {
+ m := newTestConnManager()
+ defer m.GracefulStop()
+ node := "exec-no-client"
+
+ execErr := m.Execute(node, func(_ *mockClient) error {
+ t.Fatal("callback should not be called when client is not
found")
+ return nil
+ })
+ require.Error(t, execErr)
+ assert.ErrorIs(t, execErr, ErrClientNotFound)
+ assert.Contains(t, execErr.Error(), node)
+}
diff --git a/pkg/grpchelper/helpers_test.go b/pkg/grpchelper/helpers_test.go
new file mode 100644
index 000000000..1c6c9815c
--- /dev/null
+++ b/pkg/grpchelper/helpers_test.go
@@ -0,0 +1,67 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package grpchelper
+
+import (
+ "google.golang.org/grpc"
+
+ databasev1
"github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
+)
+
+// mockClient implements the Client interface for tests.
+type mockClient struct {
+ closed bool
+}
+
+func (m *mockClient) Close() error {
+ m.closed = true
+ return nil
+}
+
+// mockHandler implements ConnectionHandler[*mockClient] for tests.
+type mockHandler struct {
+ activeNodes map[string]*mockClient
+ inactiveNodes map[string]*mockClient
+}
+
+func (h *mockHandler) AddressOf(node *databasev1.Node) string {
+ return node.GetGrpcAddress()
+}
+
+func (h *mockHandler) GetDialOptions() ([]grpc.DialOption, error) {
+ return nil, nil
+}
+
+func (h *mockHandler) NewClient(_ *grpc.ClientConn, _ *databasev1.Node)
(*mockClient, error) {
+ return &mockClient{}, nil
+}
+
+func (h *mockHandler) OnActive(name string, client *mockClient) {
+ if h.activeNodes == nil {
+ h.activeNodes = make(map[string]*mockClient)
+ }
+ h.activeNodes[name] = client
+}
+
+func (h *mockHandler) OnInactive(name string, client *mockClient) {
+ if h.inactiveNodes == nil {
+ h.inactiveNodes = make(map[string]*mockClient)
+ }
+ h.inactiveNodes[name] = client
+ delete(h.activeNodes, name)
+}
diff --git a/banyand/queue/pub/retry.go b/pkg/grpchelper/retry.go
similarity index 72%
rename from banyand/queue/pub/retry.go
rename to pkg/grpchelper/retry.go
index 1adfa0e2b..2478a4c3d 100644
--- a/banyand/queue/pub/retry.go
+++ b/pkg/grpchelper/retry.go
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-package pub
+package grpchelper
import (
"crypto/rand"
@@ -30,18 +30,14 @@ import (
modelv1
"github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
)
-const (
- defaultJitterFactor = 0.2
- defaultMaxRetries = 3
- defaultPerRequestTimeout = 2 * time.Second
- defaultBackoffBase = 500 * time.Millisecond
- defaultBackoffMax = 30 * time.Second
-)
+// DefaultJitterFactor is the default jitter factor for backoff calculations.
+const DefaultJitterFactor = 0.2
var (
- // Retry policy for health check.
- initBackoff = time.Second
- maxBackoff = 20 * time.Second
+ // InitBackoff is the initial backoff duration for health check retries.
+ InitBackoff = time.Second
+ // MaxBackoff is the maximum backoff duration for health check retries.
+ MaxBackoff = 20 * time.Second
// Retryable gRPC status codes for streaming send retries.
retryableCodes = map[codes.Code]bool{
@@ -77,9 +73,9 @@ func secureRandFloat64() float64 {
return float64(n.Uint64()) / float64(1<<53)
}
-// jitteredBackoff calculates backoff duration with jitter to avoid thundering
herds.
+// JitteredBackoff calculates backoff duration with jitter to avoid thundering
herds.
// Uses bounded symmetric jitter: backoff * (1 + jitter * (rand() - 0.5) * 2).
-func jitteredBackoff(baseBackoff, maxBackoff time.Duration, attempt int,
jitterFactor float64) time.Duration {
+func JitteredBackoff(baseBackoff, maxBackoff time.Duration, attempt int,
jitterFactor float64) time.Duration {
if jitterFactor < 0 {
jitterFactor = 0
}
@@ -114,15 +110,29 @@ func jitteredBackoff(baseBackoff, maxBackoff
time.Duration, attempt int, jitterF
return jitteredDuration
}
-// isTransientError checks if the error is considered transient and retryable.
-func isTransientError(err error) bool {
+// statusCodeFromError extracts the gRPC status code from an error,
+// supporting both direct gRPC status errors and wrapped errors.
+func statusCodeFromError(err error) (codes.Code, bool) {
+ if s, ok := status.FromError(err); ok {
+ return s.Code(), true
+ }
+ type grpcStatusProvider interface{ GRPCStatus() *status.Status }
+ var se grpcStatusProvider
+ if errors.As(err, &se) {
+ return se.GRPCStatus().Code(), true
+ }
+ return codes.OK, false
+}
+
+// IsTransientError checks if the error is considered transient and retryable.
+func IsTransientError(err error) bool {
if err == nil {
return false
}
// Handle gRPC status errors
- if s, ok := status.FromError(err); ok {
- return retryableCodes[s.Code()]
+ if code, ok := statusCodeFromError(err); ok {
+ return retryableCodes[code]
}
// Handle common.Error types
@@ -140,15 +150,24 @@ func isTransientError(err error) bool {
return false
}
-// isInternalError checks if the error is an internal server error.
-func isInternalError(err error) bool {
+// IsFailoverError checks if the error indicates the node should be failed
over.
+func IsFailoverError(err error) bool {
+ code, ok := statusCodeFromError(err)
+ if !ok {
+ return false
+ }
+ return code == codes.Unavailable || code == codes.DeadlineExceeded
+}
+
+// IsInternalError checks if the error is an internal server error.
+func IsInternalError(err error) bool {
if err == nil {
return false
}
// Handle gRPC status errors
- if s, ok := status.FromError(err); ok {
- return s.Code() == codes.Internal
+ if code, ok := statusCodeFromError(err); ok {
+ return code == codes.Internal
}
// Handle common.Error types
diff --git a/pkg/grpchelper/retry_test.go b/pkg/grpchelper/retry_test.go
new file mode 100644
index 000000000..d13962640
--- /dev/null
+++ b/pkg/grpchelper/retry_test.go
@@ -0,0 +1,500 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package grpchelper
+
+import (
+ "errors"
+ "fmt"
+ "testing"
+ "time"
+
+ pkgerrors "github.com/pkg/errors"
+ "github.com/stretchr/testify/assert"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
+)
+
+// customUnwrapError wraps a single error via Unwrap() error.
+type customUnwrapError struct {
+ cause error
+ msg string
+}
+
+func (e *customUnwrapError) Error() string { return e.msg + ": " +
e.cause.Error() }
+func (e *customUnwrapError) Unwrap() error { return e.cause }
+
+// multiUnwrapError wraps multiple errors via Unwrap() []error.
+type multiUnwrapError struct {
+ msg string
+ causes []error
+}
+
+func (e *multiUnwrapError) Error() string { return e.msg }
+func (e *multiUnwrapError) Unwrap() []error { return e.causes }
+
+func TestJitteredBackoff(t *testing.T) {
+ tests := []struct {
+ name string
+ baseBackoff time.Duration
+ maxBackoff time.Duration
+ attempt int
+ jitterFactor float64
+ expectedMin time.Duration
+ expectedMax time.Duration
+ }{
+ {
+ name: "zero_attempt",
+ baseBackoff: 100 * time.Millisecond,
+ maxBackoff: 10 * time.Second,
+ attempt: 0,
+ jitterFactor: 0.2,
+ expectedMin: 80 * time.Millisecond,
+ expectedMax: 120 * time.Millisecond,
+ },
+ {
+ name: "first_attempt",
+ baseBackoff: 100 * time.Millisecond,
+ maxBackoff: 10 * time.Second,
+ attempt: 1,
+ jitterFactor: 0.2,
+ expectedMin: 160 * time.Millisecond,
+ expectedMax: 240 * time.Millisecond,
+ },
+ {
+ name: "max_backoff_reached",
+ baseBackoff: 100 * time.Millisecond,
+ maxBackoff: 300 * time.Millisecond,
+ attempt: 10,
+ jitterFactor: 0.2,
+ expectedMin: 240 * time.Millisecond,
+ expectedMax: 300 * time.Millisecond,
+ },
+ {
+ name: "no_jitter",
+ baseBackoff: 100 * time.Millisecond,
+ maxBackoff: 10 * time.Second,
+ attempt: 0,
+ jitterFactor: 0.0,
+ expectedMin: 100 * time.Millisecond,
+ expectedMax: 100 * time.Millisecond,
+ },
+ {
+ name: "max_jitter",
+ baseBackoff: 100 * time.Millisecond,
+ maxBackoff: 10 * time.Second,
+ attempt: 0,
+ jitterFactor: 1.0,
+ expectedMin: 0,
+ expectedMax: 200 * time.Millisecond,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Test multiple times to account for randomness
+ for i := 0; i < 100; i++ {
+ result := JitteredBackoff(tt.baseBackoff,
tt.maxBackoff, tt.attempt, tt.jitterFactor)
+
+ assert.GreaterOrEqual(t, result, tt.expectedMin,
+ "result %v should be >= expected min
%v", result, tt.expectedMin)
+ assert.LessOrEqual(t, result, tt.expectedMax,
+ "result %v should be <= expected max
%v", result, tt.expectedMax)
+ }
+ })
+ }
+}
+
+func TestJitteredBackoffEdgeCases(t *testing.T) {
+ tests := []struct {
+ name string
+ baseBackoff time.Duration
+ maxBackoff time.Duration
+ attempt int
+ jitterFactor float64
+ }{
+ {
+ name: "negative_jitter_factor",
+ baseBackoff: 100 * time.Millisecond,
+ maxBackoff: 10 * time.Second,
+ attempt: 0,
+ jitterFactor: -0.5,
+ },
+ {
+ name: "jitter_factor_greater_than_one",
+ baseBackoff: 100 * time.Millisecond,
+ maxBackoff: 10 * time.Second,
+ attempt: 0,
+ jitterFactor: 1.5,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Should not panic and should return a reasonable value
+ result := JitteredBackoff(tt.baseBackoff,
tt.maxBackoff, tt.attempt, tt.jitterFactor)
+ assert.GreaterOrEqual(t, result, time.Duration(0),
"result should be non-negative")
+ assert.LessOrEqual(t, result, tt.maxBackoff, "result
should not exceed max backoff")
+ })
+ }
+}
+
+func TestIsTransientError(t *testing.T) {
+ tests := []struct {
+ err error
+ name string
+ expectRetry bool
+ }{
+ {
+ name: "nil_error",
+ err: nil,
+ expectRetry: false,
+ },
+ {
+ name: "unavailable_error",
+ err: status.Error(codes.Unavailable, "service
unavailable"),
+ expectRetry: true,
+ },
+ {
+ name: "deadline_exceeded_error",
+ err: status.Error(codes.DeadlineExceeded,
"deadline exceeded"),
+ expectRetry: true,
+ },
+ {
+ name: "resource_exhausted_error",
+ err: status.Error(codes.ResourceExhausted,
"rate limited"),
+ expectRetry: true,
+ },
+ {
+ name: "internal_error",
+ err: status.Error(codes.Internal, "internal"),
+ expectRetry: true,
+ },
+ {
+ name: "not_found_error",
+ err: status.Error(codes.NotFound, "not found"),
+ expectRetry: false,
+ },
+ {
+ name: "permission_denied_error",
+ err: status.Error(codes.PermissionDenied,
"permission denied"),
+ expectRetry: false,
+ },
+ {
+ name: "invalid_argument_error",
+ err: status.Error(codes.InvalidArgument,
"invalid argument"),
+ expectRetry: false,
+ },
+ {
+ name: "ok_error",
+ err: status.Error(codes.OK, "ok"),
+ expectRetry: false,
+ },
+ {
+ name: "canceled_error",
+ err: status.Error(codes.Canceled, "canceled"),
+ expectRetry: false,
+ },
+ {
+ name: "unknown_error",
+ err: status.Error(codes.Unknown, "unknown"),
+ expectRetry: false,
+ },
+ {
+ name: "already_exists_error",
+ err: status.Error(codes.AlreadyExists,
"exists"),
+ expectRetry: false,
+ },
+ {
+ name: "failed_precondition_error",
+ err: status.Error(codes.FailedPrecondition,
"failed"),
+ expectRetry: false,
+ },
+ {
+ name: "aborted_error",
+ err: status.Error(codes.Aborted, "aborted"),
+ expectRetry: false,
+ },
+ {
+ name: "out_of_range_error",
+ err: status.Error(codes.OutOfRange, "range"),
+ expectRetry: false,
+ },
+ {
+ name: "unimplemented_error",
+ err: status.Error(codes.Unimplemented,
"unimpl"),
+ expectRetry: false,
+ },
+ {
+ name: "data_loss_error",
+ err: status.Error(codes.DataLoss, "loss"),
+ expectRetry: false,
+ },
+ {
+ name: "unauthenticated_error",
+ err: status.Error(codes.Unauthenticated,
"unauth"),
+ expectRetry: false,
+ },
+ {
+ name: "generic_error",
+ err: fmt.Errorf("generic error"),
+ expectRetry: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := IsTransientError(tt.err)
+ assert.Equal(t, tt.expectRetry, result, "transient
error classification mismatch")
+ })
+ }
+}
+
+func TestIsFailoverError(t *testing.T) {
+ tests := []struct {
+ err error
+ name string
+ expectFailover bool
+ }{
+ {name: "nil_error", err: nil, expectFailover: false},
+ {name: "unavailable_error", err:
status.Error(codes.Unavailable, "unavailable"), expectFailover: true},
+ {name: "deadline_exceeded_error", err:
status.Error(codes.DeadlineExceeded, "timeout"), expectFailover: true},
+ {name: "internal_error", err: status.Error(codes.Internal,
"internal"), expectFailover: false},
+ {name: "invalid_argument_error", err:
status.Error(codes.InvalidArgument, "bad"), expectFailover: false},
+ {name: "ok_error", err: status.Error(codes.OK, "ok"),
expectFailover: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := IsFailoverError(tt.err)
+ assert.Equal(t, tt.expectFailover, result, "failover
error classification mismatch")
+ })
+ }
+}
+
+func TestIsTransientErrorWrapped(t *testing.T) {
+ grpcUnavailable := status.Error(codes.Unavailable, "unavailable")
+ grpcInternal := status.Error(codes.Internal, "internal")
+ grpcInvalidArg := status.Error(codes.InvalidArgument, "bad arg")
+ grpcDeadline := status.Error(codes.DeadlineExceeded, "timeout")
+
+ tests := []struct {
+ err error
+ name string
+ expectRetry bool
+ }{
+ // direct gRPC status (baseline)
+ {name: "direct_unavailable", err: grpcUnavailable, expectRetry:
true},
+ {name: "direct_internal", err: grpcInternal, expectRetry: true},
+ {name: "direct_invalid_argument", err: grpcInvalidArg,
expectRetry: false},
+ {name: "direct_deadline_exceeded", err: grpcDeadline,
expectRetry: true},
+ // fmt.Errorf %w
+ {name: "fmt_errorf_unavailable", err: fmt.Errorf("wrap: %w",
grpcUnavailable), expectRetry: true},
+ {name: "fmt_errorf_internal", err: fmt.Errorf("wrap: %w",
grpcInternal), expectRetry: true},
+ {name: "fmt_errorf_invalid_argument", err: fmt.Errorf("wrap:
%w", grpcInvalidArg), expectRetry: false},
+ {
+ name: "fmt_errorf_double_wrapped",
+ err: fmt.Errorf("outer: %w", fmt.Errorf("inner: %w",
grpcDeadline)), expectRetry: true,
+ },
+ // errors.New with message only — no gRPC status in chain
+ {name: "errors_new_with_grpc_message_text", err:
errors.New(grpcUnavailable.Error()), expectRetry: false},
+ // errors.Join
+ {name: "errors_join_unavailable", err:
errors.Join(errors.New("extra"), grpcUnavailable), expectRetry: true},
+ {name: "errors_join_invalid_argument", err:
errors.Join(errors.New("extra"), grpcInvalidArg), expectRetry: false},
+ {
+ name: "errors_join_multiple_grpc_errors",
+ err: errors.Join(grpcInvalidArg, grpcUnavailable),
expectRetry: false,
+ },
+ // custom Unwrap() error
+ {
+ name: "custom_unwrap_internal",
+ err: &customUnwrapError{msg: "custom", cause:
grpcInternal}, expectRetry: true,
+ },
+ {
+ name: "custom_unwrap_invalid_argument",
+ err: &customUnwrapError{msg: "custom", cause:
grpcInvalidArg}, expectRetry: false,
+ },
+ {
+ name: "custom_unwrap_nested_in_fmt_errorf",
+ err: fmt.Errorf("outer: %w", &customUnwrapError{msg:
"inner", cause: grpcUnavailable}), expectRetry: true,
+ },
+ // custom Unwrap() []error
+ {
+ name: "multi_unwrap_internal",
+ err: &multiUnwrapError{msg: "multi", causes:
[]error{errors.New("other"), grpcInternal}}, expectRetry: true,
+ },
+ {
+ name: "multi_unwrap_all_non_retryable",
+ err: &multiUnwrapError{msg: "multi", causes:
[]error{errors.New("a"), grpcInvalidArg}}, expectRetry: false,
+ },
+ // pkg/errors Wrap and Wrapf
+ {name: "pkgerrors_wrap_unavailable", err:
pkgerrors.Wrap(grpcUnavailable, "wrap"), expectRetry: true},
+ {name: "pkgerrors_wrapf_internal", err:
pkgerrors.Wrapf(grpcInternal, "wrap %s", "ctx"), expectRetry: true},
+ {name: "pkgerrors_wrap_invalid_argument", err:
pkgerrors.Wrap(grpcInvalidArg, "wrap"), expectRetry: false},
+ {
+ name: "pkgerrors_wrap_double_wrapped",
+ err: pkgerrors.Wrap(pkgerrors.Wrap(grpcDeadline,
"inner"), "outer"), expectRetry: true,
+ },
+ // plain errors — no gRPC status
+ {name: "plain_errors_new", err: errors.New("plain error"),
expectRetry: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := IsTransientError(tt.err)
+ assert.Equal(t, tt.expectRetry, result, "transient
error classification mismatch")
+ })
+ }
+}
+
+func TestIsFailoverErrorWrapped(t *testing.T) {
+ grpcUnavailable := status.Error(codes.Unavailable, "unavailable")
+ grpcDeadline := status.Error(codes.DeadlineExceeded, "timeout")
+ grpcInternal := status.Error(codes.Internal, "internal")
+
+ tests := []struct {
+ err error
+ name string
+ expectFailover bool
+ }{
+ // direct gRPC status (baseline)
+ {name: "direct_unavailable", err: grpcUnavailable,
expectFailover: true},
+ {name: "direct_deadline_exceeded", err: grpcDeadline,
expectFailover: true},
+ {name: "direct_internal", err: grpcInternal, expectFailover:
false},
+ // fmt.Errorf %w
+ {name: "fmt_errorf_unavailable", err: fmt.Errorf("wrap: %w",
grpcUnavailable), expectFailover: true},
+ {name: "fmt_errorf_deadline_exceeded", err: fmt.Errorf("wrap:
%w", grpcDeadline), expectFailover: true},
+ {name: "fmt_errorf_internal", err: fmt.Errorf("wrap: %w",
grpcInternal), expectFailover: false},
+ {
+ name: "fmt_errorf_double_wrapped",
+ err: fmt.Errorf("outer: %w", fmt.Errorf("inner: %w",
grpcUnavailable)), expectFailover: true,
+ },
+ // errors.New with message only — no gRPC status in chain
+ {name: "errors_new_with_grpc_message_text", err:
errors.New(grpcUnavailable.Error()), expectFailover: false},
+ // errors.Join
+ {name: "errors_join_unavailable", err:
errors.Join(errors.New("extra"), grpcUnavailable), expectFailover: true},
+ {name: "errors_join_internal", err:
errors.Join(errors.New("extra"), grpcInternal), expectFailover: false},
+ // custom Unwrap() error
+ {
+ name: "custom_unwrap_deadline_exceeded",
+ err: &customUnwrapError{msg: "custom", cause:
grpcDeadline}, expectFailover: true,
+ },
+ {
+ name: "custom_unwrap_internal",
+ err: &customUnwrapError{msg: "custom", cause:
grpcInternal}, expectFailover: false,
+ },
+ // custom Unwrap() []error
+ {
+ name: "multi_unwrap_unavailable",
+ err: &multiUnwrapError{msg: "multi", causes:
[]error{errors.New("other"), grpcUnavailable}}, expectFailover: true,
+ },
+ // pkg/errors Wrap and Wrapf
+ {name: "pkgerrors_wrap_unavailable", err:
pkgerrors.Wrap(grpcUnavailable, "wrap"), expectFailover: true},
+ {name: "pkgerrors_wrapf_deadline_exceeded", err:
pkgerrors.Wrapf(grpcDeadline, "wrap %s", "ctx"), expectFailover: true},
+ {name: "pkgerrors_wrap_internal", err:
pkgerrors.Wrap(grpcInternal, "wrap"), expectFailover: false},
+ // plain errors — no gRPC status
+ {name: "plain_errors_new", err: errors.New("plain error"),
expectFailover: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := IsFailoverError(tt.err)
+ assert.Equal(t, tt.expectFailover, result, "failover
error classification mismatch")
+ })
+ }
+}
+
+func TestIsInternalErrorWrapped(t *testing.T) {
+ grpcInternal := status.Error(codes.Internal, "internal")
+ grpcUnavailable := status.Error(codes.Unavailable, "unavailable")
+
+ tests := []struct {
+ err error
+ name string
+ expectInternal bool
+ }{
+ // direct gRPC status (baseline)
+ {name: "direct_internal", err: grpcInternal, expectInternal:
true},
+ {name: "direct_unavailable", err: grpcUnavailable,
expectInternal: false},
+ // fmt.Errorf %w
+ {name: "fmt_errorf_internal", err: fmt.Errorf("wrap: %w",
grpcInternal), expectInternal: true},
+ {name: "fmt_errorf_unavailable", err: fmt.Errorf("wrap: %w",
grpcUnavailable), expectInternal: false},
+ {
+ name: "fmt_errorf_double_wrapped",
+ err: fmt.Errorf("outer: %w", fmt.Errorf("inner: %w",
grpcInternal)), expectInternal: true,
+ },
+ // errors.New with message only — no gRPC status in chain
+ {name: "errors_new_with_grpc_message_text", err:
errors.New(grpcInternal.Error()), expectInternal: false},
+ // errors.Join
+ {name: "errors_join_internal", err:
errors.Join(errors.New("extra"), grpcInternal), expectInternal: true},
+ {name: "errors_join_unavailable", err:
errors.Join(errors.New("extra"), grpcUnavailable), expectInternal: false},
+ // custom Unwrap() error
+ {
+ name: "custom_unwrap_internal",
+ err: &customUnwrapError{msg: "custom", cause:
grpcInternal}, expectInternal: true,
+ },
+ {
+ name: "custom_unwrap_unavailable",
+ err: &customUnwrapError{msg: "custom", cause:
grpcUnavailable}, expectInternal: false,
+ },
+ // custom Unwrap() []error
+ {
+ name: "multi_unwrap_internal",
+ err: &multiUnwrapError{msg: "multi", causes:
[]error{errors.New("other"), grpcInternal}}, expectInternal: true,
+ },
+ {
+ name: "multi_unwrap_all_non_internal",
+ err: &multiUnwrapError{msg: "multi", causes:
[]error{errors.New("a"), grpcUnavailable}}, expectInternal: false,
+ },
+ // pkg/errors Wrap and Wrapf
+ {name: "pkgerrors_wrap_internal", err:
pkgerrors.Wrap(grpcInternal, "wrap"), expectInternal: true},
+ {name: "pkgerrors_wrapf_internal", err:
pkgerrors.Wrapf(grpcInternal, "wrap %s", "ctx"), expectInternal: true},
+ {name: "pkgerrors_wrap_unavailable", err:
pkgerrors.Wrap(grpcUnavailable, "wrap"), expectInternal: false},
+ // plain errors — no gRPC status
+ {name: "plain_errors_new", err: errors.New("plain error"),
expectInternal: false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := IsInternalError(tt.err)
+ assert.Equal(t, tt.expectInternal, result, "internal
error classification mismatch")
+ })
+ }
+}
+
+func TestBackoffDistribution(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping distribution test in short mode")
+ }
+
+ const samples = 100
+ baseBackoff := 100 * time.Millisecond
+ maxBackoff := 10 * time.Second
+ jitter := 0.2
+
+ var totalBackoff time.Duration
+ for i := 0; i < samples; i++ {
+ b := JitteredBackoff(baseBackoff, maxBackoff, 0, jitter)
+ totalBackoff += b
+ assert.GreaterOrEqual(t, b, 80*time.Millisecond, "backoff
should be >= 80ms with 0.2 jitter")
+ assert.LessOrEqual(t, b, 120*time.Millisecond, "backoff should
be <= 120ms with 0.2 jitter")
+ }
+
+ avgBackoff := totalBackoff / time.Duration(samples)
+ assert.InDelta(t, float64(100*time.Millisecond), float64(avgBackoff),
float64(20*time.Millisecond),
+ "average backoff should be near base duration")
+}