Copilot commented on code in PR #920:
URL:
https://github.com/apache/skywalking-banyandb/pull/920#discussion_r2660519022
##########
banyand/metadata/client.go:
##########
@@ -96,13 +115,54 @@ func (s *clientService) FlagSet() *run.FlagSet {
fs.StringVar(&s.etcdTLSKeyFile, flagEtcdTLSKeyFile, "", "Private key
for the etcd client certificate.")
fs.DurationVar(&s.registryTimeout, "node-registry-timeout",
2*time.Minute, "The timeout for the node registry")
fs.DurationVar(&s.etcdFullSyncInterval, "etcd-full-sync-interval",
30*time.Minute, "The interval for full sync etcd")
+
+ // node discovery configuration
+ fs.StringVar(&s.nodeDiscoveryMode, "node-discovery-mode",
NodeDiscoveryModeEtcd,
+ "Node discovery mode: 'etcd' for etcd-based discovery, 'dns'
for DNS-based discovery")
+ fs.StringSliceVar(&s.dnsSRVAddresses,
"node-discovery-dns-srv-addresses", []string{},
+ "DNS SRV addresses for node discovery (e.g.,
_grpc._tcp.banyandb.svc.cluster.local)")
+ fs.DurationVar(&s.dnsFetchInitInterval,
"node-discovery-dns-fetch-init-interval", 5*time.Second,
+ "DNS query interval during initialization phase")
+ fs.DurationVar(&s.dnsFetchInitDuration,
"node-discovery-dns-fetch-init-duration", 5*time.Minute,
+ "Duration of the initialization phase for DNS discovery")
+ fs.DurationVar(&s.dnsFetchInterval,
"node-discovery-dns-fetch-interval", 15*time.Second,
+ "DNS query interval after initialization phase")
+ fs.DurationVar(&s.grpcTimeout, "node-discovery-grpc-timeout",
5*time.Second,
+ "Timeout for gRPC calls to fetch node metadata")
+ fs.BoolVar(&s.dnsTLSEnabled, "node-discovery-dns-tls", false,
+ "Enable TLS for DNS discovery gRPC connections")
+ fs.StringSliceVar(&s.dnsCACertPaths, "node-discovery-dns-ca-certs",
[]string{},
+ "Comma-separated list of CA certificate files to verify DNS
discovered nodes (one per SRV address, in same order)")
+
return fs
}
func (s *clientService) Validate() error {
- if s.endpoints == nil {
- return errors.New("endpoints is empty")
+ if s.nodeDiscoveryMode != NodeDiscoveryModeEtcd && s.nodeDiscoveryMode
!= NodeDiscoveryModeDNS {
+ return fmt.Errorf("invalid node-discovery-mode: %s, must be
'%s' or '%s'", s.nodeDiscoveryMode, NodeDiscoveryModeEtcd, NodeDiscoveryModeDNS)
}
+
+ // Validate etcd endpoints (required for both modes for schema storage)
Review Comment:
The validation only checks that etcd endpoints are not empty when mode is
etcd, but the comment on line 145 states "required for both modes for schema
storage". This suggests that even in DNS mode, etcd endpoints might be needed
for schema storage. The validation should either require endpoints for both
modes, or the comment should be corrected if DNS mode truly doesn't require
etcd.
```suggestion
// Validate etcd endpoints (required when using etcd-based node
discovery for schema storage)
```
##########
banyand/queue/sub/server.go:
##########
@@ -139,6 +144,27 @@ func (s *server) PreRun(_ context.Context) error {
s.log.Info().Str("certFile", s.certFile).Str("keyFile",
s.keyFile).Msg("Initialized TLS reloader for queue server")
}
+ nodeVal := ctx.Value(common.ContextNodeKey)
+ roleVal := ctx.Value(common.ContextNodeRolesKey)
+ if nodeVal == nil || roleVal == nil {
+ s.log.Warn().Msg("node or role value not found in context")
+ return nil
+ }
+ nodeRoles := roleVal.([]databasev1.Role)
+ node := nodeVal.(common.Node)
Review Comment:
Type assertions on lines 153-154 can panic if the context values are not of
the expected types. While the nil check on line 149 handles missing values, if
the values exist but are of wrong types, the server will panic. Add type
assertion checks with ok pattern to handle type mismatches gracefully.
##########
banyand/metadata/dns/dns.go:
##########
@@ -0,0 +1,603 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Package dns implements DNS-based node discovery for distributed metadata
management.
+package dns
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/credentials/insecure"
+
+ databasev1
"github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
+ "github.com/apache/skywalking-banyandb/banyand/metadata/schema"
+ "github.com/apache/skywalking-banyandb/banyand/observability"
+ "github.com/apache/skywalking-banyandb/pkg/grpchelper"
+ "github.com/apache/skywalking-banyandb/pkg/logger"
+ "github.com/apache/skywalking-banyandb/pkg/run"
+ pkgtls "github.com/apache/skywalking-banyandb/pkg/tls"
+)
+
+// Service implements DNS-based node discovery.
+type Service struct {
+ lastQueryTime time.Time
+ resolver Resolver
+ pathToReloader map[string]*pkgtls.Reloader
+ srvIndexToPath map[int]string
+ resolvedAddrToSRVIdx map[string]int
+ nodeCache map[string]*databasev1.Node
+ closer *run.Closer
+ log *logger.Logger
+ metrics *metrics
+ handlers map[string]schema.EventHandler
+ caCertPaths []string
+ srvAddresses []string
+ lastSuccessfulDNS []string
+ pollInterval time.Duration
+ initInterval time.Duration
+ initDuration time.Duration
+ grpcTimeout time.Duration
+ cacheMutex sync.RWMutex
+ handlersMutex sync.RWMutex
+ lastQueryMutex sync.RWMutex
+ resolvedAddrMutex sync.RWMutex
+ tlsEnabled bool
+}
+
+// Config holds configuration for DNS discovery service.
+type Config struct {
+ CACertPaths []string
+ SRVAddresses []string
+ InitInterval time.Duration
+ InitDuration time.Duration
+ PollInterval time.Duration
+ GRPCTimeout time.Duration
+ TLSEnabled bool
+}
+
+// Resolver defines the interface for DNS SRV lookups.
+type Resolver interface {
+ LookupSRV(ctx context.Context, name string) (string, []*net.SRV, error)
+}
+
+// defaultResolver wraps net.DefaultResolver to implement Resolver.
+type defaultResolver struct{}
+
+func (d *defaultResolver) LookupSRV(ctx context.Context, name string) (string,
[]*net.SRV, error) {
+ return net.DefaultResolver.LookupSRV(ctx, "", "", name)
+}
+
+// NewService creates a new DNS discovery service.
+func NewService(cfg Config) (*Service, error) {
+ // validation
+ if len(cfg.SRVAddresses) == 0 {
+ return nil, errors.New("DNS SRV addresses cannot be empty")
+ }
+
+ // validate CA cert paths match SRV addresses when TLS is enabled
+ if cfg.TLSEnabled {
+ if len(cfg.CACertPaths) != len(cfg.SRVAddresses) {
+ return nil, fmt.Errorf("number of CA cert paths (%d)
must match number of SRV addresses (%d)",
+ len(cfg.CACertPaths), len(cfg.SRVAddresses))
+ }
+ }
+
+ svc := &Service{
+ srvAddresses: cfg.SRVAddresses,
+ initInterval: cfg.InitInterval,
+ initDuration: cfg.InitDuration,
+ pollInterval: cfg.PollInterval,
+ grpcTimeout: cfg.GRPCTimeout,
+ tlsEnabled: cfg.TLSEnabled,
+ caCertPaths: cfg.CACertPaths,
+ nodeCache: make(map[string]*databasev1.Node),
+ handlers: make(map[string]schema.EventHandler),
+ lastSuccessfulDNS: []string{},
+ pathToReloader: make(map[string]*pkgtls.Reloader),
+ srvIndexToPath: make(map[int]string),
+ resolvedAddrToSRVIdx: make(map[string]int),
+ closer: run.NewCloser(1),
+ log:
logger.GetLogger("metadata-discovery-dns"),
+ resolver: &defaultResolver{},
+ }
+
+ // create shared reloaders for CA certificates
+ if svc.tlsEnabled {
+ for srvIdx, certPath := range cfg.CACertPaths {
+ // Store the SRV index → cert path mapping
+ svc.srvIndexToPath[srvIdx] = certPath
+
+ // check if we already have a Reloader for this path
+ if _, exists := svc.pathToReloader[certPath]; exists {
+ svc.log.Debug().Str("certPath",
certPath).Int("srvIndex", srvIdx).
+ Msg("Reusing existing CA certificate
reloader")
+ continue
+ }
+
+ // create new Reloader for this unique path
+ reloader, reloaderErr :=
pkgtls.NewClientCertReloader(certPath, svc.log)
+ if reloaderErr != nil {
+ // clean up any already-created reloaders
+ for _, r := range svc.pathToReloader {
+ r.Stop()
+ }
+ return nil, fmt.Errorf("failed to initialize CA
certificate reloader for path %s (SRV index %d): %w",
+ certPath, srvIdx, reloaderErr)
+ }
+
+ svc.pathToReloader[certPath] = reloader
+ svc.log.Info().Str("certPath",
certPath).Int("srvIndex", srvIdx).
+ Str("srvAddress",
cfg.SRVAddresses[srvIdx]).Msg("Initialized DNS CA certificate reloader")
+ }
+ }
+
+ return svc, nil
+}
+
+// newServiceWithResolver creates a service with a custom resolver (for
testing).
+func newServiceWithResolver(cfg Config, resolver Resolver) (*Service, error) {
+ svc, err := NewService(cfg)
+ if err != nil {
+ return nil, err
+ }
+ svc.resolver = resolver
+ return svc, nil
+}
+
+func (s *Service) getTLSDialOptions(address string) ([]grpc.DialOption, error)
{
+ if !s.tlsEnabled {
+ return
[]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, nil
+ }
+
+ // look up which Reloader to use for this address
+ if len(s.pathToReloader) > 0 {
+ // Find which SRV address this resolved address came from
+ s.resolvedAddrMutex.RLock()
+ srvIdx, addrExists := s.resolvedAddrToSRVIdx[address]
+ s.resolvedAddrMutex.RUnlock()
+
+ if !addrExists {
+ return nil, fmt.Errorf("no SRV mapping found for
address %s", address)
+ }
+
+ // look up the cert path for this SRV index
+ certPath, pathExists := s.srvIndexToPath[srvIdx]
+ if !pathExists {
+ return nil, fmt.Errorf("no cert path found for SRV
index %d (address %s)", srvIdx, address)
+ }
+
+ // get the Reloader for this cert path
+ reloader, reloaderExists := s.pathToReloader[certPath]
+ if !reloaderExists {
+ return nil, fmt.Errorf("no reloader found for cert path
%s (address %s)", certPath, address)
+ }
+
+ // get fresh TLS config from the Reloader
+ tlsConfig, configErr := reloader.GetClientTLSConfig("")
+ if configErr != nil {
+ return nil, fmt.Errorf("failed to get TLS config from
reloader for address %s: %w", address, configErr)
+ }
+
+ creds := credentials.NewTLS(tlsConfig)
+ return []grpc.DialOption{grpc.WithTransportCredentials(creds)},
nil
+ }
+
+ // fallback to static TLS config (when no reloaders configured)
+ opts, err := grpchelper.SecureOptions(nil, s.tlsEnabled, false, "")
+ if err != nil {
+ return nil, fmt.Errorf("failed to load TLS config: %w", err)
+ }
+ return opts, nil
+}
+
+// Start begins the DNS discovery background process.
+func (s *Service) Start(ctx context.Context) error {
+ s.log.Debug().Msg("Starting DNS-based node discovery service")
+
+ // start all Reloaders
+ if len(s.pathToReloader) > 0 {
+ startedReloaders := make([]*pkgtls.Reloader, 0,
len(s.pathToReloader))
+
+ for certPath, reloader := range s.pathToReloader {
+ if startErr := reloader.Start(); startErr != nil {
+ // stop any already-started reloaders
+ for _, r := range startedReloaders {
+ r.Stop()
+ }
+ return fmt.Errorf("failed to start CA
certificate reloader for path %s: %w", certPath, startErr)
+ }
+ startedReloaders = append(startedReloaders, reloader)
+ s.log.Debug().Str("certPath", certPath).Msg("Started CA
certificate reloader")
+ }
+ }
+
+ go s.discoveryLoop(ctx)
+
+ return nil
+}
+
+func (s *Service) discoveryLoop(ctx context.Context) {
+ // add the init phase finish time
+ initPhaseEnd := time.Now().Add(s.initDuration)
+
+ for {
+ if err := s.queryDNSAndUpdateNodes(ctx); err != nil {
+ s.log.Err(err).Msg("failed to query DNS and update
nodes")
+ }
+
+ // wait for next interval
+ var interval time.Duration
+ if time.Now().Before(initPhaseEnd) {
+ interval = s.initInterval
+ } else {
+ interval = s.pollInterval
+ }
+
+ timer := time.NewTimer(interval)
+ select {
+ case <-ctx.Done():
+ timer.Stop()
+ return
+ case <-s.closer.CloseNotify():
+ timer.Stop()
+ return
+ case <-timer.C:
+ // continue to next iteration
+ }
+ }
+}
+
+func (s *Service) queryDNSAndUpdateNodes(ctx context.Context) error {
+ // Record summary metrics
+ startTime := time.Now()
+ defer func() {
+ if s.metrics != nil {
+ duration := time.Since(startTime)
+ s.metrics.discoveryCount.Inc(1)
+ s.metrics.discoveryDuration.Observe(duration.Seconds())
+ s.metrics.discoveryTotalDuration.Inc(duration.Seconds())
+ }
+ }()
+
+ addresses, queryErr := s.queryAllSRVRecords(ctx)
+
+ if queryErr != nil {
+ s.log.Warn().Err(queryErr).Msg("DNS query failed, using last
successful cache")
+ addresses = s.lastSuccessfulDNS
+ if len(addresses) == 0 {
+ if s.metrics != nil {
+ s.metrics.discoveryFailedCount.Inc(1)
+ }
+ return fmt.Errorf("DNS query failed and no cached
addresses available: %w", queryErr)
+ }
+ } else {
+ s.lastSuccessfulDNS = addresses
+ if s.log.Debug().Enabled() {
+ s.log.Debug().
+ Int("count", len(addresses)).
+ Strs("addresses", addresses).
+ Strs("srv_addresses", s.srvAddresses).
+ Msg("DNS query successful")
+ }
+ }
+
+ // Update node cache based on DNS results
+ updateErr := s.updateNodeCache(ctx, addresses)
+ if updateErr != nil && s.metrics != nil {
+ s.metrics.discoveryFailedCount.Inc(1)
+ }
+ s.lastQueryMutex.Lock()
+ s.lastQueryTime = time.Now()
+ s.lastQueryMutex.Unlock()
+ return updateErr
+}
+
+func (s *Service) queryAllSRVRecords(ctx context.Context) ([]string, error) {
+ startTime := time.Now()
+ defer func() {
+ if s.metrics != nil {
+ duration := time.Since(startTime)
+ s.metrics.dnsQueryCount.Inc(1)
+ s.metrics.dnsQueryDuration.Observe(duration.Seconds())
+ s.metrics.dnsQueryTotalDuration.Inc(duration.Seconds())
+ }
+ }()
+
+ allAddresses := make(map[string]bool)
+ // track which SRV address (by index) each resolved address came from
+ newAddrToSRVIdx := make(map[string]int)
+ var queryErrors []error
+
+ for srvIdx, srvAddr := range s.srvAddresses {
+ _, addrs, lookupErr := s.resolver.LookupSRV(ctx, srvAddr)
+ if lookupErr != nil {
+ queryErrors = append(queryErrors, fmt.Errorf("lookup %s
failed: %w", srvAddr, lookupErr))
+ continue
+ }
+
+ for _, srv := range addrs {
+ address := fmt.Sprintf("%s:%d", srv.Target, srv.Port)
+ allAddresses[address] = true
+
+ // track which SRV address this resolved to (first-wins
strategy)
+ if _, exists := newAddrToSRVIdx[address]; !exists {
+ newAddrToSRVIdx[address] = srvIdx
+ }
+ }
+ }
+
+ // if there have any error occurred,
+ // then just return the query error to ignore the result to make sure
the cache correct
+ if len(queryErrors) > 0 {
+ if s.metrics != nil {
+ s.metrics.dnsQueryFailedCount.Inc(1)
+ }
+ return nil, errors.Join(queryErrors...)
+ }
+
+ // update the resolved address to SRV index mapping
+ s.resolvedAddrMutex.Lock()
+ s.resolvedAddrToSRVIdx = newAddrToSRVIdx
+ s.resolvedAddrMutex.Unlock()
+
+ // convert map to slice
+ result := make([]string, 0, len(allAddresses))
+ for addr := range allAddresses {
+ result = append(result, addr)
+ }
+
+ return result, nil
+}
+
+func (s *Service) updateNodeCache(ctx context.Context, addresses []string)
error {
+ addressSet := make(map[string]bool)
+ for _, addr := range addresses {
+ addressSet[addr] = true
+ }
+
+ var addErrors []error
+
+ for addr := range addressSet {
+ s.cacheMutex.RLock()
+ _, exists := s.nodeCache[addr]
+ s.cacheMutex.RUnlock()
+
+ if !exists {
+ // fetch node metadata from gRPC
+ node, fetchErr := s.fetchNodeMetadata(ctx, addr)
+ if fetchErr != nil {
+ s.log.Warn().
+ Err(fetchErr).
+ Str("address", addr).
+ Msg("Failed to fetch node metadata")
+ addErrors = append(addErrors, fetchErr)
+ continue
+ }
+
+ // update cache and notify handlers
+ s.cacheMutex.Lock()
+ s.nodeCache[addr] = node
+ s.cacheMutex.Unlock()
+
+ s.notifyHandlers(schema.Metadata{
+ TypeMeta: schema.TypeMeta{
+ Kind: schema.KindNode,
+ Name: node.GetMetadata().GetName(),
+ },
+ Spec: node,
+ }, true)
+
+ s.log.Debug().
+ Str("address", addr).
+ Str("name", node.GetMetadata().GetName()).
+ Msg("New node discovered and added to cache")
+ }
+ }
+
+ // collect nodes to delete first
+ s.cacheMutex.Lock()
+ nodesToDelete := make(map[string]*databasev1.Node)
+ for addr, node := range s.nodeCache {
+ if !addressSet[addr] {
+ nodesToDelete[addr] = node
+ }
+ }
+
+ // delete from cache while still holding lock
+ for addr, node := range nodesToDelete {
+ delete(s.nodeCache, addr)
+ s.log.Debug().
+ Str("address", addr).
+ Str("name", node.GetMetadata().GetName()).
+ Msg("Node removed from cache (no longer in DNS)")
+ }
+ cacheSize := len(s.nodeCache)
+ s.cacheMutex.Unlock()
+
+ // Notify handlers after releasing lock
+ for _, node := range nodesToDelete {
+ s.notifyHandlers(schema.Metadata{
+ TypeMeta: schema.TypeMeta{
+ Kind: schema.KindNode,
+ Name: node.GetMetadata().GetName(),
+ },
+ Spec: node,
+ }, false)
+ }
+
+ // update total nodes metric
+ if s.metrics != nil {
+ s.metrics.totalNodesCount.Set(float64(cacheSize))
+ }
+
+ if len(addErrors) > 0 {
+ return errors.Join(addErrors...)
+ }
+
+ return nil
+}
+
+func (s *Service) fetchNodeMetadata(ctx context.Context, address string)
(*databasev1.Node, error) {
+ // record gRPC query metrics
+ startTime := time.Now()
+ var grpcErr error
+ defer func() {
+ if s.metrics != nil {
+ duration := time.Since(startTime)
+ s.metrics.grpcQueryCount.Inc(1)
+ s.metrics.grpcQueryDuration.Observe(duration.Seconds())
+ s.metrics.grpcQueryTotalDuration.Inc(duration.Seconds())
+ if grpcErr != nil {
+ s.metrics.grpcQueryFailedCount.Inc(1)
+ }
+ }
+ }()
+
+ ctxTimeout, cancel := context.WithTimeout(ctx, s.grpcTimeout)
+ defer cancel()
+
+ // for TLS connections with other nodes to getting metadata
+ dialOpts, err := s.getTLSDialOptions(address)
+ if err != nil {
+ grpcErr = fmt.Errorf("failed to get TLS dial options: %w", err)
+ return nil, grpcErr
+ }
+ // nolint:contextcheck
+ conn, connErr := grpchelper.ConnWithAuth(address, s.grpcTimeout, "",
"", dialOpts...)
Review Comment:
The nolint directive disables context checking, but a context with timeout
(ctxTimeout) has already been created on line 477. The connection should use
ctxTimeout instead of having its own timeout parameter, or if ConnWithAuth
requires a separate timeout, it should respect the parent context's
cancellation. This could lead to connections outliving their intended context
lifetime.
```suggestion
// ensure the connection timeout does not exceed the remaining context
deadline
connTimeout := s.grpcTimeout
if deadline, ok := ctxTimeout.Deadline(); ok {
remaining := time.Until(deadline)
if remaining <= 0 {
grpcErr = ctxTimeout.Err()
return nil, grpcErr
}
if remaining < connTimeout {
connTimeout = remaining
}
}
conn, connErr := grpchelper.ConnWithAuth(address, connTimeout, "", "",
dialOpts...)
```
##########
banyand/metadata/dns/dns.go:
##########
@@ -0,0 +1,603 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Package dns implements DNS-based node discovery for distributed metadata
management.
+package dns
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/credentials/insecure"
+
+ databasev1
"github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
+ "github.com/apache/skywalking-banyandb/banyand/metadata/schema"
+ "github.com/apache/skywalking-banyandb/banyand/observability"
+ "github.com/apache/skywalking-banyandb/pkg/grpchelper"
+ "github.com/apache/skywalking-banyandb/pkg/logger"
+ "github.com/apache/skywalking-banyandb/pkg/run"
+ pkgtls "github.com/apache/skywalking-banyandb/pkg/tls"
+)
+
+// Service implements DNS-based node discovery.
+type Service struct {
+ lastQueryTime time.Time
+ resolver Resolver
+ pathToReloader map[string]*pkgtls.Reloader
+ srvIndexToPath map[int]string
+ resolvedAddrToSRVIdx map[string]int
+ nodeCache map[string]*databasev1.Node
+ closer *run.Closer
+ log *logger.Logger
+ metrics *metrics
+ handlers map[string]schema.EventHandler
+ caCertPaths []string
+ srvAddresses []string
+ lastSuccessfulDNS []string
+ pollInterval time.Duration
+ initInterval time.Duration
+ initDuration time.Duration
+ grpcTimeout time.Duration
+ cacheMutex sync.RWMutex
+ handlersMutex sync.RWMutex
+ lastQueryMutex sync.RWMutex
+ resolvedAddrMutex sync.RWMutex
+ tlsEnabled bool
+}
+
+// Config holds configuration for DNS discovery service.
+type Config struct {
+ CACertPaths []string
+ SRVAddresses []string
+ InitInterval time.Duration
+ InitDuration time.Duration
+ PollInterval time.Duration
+ GRPCTimeout time.Duration
+ TLSEnabled bool
+}
+
+// Resolver defines the interface for DNS SRV lookups.
+type Resolver interface {
+ LookupSRV(ctx context.Context, name string) (string, []*net.SRV, error)
+}
+
+// defaultResolver wraps net.DefaultResolver to implement Resolver.
+type defaultResolver struct{}
+
+func (d *defaultResolver) LookupSRV(ctx context.Context, name string) (string,
[]*net.SRV, error) {
+ return net.DefaultResolver.LookupSRV(ctx, "", "", name)
+}
+
+// NewService creates a new DNS discovery service.
+func NewService(cfg Config) (*Service, error) {
+ // validation
+ if len(cfg.SRVAddresses) == 0 {
+ return nil, errors.New("DNS SRV addresses cannot be empty")
+ }
+
+ // validate CA cert paths match SRV addresses when TLS is enabled
+ if cfg.TLSEnabled {
+ if len(cfg.CACertPaths) != len(cfg.SRVAddresses) {
+ return nil, fmt.Errorf("number of CA cert paths (%d)
must match number of SRV addresses (%d)",
+ len(cfg.CACertPaths), len(cfg.SRVAddresses))
+ }
+ }
+
+ svc := &Service{
+ srvAddresses: cfg.SRVAddresses,
+ initInterval: cfg.InitInterval,
+ initDuration: cfg.InitDuration,
+ pollInterval: cfg.PollInterval,
+ grpcTimeout: cfg.GRPCTimeout,
+ tlsEnabled: cfg.TLSEnabled,
+ caCertPaths: cfg.CACertPaths,
+ nodeCache: make(map[string]*databasev1.Node),
+ handlers: make(map[string]schema.EventHandler),
+ lastSuccessfulDNS: []string{},
+ pathToReloader: make(map[string]*pkgtls.Reloader),
+ srvIndexToPath: make(map[int]string),
+ resolvedAddrToSRVIdx: make(map[string]int),
+ closer: run.NewCloser(1),
+ log:
logger.GetLogger("metadata-discovery-dns"),
+ resolver: &defaultResolver{},
+ }
+
+ // create shared reloaders for CA certificates
+ if svc.tlsEnabled {
+ for srvIdx, certPath := range cfg.CACertPaths {
+ // Store the SRV index → cert path mapping
+ svc.srvIndexToPath[srvIdx] = certPath
+
+ // check if we already have a Reloader for this path
+ if _, exists := svc.pathToReloader[certPath]; exists {
+ svc.log.Debug().Str("certPath",
certPath).Int("srvIndex", srvIdx).
+ Msg("Reusing existing CA certificate
reloader")
+ continue
+ }
+
+ // create new Reloader for this unique path
+ reloader, reloaderErr :=
pkgtls.NewClientCertReloader(certPath, svc.log)
+ if reloaderErr != nil {
+ // clean up any already-created reloaders
+ for _, r := range svc.pathToReloader {
+ r.Stop()
+ }
+ return nil, fmt.Errorf("failed to initialize CA
certificate reloader for path %s (SRV index %d): %w",
+ certPath, srvIdx, reloaderErr)
+ }
+
+ svc.pathToReloader[certPath] = reloader
+ svc.log.Info().Str("certPath",
certPath).Int("srvIndex", srvIdx).
+ Str("srvAddress",
cfg.SRVAddresses[srvIdx]).Msg("Initialized DNS CA certificate reloader")
+ }
+ }
+
+ return svc, nil
+}
+
+// newServiceWithResolver creates a service with a custom resolver (for
testing).
+func newServiceWithResolver(cfg Config, resolver Resolver) (*Service, error) {
+ svc, err := NewService(cfg)
+ if err != nil {
+ return nil, err
+ }
+ svc.resolver = resolver
+ return svc, nil
+}
+
+func (s *Service) getTLSDialOptions(address string) ([]grpc.DialOption, error)
{
+ if !s.tlsEnabled {
+ return
[]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, nil
+ }
+
+ // look up which Reloader to use for this address
+ if len(s.pathToReloader) > 0 {
+ // Find which SRV address this resolved address came from
+ s.resolvedAddrMutex.RLock()
+ srvIdx, addrExists := s.resolvedAddrToSRVIdx[address]
+ s.resolvedAddrMutex.RUnlock()
+
+ if !addrExists {
+ return nil, fmt.Errorf("no SRV mapping found for
address %s", address)
+ }
+
+ // look up the cert path for this SRV index
+ certPath, pathExists := s.srvIndexToPath[srvIdx]
+ if !pathExists {
+ return nil, fmt.Errorf("no cert path found for SRV
index %d (address %s)", srvIdx, address)
+ }
+
+ // get the Reloader for this cert path
+ reloader, reloaderExists := s.pathToReloader[certPath]
+ if !reloaderExists {
+ return nil, fmt.Errorf("no reloader found for cert path
%s (address %s)", certPath, address)
+ }
+
+ // get fresh TLS config from the Reloader
+ tlsConfig, configErr := reloader.GetClientTLSConfig("")
+ if configErr != nil {
+ return nil, fmt.Errorf("failed to get TLS config from
reloader for address %s: %w", address, configErr)
+ }
+
+ creds := credentials.NewTLS(tlsConfig)
+ return []grpc.DialOption{grpc.WithTransportCredentials(creds)},
nil
+ }
+
+ // fallback to static TLS config (when no reloaders configured)
+ opts, err := grpchelper.SecureOptions(nil, s.tlsEnabled, false, "")
+ if err != nil {
+ return nil, fmt.Errorf("failed to load TLS config: %w", err)
+ }
+ return opts, nil
+}
+
+// Start begins the DNS discovery background process.
+func (s *Service) Start(ctx context.Context) error {
+ s.log.Debug().Msg("Starting DNS-based node discovery service")
+
+ // start all Reloaders
+ if len(s.pathToReloader) > 0 {
+ startedReloaders := make([]*pkgtls.Reloader, 0,
len(s.pathToReloader))
+
+ for certPath, reloader := range s.pathToReloader {
+ if startErr := reloader.Start(); startErr != nil {
+ // stop any already-started reloaders
+ for _, r := range startedReloaders {
+ r.Stop()
+ }
+ return fmt.Errorf("failed to start CA
certificate reloader for path %s: %w", certPath, startErr)
+ }
+ startedReloaders = append(startedReloaders, reloader)
+ s.log.Debug().Str("certPath", certPath).Msg("Started CA
certificate reloader")
+ }
+ }
+
+ go s.discoveryLoop(ctx)
+
+ return nil
+}
+
+func (s *Service) discoveryLoop(ctx context.Context) {
+ // add the init phase finish time
+ initPhaseEnd := time.Now().Add(s.initDuration)
+
+ for {
+ if err := s.queryDNSAndUpdateNodes(ctx); err != nil {
+ s.log.Err(err).Msg("failed to query DNS and update
nodes")
+ }
+
+ // wait for next interval
+ var interval time.Duration
+ if time.Now().Before(initPhaseEnd) {
+ interval = s.initInterval
+ } else {
+ interval = s.pollInterval
+ }
+
+ timer := time.NewTimer(interval)
+ select {
+ case <-ctx.Done():
+ timer.Stop()
+ return
+ case <-s.closer.CloseNotify():
+ timer.Stop()
+ return
+ case <-timer.C:
+ // continue to next iteration
+ }
+ }
+}
+
+func (s *Service) queryDNSAndUpdateNodes(ctx context.Context) error {
+ // Record summary metrics
+ startTime := time.Now()
+ defer func() {
+ if s.metrics != nil {
+ duration := time.Since(startTime)
+ s.metrics.discoveryCount.Inc(1)
+ s.metrics.discoveryDuration.Observe(duration.Seconds())
+ s.metrics.discoveryTotalDuration.Inc(duration.Seconds())
+ }
+ }()
+
+ addresses, queryErr := s.queryAllSRVRecords(ctx)
+
+ if queryErr != nil {
+ s.log.Warn().Err(queryErr).Msg("DNS query failed, using last
successful cache")
+ addresses = s.lastSuccessfulDNS
+ if len(addresses) == 0 {
+ if s.metrics != nil {
+ s.metrics.discoveryFailedCount.Inc(1)
+ }
+ return fmt.Errorf("DNS query failed and no cached
addresses available: %w", queryErr)
+ }
+ } else {
+ s.lastSuccessfulDNS = addresses
Review Comment:
When DNS query fails, the code updates lastSuccessfulDNS on line 295 only in
the success path. However, if DNS fails but cached addresses are available and
used (lines 287-293), lastSuccessfulDNS is not updated. This means if the cache
is stale and DNS queries start working again but return different results, the
old cache remains in lastSuccessfulDNS. Consider documenting this behavior or
updating the cache assignment logic to be more explicit about when
lastSuccessfulDNS should be refreshed.
##########
banyand/metadata/dns/dns.go:
##########
@@ -0,0 +1,603 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Package dns implements DNS-based node discovery for distributed metadata
management.
+package dns
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials"
+ "google.golang.org/grpc/credentials/insecure"
+
+ databasev1
"github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
+ "github.com/apache/skywalking-banyandb/banyand/metadata/schema"
+ "github.com/apache/skywalking-banyandb/banyand/observability"
+ "github.com/apache/skywalking-banyandb/pkg/grpchelper"
+ "github.com/apache/skywalking-banyandb/pkg/logger"
+ "github.com/apache/skywalking-banyandb/pkg/run"
+ pkgtls "github.com/apache/skywalking-banyandb/pkg/tls"
+)
+
+// Service implements DNS-based node discovery.
+type Service struct {
+ lastQueryTime time.Time
+ resolver Resolver
+ pathToReloader map[string]*pkgtls.Reloader
+ srvIndexToPath map[int]string
+ resolvedAddrToSRVIdx map[string]int
+ nodeCache map[string]*databasev1.Node
+ closer *run.Closer
+ log *logger.Logger
+ metrics *metrics
+ handlers map[string]schema.EventHandler
+ caCertPaths []string
+ srvAddresses []string
+ lastSuccessfulDNS []string
+ pollInterval time.Duration
+ initInterval time.Duration
+ initDuration time.Duration
+ grpcTimeout time.Duration
+ cacheMutex sync.RWMutex
+ handlersMutex sync.RWMutex
+ lastQueryMutex sync.RWMutex
+ resolvedAddrMutex sync.RWMutex
+ tlsEnabled bool
+}
+
+// Config holds configuration for DNS discovery service.
+type Config struct {
+ CACertPaths []string
+ SRVAddresses []string
+ InitInterval time.Duration
+ InitDuration time.Duration
+ PollInterval time.Duration
+ GRPCTimeout time.Duration
+ TLSEnabled bool
+}
+
+// Resolver defines the interface for DNS SRV lookups.
+type Resolver interface {
+ LookupSRV(ctx context.Context, name string) (string, []*net.SRV, error)
+}
+
+// defaultResolver wraps net.DefaultResolver to implement Resolver.
+type defaultResolver struct{}
+
+func (d *defaultResolver) LookupSRV(ctx context.Context, name string) (string,
[]*net.SRV, error) {
+ return net.DefaultResolver.LookupSRV(ctx, "", "", name)
+}
+
+// NewService creates a new DNS discovery service.
+func NewService(cfg Config) (*Service, error) {
+ // validation
+ if len(cfg.SRVAddresses) == 0 {
+ return nil, errors.New("DNS SRV addresses cannot be empty")
+ }
+
+ // validate CA cert paths match SRV addresses when TLS is enabled
+ if cfg.TLSEnabled {
+ if len(cfg.CACertPaths) != len(cfg.SRVAddresses) {
+ return nil, fmt.Errorf("number of CA cert paths (%d)
must match number of SRV addresses (%d)",
+ len(cfg.CACertPaths), len(cfg.SRVAddresses))
+ }
+ }
+
+ svc := &Service{
+ srvAddresses: cfg.SRVAddresses,
+ initInterval: cfg.InitInterval,
+ initDuration: cfg.InitDuration,
+ pollInterval: cfg.PollInterval,
+ grpcTimeout: cfg.GRPCTimeout,
+ tlsEnabled: cfg.TLSEnabled,
+ caCertPaths: cfg.CACertPaths,
+ nodeCache: make(map[string]*databasev1.Node),
+ handlers: make(map[string]schema.EventHandler),
+ lastSuccessfulDNS: []string{},
+ pathToReloader: make(map[string]*pkgtls.Reloader),
+ srvIndexToPath: make(map[int]string),
+ resolvedAddrToSRVIdx: make(map[string]int),
+ closer: run.NewCloser(1),
+ log:
logger.GetLogger("metadata-discovery-dns"),
+ resolver: &defaultResolver{},
+ }
+
+ // create shared reloaders for CA certificates
+ if svc.tlsEnabled {
+ for srvIdx, certPath := range cfg.CACertPaths {
+ // Store the SRV index → cert path mapping
+ svc.srvIndexToPath[srvIdx] = certPath
+
+ // check if we already have a Reloader for this path
+ if _, exists := svc.pathToReloader[certPath]; exists {
+ svc.log.Debug().Str("certPath",
certPath).Int("srvIndex", srvIdx).
+ Msg("Reusing existing CA certificate
reloader")
+ continue
+ }
+
+ // create new Reloader for this unique path
+ reloader, reloaderErr :=
pkgtls.NewClientCertReloader(certPath, svc.log)
+ if reloaderErr != nil {
+ // clean up any already-created reloaders
+ for _, r := range svc.pathToReloader {
+ r.Stop()
+ }
+ return nil, fmt.Errorf("failed to initialize CA
certificate reloader for path %s (SRV index %d): %w",
+ certPath, srvIdx, reloaderErr)
+ }
+
+ svc.pathToReloader[certPath] = reloader
+ svc.log.Info().Str("certPath",
certPath).Int("srvIndex", srvIdx).
+ Str("srvAddress",
cfg.SRVAddresses[srvIdx]).Msg("Initialized DNS CA certificate reloader")
+ }
+ }
+
+ return svc, nil
+}
+
+// newServiceWithResolver creates a service with a custom resolver (for
testing).
+func newServiceWithResolver(cfg Config, resolver Resolver) (*Service, error) {
+ svc, err := NewService(cfg)
+ if err != nil {
+ return nil, err
+ }
+ svc.resolver = resolver
+ return svc, nil
+}
+
+func (s *Service) getTLSDialOptions(address string) ([]grpc.DialOption, error)
{
+ if !s.tlsEnabled {
+ return
[]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, nil
+ }
+
+ // look up which Reloader to use for this address
+ if len(s.pathToReloader) > 0 {
+ // Find which SRV address this resolved address came from
+ s.resolvedAddrMutex.RLock()
+ srvIdx, addrExists := s.resolvedAddrToSRVIdx[address]
+ s.resolvedAddrMutex.RUnlock()
+
+ if !addrExists {
+ return nil, fmt.Errorf("no SRV mapping found for
address %s", address)
+ }
+
+ // look up the cert path for this SRV index
+ certPath, pathExists := s.srvIndexToPath[srvIdx]
+ if !pathExists {
+ return nil, fmt.Errorf("no cert path found for SRV
index %d (address %s)", srvIdx, address)
+ }
+
+ // get the Reloader for this cert path
+ reloader, reloaderExists := s.pathToReloader[certPath]
+ if !reloaderExists {
+ return nil, fmt.Errorf("no reloader found for cert path
%s (address %s)", certPath, address)
+ }
+
+ // get fresh TLS config from the Reloader
+ tlsConfig, configErr := reloader.GetClientTLSConfig("")
+ if configErr != nil {
+ return nil, fmt.Errorf("failed to get TLS config from
reloader for address %s: %w", address, configErr)
+ }
+
+ creds := credentials.NewTLS(tlsConfig)
+ return []grpc.DialOption{grpc.WithTransportCredentials(creds)},
nil
+ }
+
+ // fallback to static TLS config (when no reloaders configured)
+ opts, err := grpchelper.SecureOptions(nil, s.tlsEnabled, false, "")
+ if err != nil {
+ return nil, fmt.Errorf("failed to load TLS config: %w", err)
+ }
+ return opts, nil
+}
+
+// Start begins the DNS discovery background process.
+func (s *Service) Start(ctx context.Context) error {
+ s.log.Debug().Msg("Starting DNS-based node discovery service")
+
+ // start all Reloaders
+ if len(s.pathToReloader) > 0 {
+ startedReloaders := make([]*pkgtls.Reloader, 0,
len(s.pathToReloader))
+
+ for certPath, reloader := range s.pathToReloader {
+ if startErr := reloader.Start(); startErr != nil {
+ // stop any already-started reloaders
+ for _, r := range startedReloaders {
+ r.Stop()
+ }
+ return fmt.Errorf("failed to start CA
certificate reloader for path %s: %w", certPath, startErr)
+ }
+ startedReloaders = append(startedReloaders, reloader)
+ s.log.Debug().Str("certPath", certPath).Msg("Started CA
certificate reloader")
+ }
+ }
+
+ go s.discoveryLoop(ctx)
+
+ return nil
+}
+
+func (s *Service) discoveryLoop(ctx context.Context) {
+ // add the init phase finish time
+ initPhaseEnd := time.Now().Add(s.initDuration)
+
+ for {
+ if err := s.queryDNSAndUpdateNodes(ctx); err != nil {
+ s.log.Err(err).Msg("failed to query DNS and update
nodes")
+ }
+
+ // wait for next interval
+ var interval time.Duration
+ if time.Now().Before(initPhaseEnd) {
+ interval = s.initInterval
+ } else {
+ interval = s.pollInterval
+ }
+
+ timer := time.NewTimer(interval)
+ select {
+ case <-ctx.Done():
+ timer.Stop()
+ return
+ case <-s.closer.CloseNotify():
+ timer.Stop()
+ return
+ case <-timer.C:
+ // continue to next iteration
+ }
+ }
+}
+
+func (s *Service) queryDNSAndUpdateNodes(ctx context.Context) error {
+ // Record summary metrics
+ startTime := time.Now()
+ defer func() {
+ if s.metrics != nil {
+ duration := time.Since(startTime)
+ s.metrics.discoveryCount.Inc(1)
+ s.metrics.discoveryDuration.Observe(duration.Seconds())
+ s.metrics.discoveryTotalDuration.Inc(duration.Seconds())
+ }
+ }()
+
+ addresses, queryErr := s.queryAllSRVRecords(ctx)
+
+ if queryErr != nil {
+ s.log.Warn().Err(queryErr).Msg("DNS query failed, using last
successful cache")
+ addresses = s.lastSuccessfulDNS
+ if len(addresses) == 0 {
+ if s.metrics != nil {
+ s.metrics.discoveryFailedCount.Inc(1)
+ }
+ return fmt.Errorf("DNS query failed and no cached
addresses available: %w", queryErr)
+ }
+ } else {
+ s.lastSuccessfulDNS = addresses
+ if s.log.Debug().Enabled() {
+ s.log.Debug().
+ Int("count", len(addresses)).
+ Strs("addresses", addresses).
+ Strs("srv_addresses", s.srvAddresses).
+ Msg("DNS query successful")
+ }
+ }
+
+ // Update node cache based on DNS results
+ updateErr := s.updateNodeCache(ctx, addresses)
+ if updateErr != nil && s.metrics != nil {
+ s.metrics.discoveryFailedCount.Inc(1)
+ }
+ s.lastQueryMutex.Lock()
+ s.lastQueryTime = time.Now()
+ s.lastQueryMutex.Unlock()
+ return updateErr
+}
+
+func (s *Service) queryAllSRVRecords(ctx context.Context) ([]string, error) {
+ startTime := time.Now()
+ defer func() {
+ if s.metrics != nil {
+ duration := time.Since(startTime)
+ s.metrics.dnsQueryCount.Inc(1)
+ s.metrics.dnsQueryDuration.Observe(duration.Seconds())
+ s.metrics.dnsQueryTotalDuration.Inc(duration.Seconds())
+ }
+ }()
+
+ allAddresses := make(map[string]bool)
+ // track which SRV address (by index) each resolved address came from
+ newAddrToSRVIdx := make(map[string]int)
+ var queryErrors []error
+
+ for srvIdx, srvAddr := range s.srvAddresses {
+ _, addrs, lookupErr := s.resolver.LookupSRV(ctx, srvAddr)
+ if lookupErr != nil {
+ queryErrors = append(queryErrors, fmt.Errorf("lookup %s
failed: %w", srvAddr, lookupErr))
+ continue
+ }
+
+ for _, srv := range addrs {
+ address := fmt.Sprintf("%s:%d", srv.Target, srv.Port)
+ allAddresses[address] = true
+
+ // track which SRV address this resolved to (first-wins
strategy)
+ if _, exists := newAddrToSRVIdx[address]; !exists {
+ newAddrToSRVIdx[address] = srvIdx
+ }
+ }
+ }
+
+ // if there have any error occurred,
+ // then just return the query error to ignore the result to make sure
the cache correct
+ if len(queryErrors) > 0 {
+ if s.metrics != nil {
+ s.metrics.dnsQueryFailedCount.Inc(1)
+ }
+ return nil, errors.Join(queryErrors...)
+ }
+
+ // update the resolved address to SRV index mapping
+ s.resolvedAddrMutex.Lock()
+ s.resolvedAddrToSRVIdx = newAddrToSRVIdx
+ s.resolvedAddrMutex.Unlock()
+
+ // convert map to slice
+ result := make([]string, 0, len(allAddresses))
+ for addr := range allAddresses {
+ result = append(result, addr)
+ }
+
+ return result, nil
+}
+
+func (s *Service) updateNodeCache(ctx context.Context, addresses []string)
error {
+ addressSet := make(map[string]bool)
+ for _, addr := range addresses {
+ addressSet[addr] = true
+ }
+
+ var addErrors []error
+
+ for addr := range addressSet {
+ s.cacheMutex.RLock()
+ _, exists := s.nodeCache[addr]
+ s.cacheMutex.RUnlock()
+
+ if !exists {
+ // fetch node metadata from gRPC
+ node, fetchErr := s.fetchNodeMetadata(ctx, addr)
+ if fetchErr != nil {
+ s.log.Warn().
+ Err(fetchErr).
+ Str("address", addr).
+ Msg("Failed to fetch node metadata")
+ addErrors = append(addErrors, fetchErr)
+ continue
+ }
+
+ // update cache and notify handlers
+ s.cacheMutex.Lock()
+ s.nodeCache[addr] = node
+ s.cacheMutex.Unlock()
Review Comment:
There's a potential race condition in the cache check. Between releasing the
RLock on line 384 and acquiring the Lock on line 399, another goroutine could
have added the same node to the cache. This could lead to duplicate
fetchNodeMetadata calls and duplicate handler notifications for the same
address. Consider using a check-lock-check pattern or using a single Lock for
the entire operation.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]