This is an automated email from the ASF dual-hosted git repository.

hanahmily pushed a commit to branch stream-load
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git

commit 157cc9d3681e08cb299424d07c74b30e413be7b3
Author: Gao Hongtao <hanahm...@gmail.com>
AuthorDate: Wed Apr 24 07:48:21 2024 +0000

    Remove empty shards
    
    Signed-off-by: Gao Hongtao <hanahm...@gmail.com>
---
 banyand/internal/storage/retention.go        | 10 +--
 banyand/internal/storage/tsdb.go             | 94 +++++++++++++++++-----------
 banyand/stream/iter_builder.go               |  6 +-
 banyand/stream/write.go                      |  6 +-
 test/cases/stream/data/want/sort_filter.yaml | 17 +++++
 test/stress/trace/trace_suite_test.go        |  4 +-
 6 files changed, 87 insertions(+), 50 deletions(-)

diff --git a/banyand/internal/storage/retention.go 
b/banyand/internal/storage/retention.go
index 445b7f56..87e25ff0 100644
--- a/banyand/internal/storage/retention.go
+++ b/banyand/internal/storage/retention.go
@@ -51,13 +51,13 @@ func newRetentionTask[T TSTable, O any](database 
*database[T, O], ttl IntervalRu
 }
 
 func (rc *retentionTask[T, O]) run(now time.Time, l *logger.Logger) bool {
-       var shardList []*shard[T, O]
-       rc.database.RLock()
-       shardList = append(shardList, rc.database.sLst...)
-       rc.database.RUnlock()
+       shardList := rc.database.sLst.Load()
+       if shardList == nil {
+               return false
+       }
        deadline := now.Add(-rc.duration)
 
-       for _, shard := range shardList {
+       for _, shard := range *shardList {
                if err := shard.segmentController.remove(deadline); err != nil {
                        l.Error().Err(err)
                }
diff --git a/banyand/internal/storage/tsdb.go b/banyand/internal/storage/tsdb.go
index 7479b5de..698827a6 100644
--- a/banyand/internal/storage/tsdb.go
+++ b/banyand/internal/storage/tsdb.go
@@ -74,18 +74,20 @@ type database[T TSTable, O any] struct {
        scheduler       *timestamp.Scheduler
        p               common.Position
        location        string
-       sLst            []*shard[T, O]
+       sLst            atomic.Pointer[[]*shard[T, O]]
        opts            TSDBOpts[T, O]
        sync.RWMutex
-       sLen uint32
 }
 
 func (d *database[T, O]) Close() error {
        d.Lock()
        defer d.Unlock()
        d.scheduler.Close()
-       for _, s := range d.sLst {
-               s.close()
+       sLst := d.sLst.Load()
+       if sLst != nil {
+               for _, s := range *sLst {
+                       s.close()
+               }
        }
        d.lock.Close()
        if err := lfs.DeleteFile(d.lock.Path()); err != nil {
@@ -139,58 +141,77 @@ func OpenTSDB[T TSTable, O any](ctx context.Context, opts 
TSDBOpts[T, O]) (TSDB[
 }
 
 func (d *database[T, O]) CreateTSTableIfNotExist(shardID common.ShardID, ts 
time.Time) (TSTableWrapper[T], error) {
-       id := uint32(shardID)
-       if id >= atomic.LoadUint32(&d.sLen) {
-               return func() (TSTableWrapper[T], error) {
-                       d.Lock()
-                       defer d.Unlock()
-                       if int(id) >= len(d.sLst) {
-                               for i := len(d.sLst); i <= int(id); i++ {
-                                       d.logger.Info().Int("shard_id", 
i).Msg("creating a shard")
-                                       if err := d.registerShard(i); err != 
nil {
-                                               return nil, err
-                                       }
-                               }
+       if s, ok := d.getShard(shardID); ok {
+               d.RLock()
+               defer d.RUnlock()
+               return d.createTSTTable(s, ts)
+       }
+       d.Lock()
+       defer d.Unlock()
+       if s, ok := d.getShard(shardID); ok {
+               return d.createTSTTable(s, ts)
+       }
+       d.logger.Info().Int("shard_id", int(shardID)).Msg("creating a shard")
+       s, err := d.registerShard(shardID)
+       if err != nil {
+               return nil, err
+       }
+       return d.createTSTTable(s, ts)
+}
+
+func (d *database[T, O]) getShard(shardID common.ShardID) (*shard[T, O], bool) 
{
+       sLst := d.sLst.Load()
+       if sLst != nil {
+               for _, s := range *sLst {
+                       if s.id == shardID {
+                               return s, true
                        }
-                       return d.createTSTTable(shardID, ts)
-               }()
+               }
        }
-       d.RLock()
-       defer d.RUnlock()
-       return d.createTSTTable(shardID, ts)
+       return nil, false
 }
 
-func (d *database[T, O]) createTSTTable(shardID common.ShardID, ts time.Time) 
(TSTableWrapper[T], error) {
+func (d *database[T, O]) createTSTTable(shard *shard[T, O], ts time.Time) 
(TSTableWrapper[T], error) {
        timeRange := timestamp.NewInclusiveTimeRange(ts, ts)
-       ss := d.sLst[shardID].segmentController.selectTSTables(timeRange)
+       ss := shard.segmentController.selectTSTables(timeRange)
        if len(ss) > 0 {
                return ss[0], nil
        }
-       return d.sLst[shardID].segmentController.createTSTable(ts)
+       return shard.segmentController.createTSTable(ts)
 }
 
 func (d *database[T, O]) SelectTSTables(timeRange timestamp.TimeRange) 
[]TSTableWrapper[T] {
        var result []TSTableWrapper[T]
-       d.RLock()
-       for i := range d.sLst {
-               result = append(result, 
d.sLst[i].segmentController.selectTSTables(timeRange)...)
+       sLst := d.sLst.Load()
+       if sLst == nil {
+               return result
+       }
+       for _, s := range *sLst {
+               result = append(result, 
s.segmentController.selectTSTables(timeRange)...)
        }
-       d.RUnlock()
        return result
 }
 
-func (d *database[T, O]) registerShard(id int) error {
+func (d *database[T, O]) registerShard(id common.ShardID) (*shard[T, O], 
error) {
+       if s, ok := d.getShard(id); ok {
+               return s, nil
+       }
        ctx := context.WithValue(context.Background(), logger.ContextKey, 
d.logger)
        ctx = common.SetPosition(ctx, func(p common.Position) common.Position {
                return d.p
        })
-       so, err := d.openShard(ctx, common.ShardID(id))
+       so, err := d.openShard(ctx, id)
        if err != nil {
-               return err
+               return nil, err
        }
-       d.sLst = append(d.sLst, so)
-       d.sLen++
-       return nil
+       var shardList []*shard[T, O]
+       sLst := d.sLst.Load()
+       if sLst != nil {
+               shardList = *sLst
+       }
+       shardList = append(shardList, so)
+       d.sLst.Store(&shardList)
+       return so, nil
 }
 
 func (d *database[T, O]) loadDatabase() error {
@@ -204,8 +225,9 @@ func (d *database[T, O]) loadDatabase() error {
                if shardID >= int(d.opts.ShardNum) {
                        return nil
                }
-               d.logger.Info().Int("shard_id", shardID).Msg("opening a existed 
shard")
-               return d.registerShard(shardID)
+               d.logger.Info().Int("shard_id", shardID).Msg("loaded a existed 
shard")
+               _, err = d.registerShard(common.ShardID(shardID))
+               return err
        })
 }
 
diff --git a/banyand/stream/iter_builder.go b/banyand/stream/iter_builder.go
index da6c0179..f51acb7f 100644
--- a/banyand/stream/iter_builder.go
+++ b/banyand/stream/iter_builder.go
@@ -25,7 +25,6 @@ import (
        databasev1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
        "github.com/apache/skywalking-banyandb/banyand/internal/storage"
        "github.com/apache/skywalking-banyandb/pkg/index"
-       "github.com/apache/skywalking-banyandb/pkg/index/posting"
        pbv1 "github.com/apache/skywalking-banyandb/pkg/pb/v1"
 )
 
@@ -53,15 +52,14 @@ func (s *stream) buildSeriesByIndex(tableWrappers 
[]storage.TSTableWrapper[*tsTa
        if !tl.valid() {
                return nil, fmt.Errorf("sorted tag %s not found in tag 
projection", sortedTag)
        }
-       var pl posting.List
        for _, tw := range tableWrappers {
                seriesFilter := make(map[common.SeriesID]filterFn)
                if sso.Filter != nil {
                        for i := range sids {
-                               pl, err = sso.Filter.Execute(func(ruleType 
databasev1.IndexRule_Type) (index.Searcher, error) {
+                               pl, errExe := sso.Filter.Execute(func(ruleType 
databasev1.IndexRule_Type) (index.Searcher, error) {
                                        return tw.Table().Index().store, nil
                                }, sids[i])
-                               if err != nil {
+                               if errExe != nil {
                                        return nil, err
                                }
 
diff --git a/banyand/stream/write.go b/banyand/stream/write.go
index fead18c0..bd0b1f31 100644
--- a/banyand/stream/write.go
+++ b/banyand/stream/write.go
@@ -78,13 +78,13 @@ func (w *writeCallback) handle(dst 
map[string]*elementsInGroup, writeEvent *stre
        }
        shardID := common.ShardID(writeEvent.ShardId)
        if et == nil {
-               tstb, err := tsdb.CreateTSTableIfNotExist(shardID, t)
+               tsdb, err := tsdb.CreateTSTableIfNotExist(shardID, t)
                if err != nil {
                        return nil, fmt.Errorf("cannot create ts table: %w", 
err)
                }
                et = &elementsInTable{
-                       timeRange: tstb.GetTimeRange(),
-                       tsTable:   tstb,
+                       timeRange: tsdb.GetTimeRange(),
+                       tsTable:   tsdb,
                }
                eg.tables = append(eg.tables, et)
        }
diff --git a/test/cases/stream/data/want/sort_filter.yaml 
b/test/cases/stream/data/want/sort_filter.yaml
index 6726bea8..50de8723 100644
--- a/test/cases/stream/data/want/sort_filter.yaml
+++ b/test/cases/stream/data/want/sort_filter.yaml
@@ -1,3 +1,20 @@
+# Licensed to Apache Software Foundation (ASF) under one or more contributor
+# license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright
+# ownership. Apache Software Foundation (ASF) licenses this file to you under
+# the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
 elements:
   - elementId: "4"
     tagFamilies:
diff --git a/test/stress/trace/trace_suite_test.go 
b/test/stress/trace/trace_suite_test.go
index ee033708..0aaee55b 100644
--- a/test/stress/trace/trace_suite_test.go
+++ b/test/stress/trace/trace_suite_test.go
@@ -70,10 +70,10 @@ var _ = Describe("Query", func() {
        })
 
        It("Metric", func() {
-               query.ServiceList(basePath, timeout, 6, fs)
+               query.ServiceList(basePath, timeout, 1, fs)
        })
 
        It("TopN", func() {
-               query.TopN(basePath, timeout, 6, fs)
+               query.TopN(basePath, timeout, 1, fs)
        })
 })

Reply via email to