This is an automated email from the ASF dual-hosted git repository. hanahmily pushed a commit to branch replica in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git
commit 3acbe1b804a45c2a6418f270b4c7df387b4e636c Author: Gao Hongtao <hanahm...@gmail.com> AuthorDate: Mon May 26 08:45:45 2025 +0800 Support configurable replica count on Group Signed-off-by: Gao Hongtao <hanahm...@gmail.com> --- CHANGES.md | 7 +- api/proto/banyandb/common/v1/common.proto | 11 ++ banyand/backup/lifecycle/service.go | 24 ++-- banyand/backup/lifecycle/steps.go | 96 +++++++------ banyand/backup/lifecycle/steps_test.go | 32 +++-- banyand/liaison/grpc/discovery.go | 10 ++ banyand/liaison/grpc/measure.go | 227 +++++++++++++++++++----------- banyand/liaison/grpc/node.go | 10 +- banyand/liaison/grpc/node_test.go | 2 +- banyand/liaison/grpc/property.go | 2 +- banyand/liaison/grpc/stream.go | 85 ++++++----- banyand/property/db.go | 33 ++++- docs/api-reference.md | 4 +- go.mod | 4 +- go.sum | 5 - pkg/logger/logger.go | 7 - pkg/meter/native/collection.go | 4 +- pkg/node/interface.go | 6 +- pkg/node/maglev.go | 123 ---------------- pkg/node/maglev_test.go | 133 ----------------- pkg/node/round_robin.go | 50 ++++--- pkg/node/round_robin_test.go | 40 +++--- 22 files changed, 402 insertions(+), 513 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index d4b39179..503d42c9 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,11 +9,15 @@ Release Notes. - Add sharding_key for TopNAggregation source measure - API: Update the data matching rule from the node selector to the stage name. - Add dynamical TLS load for the gRPC and HTTP server. +- Implement multiple groups query in one request. +- Replica: Replace Any with []byte Between Liaison and Data Nodes +- Replica: Support configurable replica count on Group. ### Bug Fixes - Fix the deadlock issue when loading a closed segment. - Fix the issue that the etcd watcher gets the historical node registration events. +- Fix the crash when collecting the metrics from a closed segment. ## 0.8.0 @@ -50,8 +54,6 @@ Release Notes. - UI: Add the `stages` to groups. - Add time range return value from stream local index filter. - Deduplicate the documents on building the series index. -- Implement multiple groups query in one request. -- Replica: Replace Any with []byte Between Liaison and Data Nodes ### Bug Fixes @@ -74,7 +76,6 @@ Release Notes. - UI: Fix the Stream List. - Fix the oom issue when loading too many unnecessary parts into memory. - bydbctl: Fix the bug that the bydbctl can't parse the absolute time flag. -- Fix the crash when collecting the metrics from a closed segment. ### Documentation diff --git a/api/proto/banyandb/common/v1/common.proto b/api/proto/banyandb/common/v1/common.proto index 2c6e5d75..733c68c2 100644 --- a/api/proto/banyandb/common/v1/common.proto +++ b/api/proto/banyandb/common/v1/common.proto @@ -82,6 +82,12 @@ message LifecycleStage { // Indicates whether segments that are no longer live should be closed. bool close = 6; + + // replicas is the number of replicas for this stage. + // This is an optional field and defaults to 0. + // A value of 0 means no replicas, while a value of 1 means one primary shard and one replica. + // Higher values indicate more replicas. + uint32 replicas = 7; } message ResourceOpts { @@ -95,6 +101,11 @@ message ResourceOpts { repeated LifecycleStage stages = 4; // default_stages is the name of the default stage repeated string default_stages = 5; + // replicas is the number of replicas. This is used to ensure high availability and fault tolerance. + // This is an optional field and defaults to 0. + // A value of 0 means no replicas, while a value of 1 means one primary shard and one replica. + // Higher values indicate more replicas. + uint32 replicas = 6; } // Group is an internal object for Group management diff --git a/banyand/backup/lifecycle/service.go b/banyand/backup/lifecycle/service.go index 8217c0a7..7338065d 100644 --- a/banyand/backup/lifecycle/service.go +++ b/banyand/backup/lifecycle/service.go @@ -195,7 +195,7 @@ func (l *lifecycleService) getGroupsToProcess(ctx context.Context, progress *Pro func (l *lifecycleService) processStreamGroup(ctx context.Context, g *commonv1.Group, streamSVC stream.Service, nodes []*databasev1.Node, labels map[string]string, progress *Progress, ) { - shardNum, selector, client, err := parseGroup(ctx, g, labels, nodes, l.l, l.metadata) + shardNum, replicas, selector, client, err := parseGroup(ctx, g, labels, nodes, l.l, l.metadata) if err != nil { l.l.Error().Err(err).Msgf("failed to parse group %s", g.Metadata.Name) return @@ -210,13 +210,13 @@ func (l *lifecycleService) processStreamGroup(ctx context.Context, g *commonv1.G tr := streamSVC.GetRemovalSegmentsTimeRange(g.Metadata.Name) - l.processStreams(ctx, g, ss, streamSVC, tr, shardNum, selector, client, progress) + l.processStreams(ctx, g, ss, streamSVC, tr, shardNum, replicas, selector, client, progress) l.deleteExpiredStreamSegments(ctx, g, tr, progress) } func (l *lifecycleService) processStreams(ctx context.Context, g *commonv1.Group, streams []*databasev1.Stream, - streamSVC stream.Service, tr *timestamp.TimeRange, shardNum uint32, selector node.Selector, client queue.Client, progress *Progress, + streamSVC stream.Service, tr *timestamp.TimeRange, shardNum uint32, replicas uint32, selector node.Selector, client queue.Client, progress *Progress, ) { for _, s := range streams { if progress.IsStreamCompleted(g.Metadata.Name, s.Metadata.Name) { @@ -224,7 +224,7 @@ func (l *lifecycleService) processStreams(ctx context.Context, g *commonv1.Group continue } - if sum, err := l.processSingleStream(ctx, s, streamSVC, tr, shardNum, selector, client); err == nil { + if sum, err := l.processSingleStream(ctx, s, streamSVC, tr, shardNum, replicas, selector, client); err == nil { l.l.Info().Msgf("migrated %d elements in stream %s", sum, s.Metadata.Name) } @@ -234,7 +234,7 @@ func (l *lifecycleService) processStreams(ctx context.Context, g *commonv1.Group } func (l *lifecycleService) processSingleStream(ctx context.Context, s *databasev1.Stream, - streamSVC stream.Service, tr *timestamp.TimeRange, shardNum uint32, selector node.Selector, client queue.Client, + streamSVC stream.Service, tr *timestamp.TimeRange, shardNum uint32, replicas uint32, selector node.Selector, client queue.Client, ) (int, error) { q, err := streamSVC.Stream(s.Metadata) if err != nil { @@ -268,7 +268,7 @@ func (l *lifecycleService) processSingleStream(ctx context.Context, s *databasev l.l.Error().Err(err).Msgf("failed to query stream %s", s.Metadata.Name) return 0, err } - return migrateStream(ctx, s, result, shardNum, selector, client, l.l), nil + return migrateStream(ctx, s, result, shardNum, replicas, selector, client, l.l), nil } func (l *lifecycleService) deleteExpiredStreamSegments(ctx context.Context, g *commonv1.Group, tr *timestamp.TimeRange, progress *Progress) { @@ -300,7 +300,7 @@ func (l *lifecycleService) deleteExpiredStreamSegments(ctx context.Context, g *c func (l *lifecycleService) processMeasureGroup(ctx context.Context, g *commonv1.Group, measureSVC measure.Service, nodes []*databasev1.Node, labels map[string]string, progress *Progress, ) { - shardNum, selector, client, err := parseGroup(ctx, g, labels, nodes, l.l, l.metadata) + shardNum, replicas, selector, client, err := parseGroup(ctx, g, labels, nodes, l.l, l.metadata) if err != nil { l.l.Error().Err(err).Msgf("failed to parse group %s", g.Metadata.Name) return @@ -315,13 +315,13 @@ func (l *lifecycleService) processMeasureGroup(ctx context.Context, g *commonv1. tr := measureSVC.GetRemovalSegmentsTimeRange(g.Metadata.Name) - l.processMeasures(ctx, g, mm, measureSVC, tr, shardNum, selector, client, progress) + l.processMeasures(ctx, g, mm, measureSVC, tr, shardNum, replicas, selector, client, progress) l.deleteExpiredMeasureSegments(ctx, g, tr, progress) } func (l *lifecycleService) processMeasures(ctx context.Context, g *commonv1.Group, measures []*databasev1.Measure, - measureSVC measure.Service, tr *timestamp.TimeRange, shardNum uint32, selector node.Selector, client queue.Client, progress *Progress, + measureSVC measure.Service, tr *timestamp.TimeRange, shardNum uint32, replicas uint32, selector node.Selector, client queue.Client, progress *Progress, ) { for _, m := range measures { if progress.IsMeasureCompleted(g.Metadata.Name, m.Metadata.Name) { @@ -329,7 +329,7 @@ func (l *lifecycleService) processMeasures(ctx context.Context, g *commonv1.Grou continue } - if sum, err := l.processSingleMeasure(ctx, m, measureSVC, tr, shardNum, selector, client); err == nil { + if sum, err := l.processSingleMeasure(ctx, m, measureSVC, tr, shardNum, replicas, selector, client); err == nil { l.l.Info().Msgf("migrated %d elements in measure %s", sum, m.Metadata.Name) } @@ -339,7 +339,7 @@ func (l *lifecycleService) processMeasures(ctx context.Context, g *commonv1.Grou } func (l *lifecycleService) processSingleMeasure(ctx context.Context, m *databasev1.Measure, - measureSVC measure.Service, tr *timestamp.TimeRange, shardNum uint32, selector node.Selector, client queue.Client, + measureSVC measure.Service, tr *timestamp.TimeRange, shardNum uint32, replicas uint32, selector node.Selector, client queue.Client, ) (int, error) { q, err := measureSVC.Measure(m.Metadata) if err != nil { @@ -378,7 +378,7 @@ func (l *lifecycleService) processSingleMeasure(ctx context.Context, m *database return 0, err } - return migrateMeasure(ctx, m, result, shardNum, selector, client, l.l), nil + return migrateMeasure(ctx, m, result, shardNum, replicas, selector, client, l.l), nil } func (l *lifecycleService) deleteExpiredMeasureSegments(ctx context.Context, g *commonv1.Group, tr *timestamp.TimeRange, progress *Progress) { diff --git a/banyand/backup/lifecycle/steps.go b/banyand/backup/lifecycle/steps.go index 620be9e9..b4e8db6d 100644 --- a/banyand/backup/lifecycle/steps.go +++ b/banyand/backup/lifecycle/steps.go @@ -111,26 +111,26 @@ func (l *lifecycleService) setupQuerySvc(ctx context.Context, streamDir, measure func parseGroup(ctx context.Context, g *commonv1.Group, nodeLabels map[string]string, nodes []*databasev1.Node, l *logger.Logger, metadata metadata.Repo, -) (uint32, node.Selector, queue.Client, error) { +) (uint32, uint32, node.Selector, queue.Client, error) { ro := g.ResourceOpts if ro == nil { - return 0, nil, nil, fmt.Errorf("no resource opts in group %s", g.Metadata.Name) + return 0, 0, nil, nil, fmt.Errorf("no resource opts in group %s", g.Metadata.Name) } if len(ro.Stages) == 0 { - return 0, nil, nil, fmt.Errorf("no stages in group %s", g.Metadata.Name) + return 0, 0, nil, nil, fmt.Errorf("no stages in group %s", g.Metadata.Name) } var nst *commonv1.LifecycleStage for i, st := range ro.Stages { selector, err := pub.ParseLabelSelector(st.NodeSelector) if err != nil { - return 0, nil, nil, errors.WithMessagef(err, "failed to parse node selector %s", st.NodeSelector) + return 0, 0, nil, nil, errors.WithMessagef(err, "failed to parse node selector %s", st.NodeSelector) } if !selector.Matches(nodeLabels) { continue } if i+1 >= len(ro.Stages) { l.Info().Msgf("no next stage for group %s at stage %s", g.Metadata.Name, st.Name) - return 0, nil, nil, nil + return 0, 0, nil, nil, nil } nst = ro.Stages[i+1] l.Info().Msgf("migrating group %s at stage %s to stage %s", g.Metadata.Name, st.Name, nst.Name) @@ -141,11 +141,11 @@ func parseGroup(ctx context.Context, g *commonv1.Group, nodeLabels map[string]st } nsl, err := pub.ParseLabelSelector(nst.NodeSelector) if err != nil { - return 0, nil, nil, errors.WithMessagef(err, "failed to parse node selector %s", nst.NodeSelector) + return 0, 0, nil, nil, errors.WithMessagef(err, "failed to parse node selector %s", nst.NodeSelector) } nodeSel := node.NewRoundRobinSelector("", metadata) if err = nodeSel.PreRun(ctx); err != nil { - return 0, nil, nil, errors.WithMessage(err, "failed to run node selector") + return 0, 0, nil, nil, errors.WithMessage(err, "failed to run node selector") } client := pub.NewWithoutMetadata() if g.Catalog == commonv1.Catalog_CATALOG_STREAM { @@ -170,18 +170,19 @@ func parseGroup(ctx context.Context, g *commonv1.Group, nodeLabels map[string]st } } if !existed { - return 0, nil, nil, errors.New("no nodes matched") + return 0, 0, nil, nil, errors.New("no nodes matched") } - return nst.ShardNum, nodeSel, client, nil + return nst.ShardNum, nst.Replicas, nodeSel, client, nil } func migrateStream(ctx context.Context, s *databasev1.Stream, result model.StreamQueryResult, - shardNum uint32, selector node.Selector, client queue.Client, l *logger.Logger, + shardNum uint32, replicas uint32, selector node.Selector, client queue.Client, l *logger.Logger, ) (sum int) { if result == nil { return 0 } defer result.Release() + copies := replicas + 1 entityLocator := partition.NewEntityLocator(s.TagFamilies, s.Entity, 0) @@ -208,22 +209,26 @@ func migrateStream(ctx context.Context, s *databasev1.Stream, result model.Strea l.Error().Err(err).Msg("failed to locate entity") continue } - nodeID, err := selector.Pick(s.Metadata.Group, s.Metadata.Name, uint32(shardID)) - if err != nil { - l.Error().Err(err).Msg("failed to pick node") - continue - } - iwr := &streamv1.InternalWriteRequest{ - Request: writeEntity, - ShardId: uint32(shardID), - SeriesHash: pbv1.HashEntity(entity), - EntityValues: tagValues[1:].Encode(), - } - message := bus.NewBatchMessageWithNode(bus.MessageID(time.Now().UnixNano()), nodeID, iwr) - _, err = batch.Publish(ctx, data.TopicStreamWrite, message) - if err != nil { - l.Error().Err(err).Msg("failed to publish message") - continue + + // Write to multiple replicas + for replicaID := uint32(0); replicaID < copies; replicaID++ { + nodeID, err := selector.Pick(s.Metadata.Group, s.Metadata.Name, uint32(shardID), replicaID) + if err != nil { + l.Error().Err(err).Msg("failed to pick node") + continue + } + iwr := &streamv1.InternalWriteRequest{ + Request: writeEntity, + ShardId: uint32(shardID), + SeriesHash: pbv1.HashEntity(entity), + EntityValues: tagValues[1:].Encode(), + } + message := bus.NewBatchMessageWithNode(bus.MessageID(time.Now().UnixNano()), nodeID, iwr) + _, err = batch.Publish(ctx, data.TopicStreamWrite, message) + if err != nil { + l.Error().Err(err).Msg("failed to publish message") + continue + } } sum++ } @@ -232,12 +237,13 @@ func migrateStream(ctx context.Context, s *databasev1.Stream, result model.Strea } func migrateMeasure(ctx context.Context, m *databasev1.Measure, result model.MeasureQueryResult, - shardNum uint32, selector node.Selector, client queue.Client, l *logger.Logger, + shardNum uint32, replicas uint32, selector node.Selector, client queue.Client, l *logger.Logger, ) (sum int) { if result == nil { return 0 } defer result.Release() + copies := replicas + 1 entityLocator := partition.NewEntityLocator(m.TagFamilies, m.Entity, 0) @@ -272,26 +278,28 @@ func migrateMeasure(ctx context.Context, m *databasev1.Measure, result model.Mea continue } - nodeID, err := selector.Pick(m.Metadata.Group, m.Metadata.Name, uint32(shardID)) - if err != nil { - l.Error().Err(err).Msg("failed to pick node") - continue - } + // Write to multiple replicas + for replicaID := uint32(0); replicaID < copies; replicaID++ { + nodeID, err := selector.Pick(m.Metadata.Group, m.Metadata.Name, uint32(shardID), replicaID) + if err != nil { + l.Error().Err(err).Msg("failed to pick node") + continue + } - iwr := &measurev1.InternalWriteRequest{ - Request: writeRequest, - ShardId: uint32(shardID), - SeriesHash: pbv1.HashEntity(entity), - EntityValues: tagValues[1:].Encode(), - } + iwr := &measurev1.InternalWriteRequest{ + Request: writeRequest, + ShardId: uint32(shardID), + SeriesHash: pbv1.HashEntity(entity), + EntityValues: tagValues[1:].Encode(), + } - message := bus.NewBatchMessageWithNode(bus.MessageID(time.Now().UnixNano()), nodeID, iwr) - _, err = batch.Publish(ctx, data.TopicMeasureWrite, message) - if err != nil { - l.Error().Err(err).Msg("failed to publish message") - } else { - sum++ + message := bus.NewBatchMessageWithNode(bus.MessageID(time.Now().UnixNano()), nodeID, iwr) + _, err = batch.Publish(ctx, data.TopicMeasureWrite, message) + if err != nil { + l.Error().Err(err).Msg("failed to publish message") + } } + sum++ } } return sum diff --git a/banyand/backup/lifecycle/steps_test.go b/banyand/backup/lifecycle/steps_test.go index cecb3799..9c3b2c7f 100644 --- a/banyand/backup/lifecycle/steps_test.go +++ b/banyand/backup/lifecycle/steps_test.go @@ -119,6 +119,7 @@ func TestMigrateStream(t *testing.T) { } shardNum := uint32(2) + replicas := uint32(0) // Create stream result data with expected timestamps timestamp1 := int64(1672531200000000000) // First timestamp in nanoseconds @@ -167,8 +168,7 @@ func TestMigrateStream(t *testing.T) { mockClient.EXPECT().NewBatchPublisher(gomock.Any()).Return(mockBatchPublisher) mockBatchPublisher.EXPECT().Close().Return(nil, nil) - mockSelector.EXPECT().Pick(stream.Metadata.Group, stream.Metadata.Name, gomock.Any()). - Return("node-1", nil).Times(2) + mockSelector.EXPECT().Pick(stream.Metadata.Group, stream.Metadata.Name, gomock.Any(), gomock.Any()).Return("node-1", nil).Times(2) // Expected element IDs encoded as base64 strings expectedElementID1 := base64.StdEncoding.EncodeToString(convert.Uint64ToBytes(elementID1)) @@ -218,7 +218,7 @@ func TestMigrateStream(t *testing.T) { return "", nil }).Times(2) - migrateStream(ctx, stream, queryResult, shardNum, mockSelector, mockClient, l) + migrateStream(ctx, stream, queryResult, shardNum, replicas, mockSelector, mockClient, l) assert.Equal(t, 1, queryResult.index) assert.Equal(t, 2, callCount, "Expected exactly 2 elements to be processed") @@ -276,6 +276,7 @@ func TestMigrateMeasure(t *testing.T) { } shardNum := uint32(2) + replicas := uint32(0) // Create measure result data with expected timestamps timestamp1 := int64(1672531200000000000) // First timestamp in nanoseconds @@ -338,7 +339,7 @@ func TestMigrateMeasure(t *testing.T) { mockClient.EXPECT().NewBatchPublisher(gomock.Any()).Return(mockBatchPublisher) mockBatchPublisher.EXPECT().Close().Return(nil, nil) - mockSelector.EXPECT().Pick(measure.Metadata.Group, measure.Metadata.Name, gomock.Any()). + mockSelector.EXPECT().Pick(measure.Metadata.Group, measure.Metadata.Name, gomock.Any(), gomock.Any()). Return("node-1", nil).Times(2) // Use a counter to check the correct element for each call @@ -390,7 +391,7 @@ func TestMigrateMeasure(t *testing.T) { return "", nil }).Times(2) - migrateMeasure(ctx, measure, queryResult, shardNum, mockSelector, mockClient, l) + migrateMeasure(ctx, measure, queryResult, shardNum, replicas, mockSelector, mockClient, l) assert.Equal(t, 1, queryResult.index) assert.Equal(t, 2, callCount, "Expected exactly 2 elements to be processed") @@ -403,14 +404,15 @@ func TestParseGroup(t *testing.T) { defer ctrl.Finish() tests := []struct { - group *commonv1.Group - nodeLabels map[string]string - name string - errorMessage string - nodes []*databasev1.Node - expectShard uint32 - expectError bool - expectResult bool + group *commonv1.Group + nodeLabels map[string]string + name string + errorMessage string + nodes []*databasev1.Node + expectShard uint32 + expectReplicas uint32 + expectError bool + expectResult bool }{ { name: "no resource opts", @@ -591,7 +593,7 @@ func TestParseGroup(t *testing.T) { t.Run(tt.name, func(t *testing.T) { mockRepo := metadata.NewMockRepo(ctrl) mockRepo.EXPECT().RegisterHandler("", schema.KindGroup, gomock.Any()).MaxTimes(1) - shardNum, selector, client, err := parseGroup(context.Background(), tt.group, tt.nodeLabels, tt.nodes, l, mockRepo) + shardNum, replicas, selector, client, err := parseGroup(context.Background(), tt.group, tt.nodeLabels, tt.nodes, l, mockRepo) if tt.expectError { require.Error(t, err) @@ -604,6 +606,7 @@ func TestParseGroup(t *testing.T) { require.Nil(t, selector) require.Nil(t, client) require.Equal(t, uint32(0), shardNum) + require.Equal(t, uint32(0), replicas) return } @@ -611,6 +614,7 @@ func TestParseGroup(t *testing.T) { require.NotNil(t, selector) require.NotNil(t, client) require.Equal(t, tt.expectShard, shardNum) + require.Equal(t, tt.expectReplicas, replicas) }) } } diff --git a/banyand/liaison/grpc/discovery.go b/banyand/liaison/grpc/discovery.go index 0df7b6ef..a7b77d8b 100644 --- a/banyand/liaison/grpc/discovery.go +++ b/banyand/liaison/grpc/discovery.go @@ -162,6 +162,16 @@ func (s *groupRepo) shardNum(groupName string) (uint32, bool) { return r.ShardNum, true } +func (s *groupRepo) copies(groupName string) (uint32, bool) { + s.RWMutex.RLock() + defer s.RWMutex.RUnlock() + r, ok := s.resourceOpts[groupName] + if !ok { + return 0, false + } + return r.Replicas + 1, true +} + func getID(metadata *commonv1.Metadata) identity { return identity{ name: metadata.GetName(), diff --git a/banyand/liaison/grpc/measure.go b/banyand/liaison/grpc/measure.go index 7dd38590..fb8dc9b6 100644 --- a/banyand/liaison/grpc/measure.go +++ b/banyand/liaison/grpc/measure.go @@ -47,13 +47,13 @@ type measureService struct { pipeline queue.Client broadcaster queue.Client *discoveryService - sampled *logger.Logger + l *logger.Logger metrics *metrics writeTimeout time.Duration } func (ms *measureService) setLogger(log *logger.Logger) { - ms.sampled = log.Sampled(10) + ms.l = log } func (ms *measureService) activeIngestionAccessLog(root string) (err error) { @@ -65,120 +65,189 @@ func (ms *measureService) activeIngestionAccessLog(root string) (err error) { } func (ms *measureService) Write(measure measurev1.MeasureService_WriteServer) error { - reply := func(metadata *commonv1.Metadata, status modelv1.Status, messageId uint64, measure measurev1.MeasureService_WriteServer, logger *logger.Logger) { - if status != modelv1.Status_STATUS_SUCCEED { - ms.metrics.totalStreamMsgReceivedErr.Inc(1, metadata.Group, "measure", "write") - } - ms.metrics.totalStreamMsgSent.Inc(1, metadata.Group, "measure", "write") - if errResp := measure.Send(&measurev1.WriteResponse{Metadata: metadata, Status: status.String(), MessageId: messageId}); errResp != nil { - logger.Debug().Err(errResp).Msg("failed to send measure write response") - ms.metrics.totalStreamMsgSentErr.Inc(1, metadata.Group, "measure", "write") - } - } ctx := measure.Context() publisher := ms.pipeline.NewBatchPublisher(ms.writeTimeout) ms.metrics.totalStreamStarted.Inc(1, "measure", "write") start := time.Now() var succeedSent []succeedSentMessage - defer func() { - cee, err := publisher.Close() - for _, s := range succeedSent { - code := modelv1.Status_STATUS_SUCCEED - if cee != nil { - if ce, ok := cee[s.node]; ok { - code = ce.Status() - } - } - reply(s.metadata, code, s.messageID, measure, ms.sampled) - } - if err != nil { - ms.sampled.Error().Err(err).Msg("failed to close the publisher") - } - ms.metrics.totalStreamFinished.Inc(1, "measure", "write") - ms.metrics.totalStreamLatency.Inc(time.Since(start).Seconds(), "measure", "write") - }() + + defer ms.handleWriteCleanup(publisher, &succeedSent, measure, start) + for { select { case <-ctx.Done(): return ctx.Err() default: } + writeRequest, err := measure.Recv() if errors.Is(err, io.EOF) { return nil } if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { - ms.sampled.Error().Err(err).Stringer("written", writeRequest).Msg("failed to receive message") + ms.l.Error().Err(err).Stringer("written", writeRequest).Msg("failed to receive message") } return err } + ms.metrics.totalStreamMsgReceived.Inc(1, writeRequest.Metadata.Group, "measure", "write") - if errTime := timestamp.CheckPb(writeRequest.DataPoint.Timestamp); errTime != nil { - ms.sampled.Error().Err(errTime).Stringer("written", writeRequest).Msg("the data point time is invalid") - reply(writeRequest.GetMetadata(), modelv1.Status_STATUS_INVALID_TIMESTAMP, writeRequest.GetMessageId(), measure, ms.sampled) + + if status := ms.validateWriteRequest(writeRequest, measure); status != modelv1.Status_STATUS_SUCCEED { continue } - if writeRequest.Metadata.ModRevision > 0 { - measureCache, existed := ms.entityRepo.getLocator(getID(writeRequest.GetMetadata())) - if !existed { - ms.sampled.Error().Err(err).Stringer("written", writeRequest).Msg("failed to measure schema not found") - reply(writeRequest.GetMetadata(), modelv1.Status_STATUS_NOT_FOUND, writeRequest.GetMessageId(), measure, ms.sampled) - continue - } - if writeRequest.Metadata.ModRevision != measureCache.ModRevision { - ms.sampled.Error().Stringer("written", writeRequest).Msg("the measure schema is expired") - reply(writeRequest.GetMetadata(), modelv1.Status_STATUS_EXPIRED_SCHEMA, writeRequest.GetMessageId(), measure, ms.sampled) - continue - } - } - entity, tagValues, shardID, err := ms.navigate(writeRequest.GetMetadata(), writeRequest.GetDataPoint().GetTagFamilies()) - if err != nil { - ms.sampled.Error().Err(err).RawJSON("written", logger.Proto(writeRequest)).Msg("failed to navigate to the write target") - reply(writeRequest.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeRequest.GetMessageId(), measure, ms.sampled) + + if err := ms.processAndPublishRequest(ctx, writeRequest, publisher, &succeedSent, measure); err != nil { continue } - if writeRequest.DataPoint.Version == 0 { - if writeRequest.MessageId == 0 { - writeRequest.MessageId = uint64(time.Now().UnixNano()) - } - writeRequest.DataPoint.Version = int64(writeRequest.MessageId) + } +} + +func (ms *measureService) validateWriteRequest(writeRequest *measurev1.WriteRequest, measure measurev1.MeasureService_WriteServer) modelv1.Status { + if errTime := timestamp.CheckPb(writeRequest.DataPoint.Timestamp); errTime != nil { + ms.l.Error().Err(errTime).Stringer("written", writeRequest).Msg("the data point time is invalid") + ms.sendReply(writeRequest.GetMetadata(), modelv1.Status_STATUS_INVALID_TIMESTAMP, writeRequest.GetMessageId(), measure) + return modelv1.Status_STATUS_INVALID_TIMESTAMP + } + + if writeRequest.Metadata.ModRevision > 0 { + measureCache, existed := ms.entityRepo.getLocator(getID(writeRequest.GetMetadata())) + if !existed { + ms.l.Error().Stringer("written", writeRequest).Msg("failed to measure schema not found") + ms.sendReply(writeRequest.GetMetadata(), modelv1.Status_STATUS_NOT_FOUND, writeRequest.GetMessageId(), measure) + return modelv1.Status_STATUS_NOT_FOUND } - if ms.ingestionAccessLog != nil { - if errAccessLog := ms.ingestionAccessLog.Write(writeRequest); errAccessLog != nil { - ms.sampled.Error().Err(errAccessLog).RawJSON("written", logger.Proto(writeRequest)).Msg("failed to write access log") - } + if writeRequest.Metadata.ModRevision != measureCache.ModRevision { + ms.l.Error().Stringer("written", writeRequest).Msg("the measure schema is expired") + ms.sendReply(writeRequest.GetMetadata(), modelv1.Status_STATUS_EXPIRED_SCHEMA, writeRequest.GetMessageId(), measure) + return modelv1.Status_STATUS_EXPIRED_SCHEMA + } + } + + return modelv1.Status_STATUS_SUCCEED +} + +func (ms *measureService) processAndPublishRequest(ctx context.Context, writeRequest *measurev1.WriteRequest, + publisher queue.BatchPublisher, succeedSent *[]succeedSentMessage, measure measurev1.MeasureService_WriteServer, +) error { + entity, tagValues, shardID, err := ms.navigate(writeRequest.GetMetadata(), writeRequest.GetDataPoint().GetTagFamilies()) + if err != nil { + ms.l.Error().Err(err).RawJSON("written", logger.Proto(writeRequest)).Msg("failed to navigate to the write target") + ms.sendReply(writeRequest.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeRequest.GetMessageId(), measure) + return err + } + + if writeRequest.DataPoint.Version == 0 { + if writeRequest.MessageId == 0 { + writeRequest.MessageId = uint64(time.Now().UnixNano()) } - iwr := &measurev1.InternalWriteRequest{ - Request: writeRequest, - ShardId: uint32(shardID), - SeriesHash: pbv1.HashEntity(entity), - EntityValues: tagValues[1:].Encode(), + writeRequest.DataPoint.Version = int64(writeRequest.MessageId) + } + + if ms.ingestionAccessLog != nil { + if errAccessLog := ms.ingestionAccessLog.Write(writeRequest); errAccessLog != nil { + ms.l.Error().Err(errAccessLog).RawJSON("written", logger.Proto(writeRequest)).Msg("failed to write access log") } - nodeID, errPickNode := ms.nodeRegistry.Locate(writeRequest.GetMetadata().GetGroup(), writeRequest.GetMetadata().GetName(), uint32(shardID)) + } + + iwr := &measurev1.InternalWriteRequest{ + Request: writeRequest, + ShardId: uint32(shardID), + SeriesHash: pbv1.HashEntity(entity), + EntityValues: tagValues[1:].Encode(), + } + + nodes, err := ms.publishToNodes(ctx, writeRequest, iwr, publisher, uint32(shardID), measure) + if err != nil { + return err + } + + *succeedSent = append(*succeedSent, succeedSentMessage{ + metadata: writeRequest.GetMetadata(), + messageID: writeRequest.GetMessageId(), + nodes: nodes, + }) + + return nil +} + +func (ms *measureService) publishToNodes(ctx context.Context, writeRequest *measurev1.WriteRequest, iwr *measurev1.InternalWriteRequest, + publisher queue.BatchPublisher, shardID uint32, measure measurev1.MeasureService_WriteServer, +) ([]string, error) { + copies, ok := ms.groupRepo.copies(writeRequest.Metadata.GetGroup()) + if !ok { + ms.l.Error().RawJSON("written", logger.Proto(writeRequest)).Msg("failed to get the group copies") + ms.sendReply(writeRequest.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeRequest.GetMessageId(), measure) + return nil, errors.New("failed to get group copies") + } + + nodes := make([]string, 0, copies) + for i := range copies { + nodeID, errPickNode := ms.nodeRegistry.Locate(writeRequest.GetMetadata().GetGroup(), writeRequest.GetMetadata().GetName(), shardID, i) if errPickNode != nil { - ms.sampled.Error().Err(errPickNode).RawJSON("written", logger.Proto(writeRequest)).Msg("failed to pick an available node") - reply(writeRequest.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeRequest.GetMessageId(), measure, ms.sampled) - continue + ms.l.Error().Err(errPickNode).RawJSON("written", logger.Proto(writeRequest)).Msg("failed to pick an available node") + ms.sendReply(writeRequest.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeRequest.GetMessageId(), measure) + return nil, errPickNode } + message := bus.NewBatchMessageWithNode(bus.MessageID(time.Now().UnixNano()), nodeID, iwr) _, errWritePub := publisher.Publish(ctx, data.TopicMeasureWrite, message) if errWritePub != nil { - ms.sampled.Error().Err(errWritePub).RawJSON("written", logger.Proto(writeRequest)).Str("nodeID", nodeID).Msg("failed to send a message") + ms.l.Error().Err(errWritePub).RawJSON("written", logger.Proto(writeRequest)).Str("nodeID", nodeID).Msg("failed to send a message") var ce *common.Error if errors.As(errWritePub, &ce) { - reply(writeRequest.GetMetadata(), ce.Status(), writeRequest.GetMessageId(), measure, ms.sampled) - continue + ms.sendReply(writeRequest.GetMetadata(), ce.Status(), writeRequest.GetMessageId(), measure) + return nil, errWritePub + } + ms.sendReply(writeRequest.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeRequest.GetMessageId(), measure) + return nil, errWritePub + } + nodes = append(nodes, nodeID) + } + + return nodes, nil +} + +func (ms *measureService) sendReply(metadata *commonv1.Metadata, status modelv1.Status, messageID uint64, measure measurev1.MeasureService_WriteServer) { + if status != modelv1.Status_STATUS_SUCCEED { + ms.metrics.totalStreamMsgReceivedErr.Inc(1, metadata.Group, "measure", "write") + } + ms.metrics.totalStreamMsgSent.Inc(1, metadata.Group, "measure", "write") + if errResp := measure.Send(&measurev1.WriteResponse{Metadata: metadata, Status: status.String(), MessageId: messageID}); errResp != nil { + if dl := ms.l.Debug(); dl.Enabled() { + dl.Err(errResp).Msg("failed to send measure write response") + } + ms.metrics.totalStreamMsgSentErr.Inc(1, metadata.Group, "measure", "write") + } +} + +func (ms *measureService) handleWriteCleanup(publisher queue.BatchPublisher, succeedSent *[]succeedSentMessage, + measure measurev1.MeasureService_WriteServer, start time.Time, +) { + cee, err := publisher.Close() + for _, s := range *succeedSent { + code := modelv1.Status_STATUS_SUCCEED + if cee != nil { + for _, node := range s.nodes { + if ce, ok := cee[node]; ok { + code = ce.Status() + if ce.Status() == modelv1.Status_STATUS_SUCCEED { + code = modelv1.Status_STATUS_SUCCEED + break + } + } } - reply(writeRequest.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeRequest.GetMessageId(), measure, ms.sampled) - continue } - succeedSent = append(succeedSent, succeedSentMessage{ - metadata: writeRequest.GetMetadata(), - messageID: writeRequest.GetMessageId(), - node: nodeID, - }) + ms.sendReply(s.metadata, code, s.messageID, measure) + } + if err != nil { + ms.l.Error().Err(err).Msg("failed to close the publisher") + } + if dl := ms.l.Debug(); dl.Enabled() { + dl.Int("total_requests", len(*succeedSent)).Msg("completed measure write batch") } + ms.metrics.totalStreamFinished.Inc(1, "measure", "write") + ms.metrics.totalStreamLatency.Inc(time.Since(start).Seconds(), "measure", "write") } var emptyMeasureQueryResponse = &measurev1.QueryResponse{DataPoints: make([]*measurev1.DataPoint, 0)} @@ -280,6 +349,6 @@ func (ms *measureService) Close() error { type succeedSentMessage struct { metadata *commonv1.Metadata - node string + nodes []string messageID uint64 } diff --git a/banyand/liaison/grpc/node.go b/banyand/liaison/grpc/node.go index 6436ce8c..60879973 100644 --- a/banyand/liaison/grpc/node.go +++ b/banyand/liaison/grpc/node.go @@ -39,7 +39,7 @@ var ( // NodeRegistry is for locating data node with group/name of the metadata // together with the shardID calculated from the incoming data. type NodeRegistry interface { - Locate(group, name string, shardID uint32) (string, error) + Locate(group, name string, shardID, replicaID uint32) (string, error) fmt.Stringer } @@ -66,10 +66,10 @@ func NewClusterNodeRegistry(topic bus.Topic, pipeline queue.Client, selector nod return nr } -func (n *clusterNodeService) Locate(group, name string, shardID uint32) (string, error) { - nodeID, err := n.sel.Pick(group, name, shardID) +func (n *clusterNodeService) Locate(group, name string, shardID, replicaID uint32) (string, error) { + nodeID, err := n.sel.Pick(group, name, shardID, replicaID) if err != nil { - return "", errors.Wrapf(err, "fail to locate %s/%s(%d)", group, name, shardID) + return "", errors.Wrapf(err, "fail to locate %s/%s(%d,%d)", group, name, shardID, replicaID) } return nodeID, nil } @@ -114,6 +114,6 @@ func NewLocalNodeRegistry() NodeRegistry { } // Locate of localNodeService always returns local. -func (localNodeService) Locate(_, _ string, _ uint32) (string, error) { +func (localNodeService) Locate(_, _ string, _, _ uint32) (string, error) { return "local", nil } diff --git a/banyand/liaison/grpc/node_test.go b/banyand/liaison/grpc/node_test.go index b39c60b5..999d78d1 100644 --- a/banyand/liaison/grpc/node_test.go +++ b/banyand/liaison/grpc/node_test.go @@ -55,7 +55,7 @@ func TestClusterNodeRegistry(t *testing.T) { }, }, }) - nodeID, err := cnr.Locate("metrics", "instance_traffic", 0) + nodeID, err := cnr.Locate("metrics", "instance_traffic", 0, 0) assert.NoError(t, err) assert.Equal(t, fakeNodeID, nodeID) } diff --git a/banyand/liaison/grpc/property.go b/banyand/liaison/grpc/property.go index 63a2cd88..b26569be 100644 --- a/banyand/liaison/grpc/property.go +++ b/banyand/liaison/grpc/property.go @@ -167,7 +167,7 @@ func (ps *propertyServer) Apply(ctx context.Context, req *propertyv1.ApplyReques if err != nil { return nil, err } - node, err := ps.nodeRegistry.Locate(g, entity, uint32(id)) + node, err := ps.nodeRegistry.Locate(g, entity, uint32(id), 0) if err != nil { return nil, err } diff --git a/banyand/liaison/grpc/stream.go b/banyand/liaison/grpc/stream.go index c7c0c245..3f467ca7 100644 --- a/banyand/liaison/grpc/stream.go +++ b/banyand/liaison/grpc/stream.go @@ -47,13 +47,13 @@ type streamService struct { pipeline queue.Client broadcaster queue.Client *discoveryService - sampled *logger.Logger + l *logger.Logger metrics *metrics writeTimeout time.Duration } func (s *streamService) setLogger(log *logger.Logger) { - s.sampled = log.Sampled(10) + s.l = log } func (s *streamService) activeIngestionAccessLog(root string) (err error) { @@ -69,9 +69,11 @@ func (s *streamService) Write(stream streamv1.StreamService_WriteServer) error { if status != modelv1.Status_STATUS_SUCCEED { s.metrics.totalStreamMsgReceivedErr.Inc(1, metadata.Group, "stream", "write") } - s.metrics.totalStreamMsgReceived.Inc(1, metadata.Group, "stream", "write") + s.metrics.totalStreamMsgSent.Inc(1, metadata.Group, "stream", "write") if errResp := stream.Send(&streamv1.WriteResponse{Metadata: metadata, Status: status.String(), MessageId: messageId}); errResp != nil { - logger.Debug().Err(errResp).Msg("failed to send stream write response") + if dl := logger.Debug(); dl.Enabled() { + dl.Err(errResp).Msg("failed to send stream write response") + } s.metrics.totalStreamMsgSentErr.Inc(1, metadata.Group, "stream", "write") } } @@ -79,19 +81,26 @@ func (s *streamService) Write(stream streamv1.StreamService_WriteServer) error { publisher := s.pipeline.NewBatchPublisher(s.writeTimeout) start := time.Now() var succeedSent []succeedSentMessage + requestCount := 0 defer func() { cee, err := publisher.Close() for _, ssm := range succeedSent { code := modelv1.Status_STATUS_SUCCEED if cee != nil { - if ce, ok := cee[ssm.node]; ok { - code = ce.Status() + for _, node := range ssm.nodes { + if ce, ok := cee[node]; ok { + code = ce.Status() + break + } } } - reply(ssm.metadata, code, ssm.messageID, stream, s.sampled) + reply(ssm.metadata, code, ssm.messageID, stream, s.l) } if err != nil { - s.sampled.Error().Err(err).Msg("failed to close the publisher") + s.l.Error().Err(err).Msg("failed to close the publisher") + } + if dl := s.l.Debug(); dl.Enabled() { + dl.Int("total_requests", requestCount).Msg("completed stream write batch") } s.metrics.totalStreamFinished.Inc(1, "stream", "write") s.metrics.totalStreamLatency.Inc(time.Since(start).Seconds(), "stream", "write") @@ -109,38 +118,39 @@ func (s *streamService) Write(stream streamv1.StreamService_WriteServer) error { } if err != nil { if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { - s.sampled.Error().Stringer("written", writeEntity).Err(err).Msg("failed to receive message") + s.l.Error().Stringer("written", writeEntity).Err(err).Msg("failed to receive message") } return err } + requestCount++ s.metrics.totalStreamMsgReceived.Inc(1, writeEntity.Metadata.Group, "stream", "write") if errTime := timestamp.CheckPb(writeEntity.GetElement().Timestamp); errTime != nil { - s.sampled.Error().Stringer("written", writeEntity).Err(errTime).Msg("the element time is invalid") - reply(nil, modelv1.Status_STATUS_INVALID_TIMESTAMP, writeEntity.GetMessageId(), stream, s.sampled) + s.l.Error().Stringer("written", writeEntity).Err(errTime).Msg("the element time is invalid") + reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_INVALID_TIMESTAMP, writeEntity.GetMessageId(), stream, s.l) continue } if writeEntity.Metadata.ModRevision > 0 { streamCache, existed := s.entityRepo.getLocator(getID(writeEntity.GetMetadata())) if !existed { - s.sampled.Error().Err(err).Stringer("written", writeEntity).Msg("failed to stream schema not found") - reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_NOT_FOUND, writeEntity.GetMessageId(), stream, s.sampled) + s.l.Error().Err(err).Stringer("written", writeEntity).Msg("failed to stream schema not found") + reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_NOT_FOUND, writeEntity.GetMessageId(), stream, s.l) continue } if writeEntity.Metadata.ModRevision != streamCache.ModRevision { - s.sampled.Error().Stringer("written", writeEntity).Msg("the stream schema is expired") - reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_EXPIRED_SCHEMA, writeEntity.GetMessageId(), stream, s.sampled) + s.l.Error().Stringer("written", writeEntity).Msg("the stream schema is expired") + reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_EXPIRED_SCHEMA, writeEntity.GetMessageId(), stream, s.l) continue } } entity, tagValues, shardID, err := s.navigate(writeEntity.GetMetadata(), writeEntity.GetElement().GetTagFamilies()) if err != nil { - s.sampled.Error().Err(err).RawJSON("written", logger.Proto(writeEntity)).Msg("failed to navigate to the write target") - reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeEntity.GetMessageId(), stream, s.sampled) + s.l.Error().Err(err).RawJSON("written", logger.Proto(writeEntity)).Msg("failed to navigate to the write target") + reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeEntity.GetMessageId(), stream, s.l) continue } if s.ingestionAccessLog != nil { if errAccessLog := s.ingestionAccessLog.Write(writeEntity); errAccessLog != nil { - s.sampled.Error().Err(errAccessLog).Msg("failed to write ingestion access log") + s.l.Error().Err(errAccessLog).Msg("failed to write ingestion access log") } } iwr := &streamv1.InternalWriteRequest{ @@ -149,28 +159,39 @@ func (s *streamService) Write(stream streamv1.StreamService_WriteServer) error { SeriesHash: pbv1.HashEntity(entity), EntityValues: tagValues[1:].Encode(), } - nodeID, errPickNode := s.nodeRegistry.Locate(writeEntity.GetMetadata().GetGroup(), writeEntity.GetMetadata().GetName(), uint32(shardID)) - if errPickNode != nil { - s.sampled.Error().Err(errPickNode).RawJSON("written", logger.Proto(writeEntity)).Msg("failed to pick an available node") - reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeEntity.GetMessageId(), stream, s.sampled) + copies, ok := s.groupRepo.copies(writeEntity.Metadata.GetGroup()) + if !ok { + s.l.Error().RawJSON("written", logger.Proto(writeEntity)).Msg("failed to get the group copies") + reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeEntity.GetMessageId(), stream, s.l) continue } - message := bus.NewBatchMessageWithNode(bus.MessageID(time.Now().UnixNano()), nodeID, iwr) - _, errWritePub := publisher.Publish(ctx, data.TopicStreamWrite, message) - if errWritePub != nil { - var ce *common.Error - if errors.As(errWritePub, &ce) { - reply(writeEntity.GetMetadata(), ce.Status(), writeEntity.GetMessageId(), stream, s.sampled) + nodes := make([]string, 0, copies) + for i := range copies { + nodeID, errPickNode := s.nodeRegistry.Locate(writeEntity.GetMetadata().GetGroup(), writeEntity.GetMetadata().GetName(), uint32(shardID), i) + if errPickNode != nil { + s.l.Error().Err(errPickNode).RawJSON("written", logger.Proto(writeEntity)).Msg("failed to pick an available node") + reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeEntity.GetMessageId(), stream, s.l) continue } - s.sampled.Error().Err(errWritePub).RawJSON("written", logger.Proto(writeEntity)).Str("nodeID", nodeID).Msg("failed to send a message") - reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeEntity.GetMessageId(), stream, s.sampled) - continue + message := bus.NewBatchMessageWithNode(bus.MessageID(time.Now().UnixNano()), nodeID, iwr) + _, errWritePub := publisher.Publish(ctx, data.TopicStreamWrite, message) + if errWritePub != nil { + s.l.Error().Err(errWritePub).RawJSON("written", logger.Proto(writeEntity)).Str("nodeID", nodeID).Msg("failed to send a message") + var ce *common.Error + if errors.As(errWritePub, &ce) { + reply(writeEntity.GetMetadata(), ce.Status(), writeEntity.GetMessageId(), stream, s.l) + continue + } + reply(writeEntity.GetMetadata(), modelv1.Status_STATUS_INTERNAL_ERROR, writeEntity.GetMessageId(), stream, s.l) + continue + } + nodes = append(nodes, nodeID) } + succeedSent = append(succeedSent, succeedSentMessage{ metadata: writeEntity.GetMetadata(), messageID: writeEntity.GetMessageId(), - node: nodeID, + nodes: nodes, }) } } diff --git a/banyand/property/db.go b/banyand/property/db.go index 1f41dbdd..8cc28629 100644 --- a/banyand/property/db.go +++ b/banyand/property/db.go @@ -22,6 +22,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -55,6 +56,7 @@ type database struct { location string flushInterval time.Duration closed atomic.Bool + mu sync.RWMutex } func openDB(ctx context.Context, location string, flushInterval time.Duration, omr observability.MetricsRegistry, lfs fs.FileSystem) (*database, error) { @@ -142,18 +144,22 @@ func (db *database) query(ctx context.Context, req *propertyv1.QueryRequest) ([] } func (db *database) loadShard(ctx context.Context, id common.ShardID) (*shard, error) { - sLst := db.sLst.Load() - if sLst != nil { - for _, s := range *sLst { - if s.id == id { - return s, nil - } - } + if db.closed.Load() { + return nil, errors.New("database is closed") + } + if s, ok := db.getShard(id); ok { + return s, nil + } + db.mu.Lock() + defer db.mu.Unlock() + if s, ok := db.getShard(id); ok { + return s, nil } sd, err := db.newShard(context.WithValue(ctx, logger.ContextKey, db.logger), id, int64(db.flushInterval.Seconds())) if err != nil { return nil, err } + sLst := db.sLst.Load() if sLst == nil { sLst = &[]*shard{} } @@ -162,6 +168,19 @@ func (db *database) loadShard(ctx context.Context, id common.ShardID) (*shard, e return sd, nil } +func (db *database) getShard(id common.ShardID) (*shard, bool) { + sLst := db.sLst.Load() + if sLst == nil { + return nil, false + } + for _, s := range *sLst { + if s.id == id { + return s, true + } + } + return nil, false +} + func (db *database) close() error { if db.closed.Swap(true) { return nil diff --git a/docs/api-reference.md b/docs/api-reference.md index 760ce0a0..302ebdfa 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -441,6 +441,7 @@ IntervalRule is a structured duration | ttl | [IntervalRule](#banyandb-common-v1-IntervalRule) | | Specifies the time-to-live for data in this stage before moving to the next. This is also a required field using the IntervalRule structure. | | node_selector | [string](#string) | | Node selector specifying target nodes for this stage. Optional; if provided, it must be a non-empty string. | | close | [bool](#bool) | | Indicates whether segments that are no longer live should be closed. | +| replicas | [uint32](#uint32) | | replicas is the number of replicas for this stage. This is an optional field and defaults to 0. A value of 0 means no replicas, while a value of 1 means one primary shard and one replica. Higher values indicate more replicas. | @@ -474,11 +475,12 @@ Metadata is for multi-tenant, multi-model use | Field | Type | Label | Description | | ----- | ---- | ----- | ----------- | -| shard_num | [uint32](#uint32) | | shard_num is the number of shards | +| shard_num | [uint32](#uint32) | | shard_num is the number of primary shards | | segment_interval | [IntervalRule](#banyandb-common-v1-IntervalRule) | | segment_interval indicates the length of a segment | | ttl | [IntervalRule](#banyandb-common-v1-IntervalRule) | | ttl indicates time to live, how long the data will be cached | | stages | [LifecycleStage](#banyandb-common-v1-LifecycleStage) | repeated | stages defines the ordered lifecycle stages. Data progresses through these stages sequentially. | | default_stages | [string](#string) | repeated | default_stages is the name of the default stage | +| replicas | [uint32](#uint32) | | replicas is the number of replicas. This is used to ensure high availability and fault tolerance. This is an optional field and defaults to 0. A value of 0 means no replicas, while a value of 1 means one primary shard and one replica. Higher values indicate more replicas. | diff --git a/go.mod b/go.mod index 8e796dcd..bff48b0c 100644 --- a/go.mod +++ b/go.mod @@ -16,11 +16,9 @@ require ( github.com/go-chi/chi/v5 v5.2.1 github.com/go-resty/resty/v2 v2.16.5 github.com/google/go-cmp v0.7.0 - github.com/google/uuid v1.6.0 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.3.1 github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 github.com/hashicorp/golang-lru v1.0.2 - github.com/kkdai/maglev v0.2.0 github.com/minio/minio-go/v7 v7.0.90 github.com/montanaflynn/stats v0.7.1 github.com/oklog/run v1.1.0 @@ -85,6 +83,7 @@ require ( github.com/goccy/go-json v0.10.5 // indirect github.com/golang/snappy v1.0.0 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/kamstrup/intmap v0.5.1 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/machinebox/graphql v0.2.2 // indirect @@ -123,7 +122,6 @@ require ( github.com/coreos/go-semver v0.3.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/dchest/siphash v1.2.3 // indirect github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33 // indirect github.com/dustin/go-humanize v1.0.1 github.com/fsnotify/fsnotify v1.9.0 diff --git a/go.sum b/go.sum index f46b46f5..4662c1e1 100644 --- a/go.sum +++ b/go.sum @@ -112,9 +112,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dchest/siphash v1.2.2/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= -github.com/dchest/siphash v1.2.3 h1:QXwFc8cFOR2dSa/gE6o/HokBMWtLUaNDVd+22aKHeEA= -github.com/dchest/siphash v1.2.3/go.mod h1:0NvQU092bT0ipiFN++/rXm69QG9tVxLAlQHIXMPAkHc= github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33 h1:ucRHb6/lvW/+mTEIGbvhcYU3S8+uSNkuMjx/qZFfhtM= github.com/dgryski/go-metro v0.0.0-20250106013310-edb8663e5e33/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= github.com/docker/cli v27.4.1+incompatible h1:VzPiUlRJ/xh+otB75gva3r05isHMo5wXDfPRi5/b4hI= @@ -240,8 +237,6 @@ github.com/kamstrup/intmap v0.5.1 h1:ENGAowczZA+PJPYYlreoqJvWgQVtAmX1l899WfYFVK0 github.com/kamstrup/intmap v0.5.1/go.mod h1:gWUVWHKzWj8xpJVFf5GC0O26bWmv3GqdnIX/LMT6Aq4= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/kkdai/maglev v0.2.0 h1:w6DCW0kAA6fstZqXkrBrlgIC3jeIRXkjOYea/m6EK/Y= -github.com/kkdai/maglev v0.2.0/go.mod h1:d+mt8Lmt3uqi9aRb/BnPjzD0fy+ETs1vVXiGRnqHVZ4= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index aa17d57b..518a4f4c 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -85,13 +85,6 @@ func (l *Logger) Named(name ...string) *Logger { return &Logger{module: module, modules: l.modules, development: l.development, Logger: &subLogger, isDefaultLevel: isDefaultLevel} } -// Sampled return a Logger with a sampler that will send every Nth events. -func (l *Logger) Sampled(n uint32) *Logger { - sampled := l.Logger.Sample(&zerolog.BasicSampler{N: n}) - l.Logger = &sampled - return l -} - // ToZapConfig outputs the zap config is derived from l. func (l *Logger) ToZapConfig() zap.Config { level, err := zap.ParseAtomicLevel(l.GetLevel().String()) diff --git a/pkg/meter/native/collection.go b/pkg/meter/native/collection.go index 9eb0a5a3..59c83542 100644 --- a/pkg/meter/native/collection.go +++ b/pkg/meter/native/collection.go @@ -35,7 +35,7 @@ import ( // NodeSelector has Locate method to select a nodeId. type NodeSelector interface { - Locate(group, name string, shardID uint32) (string, error) + Locate(group, name string, shardID, replicaID uint32) (string, error) fmt.Stringer } @@ -79,7 +79,7 @@ func (m *MetricCollection) FlushMetrics() { var err error // only liaison node has a non-nil nodeSelector if m.nodeSelector != nil { - nodeID, err = m.nodeSelector.Locate(iwr.GetRequest().GetMetadata().GetGroup(), iwr.GetRequest().GetMetadata().GetName(), uint32(0)) + nodeID, err = m.nodeSelector.Locate(iwr.GetRequest().GetMetadata().GetGroup(), iwr.GetRequest().GetMetadata().GetName(), uint32(0), uint32(0)) if err != nil { log.Error().Err(err).Msg("Failed to locate nodeID") } diff --git a/pkg/node/interface.go b/pkg/node/interface.go index 6c6dc766..c0e91490 100644 --- a/pkg/node/interface.go +++ b/pkg/node/interface.go @@ -45,7 +45,7 @@ type Selector interface { AddNode(node *databasev1.Node) RemoveNode(node *databasev1.Node) SetNodeSelector(selector *pub.LabelSelector) - Pick(group, name string, shardID uint32) (string, error) + Pick(group, name string, shardID, replicaID uint32) (string, error) run.PreRunner fmt.Stringer } @@ -69,7 +69,7 @@ func (p *pickFirstSelector) SetNodeSelector(_ *pub.LabelSelector) {} // String implements Selector. func (p *pickFirstSelector) String() string { - n, err := p.Pick("", "", 0) + n, err := p.Pick("", "", 0, 0) if err != nil { return fmt.Sprintf("%v", err) } @@ -119,7 +119,7 @@ func (p *pickFirstSelector) RemoveNode(node *databasev1.Node) { p.nodeIDs = slices.Delete(p.nodeIDs, idx, idx+1) } -func (p *pickFirstSelector) Pick(_, _ string, _ uint32) (string, error) { +func (p *pickFirstSelector) Pick(_, _ string, _, _ uint32) (string, error) { p.mu.RLock() defer p.mu.RUnlock() if len(p.nodeIDs) == 0 { diff --git a/pkg/node/maglev.go b/pkg/node/maglev.go deleted file mode 100644 index 2327790b..00000000 --- a/pkg/node/maglev.go +++ /dev/null @@ -1,123 +0,0 @@ -// Licensed to Apache Software Foundation (ASF) under one or more contributor -// license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright -// ownership. Apache Software Foundation (ASF) licenses this file to you under -// the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package node - -import ( - "context" - "fmt" - "sort" - "strconv" - "sync" - - "github.com/kkdai/maglev" - - databasev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1" - "github.com/apache/skywalking-banyandb/banyand/queue/pub" -) - -const lookupTableSize = 65537 - -var _ Selector = (*maglevSelector)(nil) - -type maglevSelector struct { - routers sync.Map - nodes []string - mutex sync.RWMutex -} - -func (m *maglevSelector) SetNodeSelector(_ *pub.LabelSelector) {} - -// String implements Selector. -func (m *maglevSelector) String() string { - var groups []string - m.routers.Range(func(key, _ any) bool { - groups = append(groups, key.(string)) - return true - }) - m.mutex.RLock() - defer m.mutex.Unlock() - return fmt.Sprintf("nodes:%s groups:%s", m.nodes, groups) -} - -func (m *maglevSelector) Name() string { - return "maglev-selector" -} - -func (m *maglevSelector) PreRun(context.Context) error { - return nil -} - -func (m *maglevSelector) AddNode(node *databasev1.Node) { - m.mutex.Lock() - defer m.mutex.Unlock() - for i := range m.nodes { - if m.nodes[i] == node.GetMetadata().GetName() { - return - } - } - m.nodes = append(m.nodes, node.GetMetadata().GetName()) - sort.StringSlice(m.nodes).Sort() - m.routers.Range(func(_, value any) bool { - _ = value.(*maglev.Maglev).Set(m.nodes) - return true - }) -} - -func (m *maglevSelector) RemoveNode(node *databasev1.Node) { - m.mutex.Lock() - defer m.mutex.Unlock() - for i := range m.nodes { - if m.nodes[i] == node.GetMetadata().GetName() { - m.nodes = append(m.nodes[:i], m.nodes[i+1:]...) - break - } - } - m.routers.Range(func(_, value any) bool { - _ = value.(*maglev.Maglev).Set(m.nodes) - return true - }) -} - -func (m *maglevSelector) Pick(group, name string, shardID uint32) (string, error) { - router, ok := m.routers.Load(group) - if ok { - return router.(*maglev.Maglev).Get(formatSearchKey(name, shardID)) - } - m.mutex.Lock() - defer m.mutex.Unlock() - router, ok = m.routers.Load(group) - if ok { - return router.(*maglev.Maglev).Get(formatSearchKey(name, shardID)) - } - - mTab, err := maglev.NewMaglev(m.nodes, lookupTableSize) - if err != nil { - return "", err - } - m.routers.Store(group, mTab) - return mTab.Get(formatSearchKey(name, shardID)) -} - -// NewMaglevSelector creates a new backend selector based on Maglev hashing algorithm. -func NewMaglevSelector() Selector { - return &maglevSelector{} -} - -func formatSearchKey(name string, shardID uint32) string { - return name + "-" + strconv.FormatUint(uint64(shardID), 10) -} diff --git a/pkg/node/maglev_test.go b/pkg/node/maglev_test.go deleted file mode 100644 index 74505a3a..00000000 --- a/pkg/node/maglev_test.go +++ /dev/null @@ -1,133 +0,0 @@ -// Licensed to Apache Software Foundation (ASF) under one or more contributor -// license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright -// ownership. Apache Software Foundation (ASF) licenses this file to you under -// the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package node - -import ( - "fmt" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - - commonv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/common/v1" - databasev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1" -) - -const ( - dataNodeTemplate = "data-node-%d" - targetEpsilon = 0.1 -) - -func TestMaglevSelector(t *testing.T) { - sel := NewMaglevSelector() - sel.AddNode(&databasev1.Node{ - Metadata: &commonv1.Metadata{ - Name: "data-node-1", - }, - }) - sel.AddNode(&databasev1.Node{ - Metadata: &commonv1.Metadata{ - Name: "data-node-2", - }, - }) - nodeID1, err := sel.Pick("sw_metrics", "traffic_instance", 0) - assert.NoError(t, err) - assert.Contains(t, []string{"data-node-1", "data-node-2"}, nodeID1) - nodeID2, err := sel.Pick("sw_metrics", "traffic_instance", 0) - assert.NoError(t, err) - assert.Equal(t, nodeID2, nodeID1) -} - -func TestMaglevSelector_EvenDistribution(t *testing.T) { - sel := NewMaglevSelector() - dataNodeNum := 10 - for i := 0; i < dataNodeNum; i++ { - sel.AddNode(&databasev1.Node{ - Metadata: &commonv1.Metadata{ - Name: fmt.Sprintf(dataNodeTemplate, i), - }, - }) - } - counterMap := make(map[string]int) - trialCount := 100_000 - for j := 0; j < trialCount; j++ { - dataNodeID, _ := sel.Pick("sw_metrics", uuid.NewString(), 0) - val, exist := counterMap[dataNodeID] - if !exist { - counterMap[dataNodeID] = 1 - } else { - counterMap[dataNodeID] = val + 1 - } - } - assert.Len(t, counterMap, dataNodeNum) - for _, count := range counterMap { - assert.InEpsilon(t, trialCount/dataNodeNum, count, targetEpsilon) - } -} - -func TestMaglevSelector_DiffNode(t *testing.T) { - fullSel := NewMaglevSelector() - brokenSel := NewMaglevSelector() - dataNodeNum := 10 - for i := 0; i < dataNodeNum; i++ { - fullSel.AddNode(&databasev1.Node{ - Metadata: &commonv1.Metadata{ - Name: fmt.Sprintf(dataNodeTemplate, i), - }, - }) - if i != dataNodeNum-1 { - brokenSel.AddNode(&databasev1.Node{ - Metadata: &commonv1.Metadata{ - Name: fmt.Sprintf(dataNodeTemplate, i), - }, - }) - } - } - diff := 0 - trialCount := 100_000 - for j := 0; j < trialCount; j++ { - metricName := uuid.NewString() - fullDataNodeID, _ := fullSel.Pick("sw_metrics", metricName, 0) - brokenDataNodeID, _ := brokenSel.Pick("sw_metrics", metricName, 0) - if fullDataNodeID != brokenDataNodeID { - diff++ - } - } - assert.InEpsilon(t, trialCount/dataNodeNum, diff, targetEpsilon*2) -} - -func BenchmarkMaglevSelector_Pick(b *testing.B) { - sel := NewMaglevSelector() - dataNodeNum := 10 - for i := 0; i < dataNodeNum; i++ { - sel.AddNode(&databasev1.Node{ - Metadata: &commonv1.Metadata{ - Name: fmt.Sprintf(dataNodeTemplate, i), - }, - }) - } - metricsCount := 10_000 - metricNames := make([]string, 0, metricsCount) - for i := 0; i < metricsCount; i++ { - metricNames = append(metricNames, uuid.NewString()) - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = sel.Pick("sw_metrics", metricNames[i%metricsCount], 0) - } -} diff --git a/pkg/node/round_robin.go b/pkg/node/round_robin.go index 354539a0..60201db4 100644 --- a/pkg/node/round_robin.go +++ b/pkg/node/round_robin.go @@ -51,13 +51,16 @@ func (r *roundRobinSelector) String() string { defer r.mu.RUnlock() result := make(map[string]string) for _, entry := range r.lookupTable { - n, err := r.Pick(entry.group, "", entry.shardID) - key := fmt.Sprintf("%s-%d", entry.group, entry.shardID) - if err != nil { - result[key] = fmt.Sprintf("%v", err) - continue + copies := entry.replicas + 1 + for i := range copies { + n, err := r.Pick(entry.group, "", entry.shardID, i) + key := fmt.Sprintf("%s-%d-%d", entry.group, entry.shardID, i) + if err != nil { + result[key] = fmt.Sprintf("%v", err) + continue + } + result[key] = n } - result[key] = n } if len(result) < 1 { return "" @@ -105,8 +108,7 @@ func (r *roundRobinSelector) OnAddOrUpdate(schemaMetadata schema.Metadata) { defer r.mu.Unlock() r.removeGroup(group.Metadata.Name) for i := uint32(0); i < group.ResourceOpts.ShardNum; i++ { - k := key{group: group.Metadata.Name, shardID: i} - r.lookupTable = append(r.lookupTable, k) + r.lookupTable = append(r.lookupTable, newKey(group.Metadata.Name, i, group.ResourceOpts.Replicas)) } r.sortEntries() } @@ -157,8 +159,7 @@ func (r *roundRobinSelector) OnInit(kinds []schema.Kind) (bool, []int64) { revision = g.Metadata.ModRevision } for i := uint32(0); i < g.ResourceOpts.ShardNum; i++ { - k := key{group: g.Metadata.Name, shardID: i} - r.lookupTable = append(r.lookupTable, k) + r.lookupTable = append(r.lookupTable, newKey(g.Metadata.Name, i, g.ResourceOpts.Replicas)) } } r.sortEntries() @@ -186,7 +187,7 @@ func (r *roundRobinSelector) RemoveNode(node *databasev1.Node) { } } -func (r *roundRobinSelector) Pick(group, _ string, shardID uint32) (string, error) { +func (r *roundRobinSelector) Pick(group, _ string, shardID, replicaID uint32) (string, error) { r.mu.RLock() defer r.mu.RUnlock() k := key{group: group, shardID: shardID} @@ -199,8 +200,8 @@ func (r *roundRobinSelector) Pick(group, _ string, shardID uint32) (string, erro } return r.lookupTable[i].group > group }) - if i < len(r.lookupTable) && r.lookupTable[i] == k { - return r.selectNode(i), nil + if i < len(r.lookupTable) && r.lookupTable[i].equal(k) { + return r.selectNode(i, replicaID), nil } return "", fmt.Errorf("%s-%d is a unknown shard", group, shardID) } @@ -215,9 +216,9 @@ func (r *roundRobinSelector) sortEntries() { }) } -func (r *roundRobinSelector) selectNode(entry any) string { - index := entry.(int) - return r.nodes[index%len(r.nodes)] +func (r *roundRobinSelector) selectNode(index int, replicasID uint32) string { + adjustedIndex := index + int(replicasID) + return r.nodes[adjustedIndex%len(r.nodes)] } func validateGroup(group *commonv1.Group) bool { @@ -231,6 +232,19 @@ func validateGroup(group *commonv1.Group) bool { } type key struct { - group string - shardID uint32 + group string + shardID uint32 + replicas uint32 +} + +func (k key) equal(other key) bool { + return k.group == other.group && k.shardID == other.shardID +} + +func newKey(group string, shardID, replicas uint32) key { + return key{ + group: group, + shardID: shardID, + replicas: replicas, + } } diff --git a/pkg/node/round_robin_test.go b/pkg/node/round_robin_test.go index 6eb37334..1668946b 100644 --- a/pkg/node/round_robin_test.go +++ b/pkg/node/round_robin_test.go @@ -30,16 +30,16 @@ import ( func TestPickEmptySelector(t *testing.T) { selector := NewRoundRobinSelector("test", nil) setupGroup(selector) - _, err := selector.Pick("group1", "", 0) + _, err := selector.Pick("group1", "", 0, 0) assert.Error(t, err) } func TestPickUnknownGroup(t *testing.T) { selector := NewRoundRobinSelector("test", nil) - _, err := selector.Pick("group1", "", 0) + _, err := selector.Pick("group1", "", 0, 0) assert.Error(t, err) setupGroup(selector) - _, err = selector.Pick("group1", "", 100) + _, err = selector.Pick("group1", "", 100, 0) assert.Error(t, err) } @@ -47,7 +47,7 @@ func TestPickSingleSelection(t *testing.T) { selector := NewRoundRobinSelector("test", nil) setupGroup(selector) selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node1"}}) - node, err := selector.Pick("group1", "", 0) + node, err := selector.Pick("group1", "", 0, 0) assert.NoError(t, err) assert.Equal(t, "node1", node) } @@ -58,11 +58,11 @@ func TestPickMultipleSelections(t *testing.T) { selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node1"}}) selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node2"}}) - _, err := selector.Pick("group1", "", 1) + _, err := selector.Pick("group1", "", 1, 0) assert.NoError(t, err) - node1, err := selector.Pick("group1", "", 0) + node1, err := selector.Pick("group1", "", 0, 0) assert.NoError(t, err) - node2, err := selector.Pick("group1", "", 1) + node2, err := selector.Pick("group1", "", 1, 0) assert.NoError(t, err) assert.NotEqual(t, node1, node2, "Different shardIDs in the same group should not result in the same node") } @@ -73,7 +73,7 @@ func TestPickNodeRemoval(t *testing.T) { selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node1"}}) selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node2"}}) selector.RemoveNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node1"}}) - node, err := selector.Pick("group1", "", 0) + node, err := selector.Pick("group1", "", 0, 0) assert.NoError(t, err) assert.Equal(t, "node2", node) } @@ -84,15 +84,15 @@ func TestPickConsistentSelectionAfterRemoval(t *testing.T) { selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node1"}}) selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node2"}}) selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node3"}}) - _, err := selector.Pick("group1", "", 0) + _, err := selector.Pick("group1", "", 0, 0) assert.NoError(t, err) - _, err = selector.Pick("group1", "", 1) + _, err = selector.Pick("group1", "", 1, 0) assert.NoError(t, err) - node, err := selector.Pick("group1", "", 1) + node, err := selector.Pick("group1", "", 1, 0) assert.NoError(t, err) assert.Equal(t, "node2", node) selector.RemoveNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node2"}}) - node, err = selector.Pick("group1", "", 1) + node, err = selector.Pick("group1", "", 1, 0) assert.NoError(t, err) assert.Equal(t, "node3", node) } @@ -104,10 +104,10 @@ func TestCleanupGroup(t *testing.T) { setupGroup(selector) selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node1"}}) selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node2"}}) - _, err := selector.Pick("group1", "", 0) + _, err := selector.Pick("group1", "", 0, 0) assert.NoError(t, err) selector.OnDelete(groupSchema) - _, err = selector.Pick("group1", "", 0) + _, err = selector.Pick("group1", "", 0, 0) assert.Error(t, err) } @@ -139,21 +139,21 @@ func TestChangeShard(t *testing.T) { setupGroup(selector) selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node1"}}) selector.AddNode(&databasev1.Node{Metadata: &commonv1.Metadata{Name: "node2"}}) - _, err := selector.Pick("group1", "", 0) + _, err := selector.Pick("group1", "", 0, 0) assert.NoError(t, err) - _, err = selector.Pick("group1", "", 1) + _, err = selector.Pick("group1", "", 1, 0) assert.NoError(t, err) // Reduce shard number to 1 selector.OnAddOrUpdate(groupSchema1) - _, err = selector.Pick("group1", "", 0) + _, err = selector.Pick("group1", "", 0, 0) assert.NoError(t, err) - _, err = selector.Pick("group1", "", 1) + _, err = selector.Pick("group1", "", 1, 0) assert.Error(t, err) // Restore shard number to 2 setupGroup(selector) - node1, err := selector.Pick("group1", "", 0) + node1, err := selector.Pick("group1", "", 0, 0) assert.NoError(t, err) - node2, err := selector.Pick("group1", "", 1) + node2, err := selector.Pick("group1", "", 1, 0) assert.NoError(t, err) assert.NotEqual(t, node1, node2) }