This is an automated email from the ASF dual-hosted git repository.

wusheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git


The following commit(s) were added to refs/heads/main by this push:
     new 0a7795c6 Fix the oom issue when loading too many unnecessary parts 
into memory (#631)
0a7795c6 is described below

commit 0a7795c69c740ac29ffc3bf8cb1c4e744dd7aa7d
Author: Gao Hongtao <[email protected]>
AuthorDate: Wed Mar 26 11:34:17 2025 +0800

    Fix the oom issue when loading too many unnecessary parts into memory (#631)
---
 .asf.yaml                       |   1 +
 CHANGES.md                      |   1 +
 banyand/stream/block_scanner.go | 185 ++++++++++++++++++++++------------------
 banyand/stream/metadata_test.go |   2 +-
 banyand/stream/query_by_ts.go   |  72 +++++++++-------
 banyand/stream/snapshot.go      |  44 ++++++++++
 banyand/stream/snapshot_test.go |  23 +++++
 7 files changed, 214 insertions(+), 114 deletions(-)

diff --git a/.asf.yaml b/.asf.yaml
index c2dd73ef..c0a7bee3 100644
--- a/.asf.yaml
+++ b/.asf.yaml
@@ -41,5 +41,6 @@ github:
           - Continuous Integration
       required_pull_request_reviews:
         dismiss_stale_reviews: true
+        required_approving_review_count: 1
     # Protect 0.7.x release branch
     v0.7.x: {}
diff --git a/CHANGES.md b/CHANGES.md
index de8bd8d0..335db585 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -54,6 +54,7 @@ Release Notes.
 - UI: Implement TopNAggregation data query page.
 - UI: Update BanyanDB UI to Integrate New Property Query API.
 - UI: Fix the Stream List.
+- Fix the oom issue when loading too many unnecessary parts into memory.
 
 ### Documentation
 
diff --git a/banyand/stream/block_scanner.go b/banyand/stream/block_scanner.go
index 7e3d89df..d7347265 100644
--- a/banyand/stream/block_scanner.go
+++ b/banyand/stream/block_scanner.go
@@ -21,13 +21,11 @@ import (
        "context"
        "fmt"
        "sort"
-       "sync"
 
        "github.com/apache/skywalking-banyandb/api/common"
        modelv1 
"github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
        "github.com/apache/skywalking-banyandb/banyand/internal/storage"
        "github.com/apache/skywalking-banyandb/banyand/protector"
-       "github.com/apache/skywalking-banyandb/pkg/cgroups"
        "github.com/apache/skywalking-banyandb/pkg/index/posting"
        "github.com/apache/skywalking-banyandb/pkg/index/posting/roaring"
        "github.com/apache/skywalking-banyandb/pkg/logger"
@@ -80,90 +78,130 @@ func releaseBlockScanResultBatch(bsb 
*blockScanResultBatch) {
 
 var blockScanResultBatchPool = 
pool.Register[*blockScanResultBatch]("stream-blockScannerBatch")
 
-var shardScanConcurrencyCh = make(chan struct{}, cgroups.CPUs())
-
-type blockScanner struct {
-       segment   storage.Segment[*tsTable, *option]
-       pm        *protector.Memory
-       l         *logger.Logger
-       series    []*pbv1.Series
-       seriesIDs []uint64
-       qo        queryOptions
-}
-
-func (q *blockScanner) searchSeries(ctx context.Context) error {
+func searchSeries(ctx context.Context, qo queryOptions, segment 
storage.Segment[*tsTable, *option], series []*pbv1.Series) (queryOptions, 
error) {
        seriesFilter := roaring.NewPostingList()
-       sl, err := q.segment.Lookup(ctx, q.series)
+       sl, err := segment.Lookup(ctx, series)
        if err != nil {
-               return err
+               return qo, err
        }
        for i := range sl {
                if seriesFilter.Contains(uint64(sl[i].ID)) {
                        continue
                }
                seriesFilter.Insert(uint64(sl[i].ID))
-               if q.qo.seriesToEntity == nil {
-                       q.qo.seriesToEntity = 
make(map[common.SeriesID][]*modelv1.TagValue)
+               if qo.seriesToEntity == nil {
+                       qo.seriesToEntity = 
make(map[common.SeriesID][]*modelv1.TagValue)
                }
-               q.qo.seriesToEntity[sl[i].ID] = sl[i].EntityValues
-               q.qo.sortedSids = append(q.qo.sortedSids, sl[i].ID)
+               qo.seriesToEntity[sl[i].ID] = sl[i].EntityValues
+               qo.sortedSids = append(qo.sortedSids, sl[i].ID)
        }
        if seriesFilter.IsEmpty() {
-               return nil
+               return qo, nil
        }
-       q.seriesIDs = seriesFilter.ToSlice()
-       sort.Slice(q.qo.sortedSids, func(i, j int) bool { return 
q.qo.sortedSids[i] < q.qo.sortedSids[j] })
-       return nil
+       sort.Slice(qo.sortedSids, func(i, j int) bool { return qo.sortedSids[i] 
< qo.sortedSids[j] })
+       return qo, nil
 }
 
-func (q *blockScanner) scanShardsInParallel(ctx context.Context, wg 
*sync.WaitGroup, blockCh chan *blockScanResultBatch) []scanFinalizer {
-       tabs := q.segment.Tables()
-       finalizers := make([]scanFinalizer, len(tabs))
+func getBlockScanner(ctx context.Context, segment storage.Segment[*tsTable, 
*option], qo queryOptions,
+       l *logger.Logger, pm *protector.Memory,
+) (bc *blockScanner, err error) {
+       tabs := segment.Tables()
+       finalizers := make([]scanFinalizer, 0, len(tabs)+1)
+       finalizers = append(finalizers, segment.DecRef)
+       defer func() {
+               if bc == nil || err != nil {
+                       for i := range finalizers {
+                               finalizers[i]()
+                       }
+               }
+       }()
+       var parts []*part
+       var size, offset int
+       filterIndex := make(map[uint64]posting.List)
        for i := range tabs {
-               select {
-               case shardScanConcurrencyCh <- struct{}{}:
-               case <-ctx.Done():
-                       return finalizers
+               snp := tabs[i].currentSnapshot()
+               parts, size = snp.getParts(parts, qo.minTimestamp, 
qo.maxTimestamp)
+               if size < 1 {
+                       snp.decRef()
+                       continue
                }
-               wg.Add(1)
-               go func(idx int, tab *tsTable) {
-                       finalizers[idx] = q.scanBlocks(ctx, q.seriesIDs, tab, 
blockCh)
-                       wg.Done()
-                       <-shardScanConcurrencyCh
-               }(i, tabs[i])
-       }
-       return finalizers
+               finalizers = append(finalizers, snp.decRef)
+               filter, err := search(ctx, qo, qo.sortedSids, tabs[i])
+               if err != nil {
+                       return nil, err
+               }
+               for j := offset; j < offset+size; j++ {
+                       filterIndex[parts[j].partMetadata.ID] = filter
+               }
+               offset += size
+       }
+       if len(parts) < 1 {
+               return nil, nil
+       }
+       var asc bool
+       if qo.Order == nil {
+               asc = true
+       } else {
+               asc = qo.Order.Sort == modelv1.Sort_SORT_ASC || qo.Order.Sort 
== modelv1.Sort_SORT_UNSPECIFIED
+       }
+       return &blockScanner{
+               parts:       getDisjointParts(parts, asc),
+               filterIndex: filterIndex,
+               qo:          qo,
+               asc:         asc,
+               l:           l,
+               pm:          pm,
+               finalizers:  finalizers,
+       }, nil
 }
 
-func (q *blockScanner) scanBlocks(ctx context.Context, seriesList []uint64, 
tab *tsTable, blockCh chan *blockScanResultBatch) (sf scanFinalizer) {
-       s := tab.currentSnapshot()
-       if s == nil {
-               return nil
+func search(ctx context.Context, qo queryOptions, seriesList 
[]common.SeriesID, tw *tsTable) (posting.List, error) {
+       if qo.Filter == nil || qo.Filter == logicalstream.ENode {
+               return nil, nil
+       }
+       sid := make([]uint64, len(seriesList))
+       for i := range seriesList {
+               sid[i] = uint64(seriesList[i])
        }
-       sf = s.decRef
-       filter, err := q.indexSearch(ctx, seriesList, tab)
+       pl, err := tw.Index().Search(ctx, sid, qo.Filter)
        if err != nil {
-               select {
-               case blockCh <- &blockScanResultBatch{err: err}:
-               case <-ctx.Done():
-               }
-               return
+               return nil, err
        }
-       select {
-       case <-ctx.Done():
-               return
-       default:
+       if pl == nil {
+               return roaring.DummyPostingList, nil
        }
+       return pl, nil
+}
+
+type scanFinalizer func()
 
-       parts, n := s.getParts(nil, q.qo.minTimestamp, q.qo.maxTimestamp)
-       if n < 1 {
+type blockScanner struct {
+       filterIndex map[uint64]posting.List
+       l           *logger.Logger
+       pm          *protector.Memory
+       parts       [][]*part
+       finalizers  []scanFinalizer
+       qo          queryOptions
+       asc         bool
+}
+
+func (bsn *blockScanner) scan(ctx context.Context, blockCh chan 
*blockScanResultBatch) {
+       if len(bsn.parts) < 1 {
                return
        }
+       var parts []*part
+       if bsn.asc {
+               parts = bsn.parts[0]
+               bsn.parts = bsn.parts[1:]
+       } else {
+               parts = bsn.parts[len(bsn.parts)-1]
+               bsn.parts = bsn.parts[:len(bsn.parts)-1]
+       }
        bma := generateBlockMetadataArray()
        defer releaseBlockMetadataArray(bma)
        ti := generateTstIter()
        defer releaseTstIter(ti)
-       ti.init(bma, parts, q.qo.sortedSids, q.qo.minTimestamp, 
q.qo.maxTimestamp)
+       ti.init(bma, parts, bsn.qo.sortedSids, bsn.qo.minTimestamp, 
bsn.qo.maxTimestamp)
        batch := generateBlockScanResultBatch()
        if ti.Error() != nil {
                batch.err = fmt.Errorf("cannot init tstIter: %w", ti.Error())
@@ -171,7 +209,7 @@ func (q *blockScanner) scanBlocks(ctx context.Context, 
seriesList []uint64, tab
                case blockCh <- batch:
                case <-ctx.Done():
                        releaseBlockScanResultBatch(batch)
-                       q.l.Warn().Err(ti.Error()).Msg("cannot init tstIter")
+                       bsn.l.Warn().Err(ti.Error()).Msg("cannot init tstIter")
                }
                return
        }
@@ -181,21 +219,21 @@ func (q *blockScanner) scanBlocks(ctx context.Context, 
seriesList []uint64, tab
                        p: p.p,
                })
                bs := &batch.bss[len(batch.bss)-1]
-               bs.qo.copyFrom(&q.qo)
-               bs.qo.elementFilter = filter
+               bs.qo.copyFrom(&bsn.qo)
+               bs.qo.elementFilter = bsn.filterIndex[p.p.partMetadata.ID]
                bs.bm.copyFrom(p.curBlock)
                if len(batch.bss) >= cap(batch.bss) {
                        var totalBlockBytes uint64
                        for i := range batch.bss {
                                totalBlockBytes += 
batch.bss[i].bm.uncompressedSizeBytes
                        }
-                       if err := q.pm.AcquireResource(ctx, totalBlockBytes); 
err != nil {
+                       if err := bsn.pm.AcquireResource(ctx, totalBlockBytes); 
err != nil {
                                batch.err = fmt.Errorf("cannot acquire 
resource: %w", err)
                                select {
                                case blockCh <- batch:
                                case <-ctx.Done():
                                        releaseBlockScanResultBatch(batch)
-                                       q.l.Warn().Err(err).Msg("cannot acquire 
resource")
+                                       bsn.l.Warn().Err(err).Msg("cannot 
acquire resource")
                                }
                                return
                        }
@@ -203,7 +241,7 @@ func (q *blockScanner) scanBlocks(ctx context.Context, 
seriesList []uint64, tab
                        case blockCh <- batch:
                        case <-ctx.Done():
                                releaseBlockScanResultBatch(batch)
-                               q.l.Warn().Int("batch.len", 
len(batch.bss)).Msg("context canceled while sending block")
+                               bsn.l.Warn().Int("batch.len", 
len(batch.bss)).Msg("context canceled while sending block")
                                return
                        }
                        batch = generateBlockScanResultBatch()
@@ -228,25 +266,10 @@ func (q *blockScanner) scanBlocks(ctx context.Context, 
seriesList []uint64, tab
                return
        }
        releaseBlockScanResultBatch(batch)
-       return
 }
 
-func (q *blockScanner) indexSearch(ctx context.Context, seriesList []uint64, 
tw *tsTable) (posting.List, error) {
-       if q.qo.Filter == nil || q.qo.Filter == logicalstream.ENode {
-               return nil, nil
+func (bsn *blockScanner) close() {
+       for i := range bsn.finalizers {
+               bsn.finalizers[i]()
        }
-       pl, err := tw.Index().Search(ctx, seriesList, q.qo.Filter)
-       if err != nil {
-               return nil, err
-       }
-       if pl == nil {
-               return roaring.DummyPostingList, nil
-       }
-       return pl, nil
-}
-
-func (q *blockScanner) close() {
-       q.segment.DecRef()
 }
-
-type scanFinalizer func()
diff --git a/banyand/stream/metadata_test.go b/banyand/stream/metadata_test.go
index 98ec8ec9..1119b6f4 100644
--- a/banyand/stream/metadata_test.go
+++ b/banyand/stream/metadata_test.go
@@ -333,7 +333,7 @@ func queryAllMeasurements(svcs *services, expectedSize int, 
newTag []string, new
                switch d := data.(type) {
                case *streamv1.QueryResponse:
                        if len(d.Elements) != expectedSize {
-                               GinkgoWriter.Printf("actual: %s", d.Elements)
+                               GinkgoWriter.Printf("expected: %d actual: %d 
\n", expectedSize, len(d.Elements))
                                return false
                        }
                        resp = d
diff --git a/banyand/stream/query_by_ts.go b/banyand/stream/query_by_ts.go
index d45a46ef..7b38b47e 100644
--- a/banyand/stream/query_by_ts.go
+++ b/banyand/stream/query_by_ts.go
@@ -39,6 +39,7 @@ type tsResult struct {
        sm       *stream
        pm       *protector.Memory
        l        *logger.Logger
+       ts       *blockScanner
        segments []storage.Segment[*tsTable, option]
        series   []*pbv1.Series
        shards   []*model.StreamResult
@@ -46,26 +47,21 @@ type tsResult struct {
        asc      bool
 }
 
-func (t *tsResult) Pull(ctx context.Context) *model.StreamResult {
-       if len(t.segments) == 0 {
+func (t *tsResult) Pull(ctx context.Context) (r *model.StreamResult) {
+       if len(t.segments) == 0 && t.ts == nil {
                return nil
        }
-       if err := t.scanSegment(ctx); err != nil {
-               return &model.StreamResult{Error: err}
-       }
        var err error
-       for i := range t.shards {
-               if t.shards[i].Error != nil {
-                       err = multierr.Append(err, t.shards[i].Error)
-               }
-       }
-       if err != nil {
+       if r, err = t.scan(ctx); err != nil {
                return &model.StreamResult{Error: err}
        }
-       return model.MergeStreamResults(t.shards, t.qo.MaxElementSize, t.asc)
+       return r
 }
 
-func (t *tsResult) scanSegment(ctx context.Context) error {
+func (t *tsResult) scan(ctx context.Context) (*model.StreamResult, error) {
+       if t.ts != nil {
+               return t.runTabScanner(ctx)
+       }
        var segment storage.Segment[*tsTable, option]
        if t.asc {
                segment = t.segments[len(t.segments)-1]
@@ -74,18 +70,22 @@ func (t *tsResult) scanSegment(ctx context.Context) error {
                segment = t.segments[0]
                t.segments = t.segments[1:]
        }
-
-       bs := blockScanner{
-               segment: segment,
-               qo:      t.qo,
-               series:  t.series,
-               pm:      t.pm,
-               l:       t.l,
+       qo, err := searchSeries(ctx, t.qo, segment, t.series)
+       if err != nil {
+               return nil, err
+       }
+       ts, err := getBlockScanner(ctx, segment, qo, t.l, t.pm)
+       if err != nil {
+               return nil, err
        }
-       defer bs.close()
-       if err := bs.searchSeries(ctx); err != nil {
-               return err
+       if ts == nil {
+               return nil, nil
        }
+       t.ts = ts
+       return t.runTabScanner(ctx)
+}
+
+func (t *tsResult) runTabScanner(ctx context.Context) (*model.StreamResult, 
error) {
        workerSize := cgroups.CPUs()
        var workerWg sync.WaitGroup
        batchCh := make(chan *blockScanResultBatch, workerSize)
@@ -100,7 +100,7 @@ func (t *tsResult) scanSegment(ctx context.Context) error {
                        t.shards[i].Reset()
                }
        }
-       for i := 0; i < workerSize; i++ {
+       for i := range workerSize {
                go func(workerID int) {
                        tmpBlock := generateBlock()
                        defer releaseBlock(tmpBlock)
@@ -132,18 +132,23 @@ func (t *tsResult) scanSegment(ctx context.Context) error 
{
                        workerWg.Done()
                }(i)
        }
-
-       var scannerWg sync.WaitGroup
-       finalizers := bs.scanShardsInParallel(ctx, &scannerWg, batchCh)
-       scannerWg.Wait()
+       t.ts.scan(ctx, batchCh)
        close(batchCh)
        workerWg.Wait()
-       for i := range finalizers {
-               if finalizers[i] != nil {
-                       finalizers[i]()
+       if len(t.ts.parts) == 0 {
+               t.ts.close()
+               t.ts = nil
+       }
+       var err error
+       for i := range t.shards {
+               if t.shards[i].Error != nil {
+                       err = multierr.Append(err, t.shards[i].Error)
                }
        }
-       return nil
+       if err != nil {
+               return nil, err
+       }
+       return model.MergeStreamResults(t.shards, t.qo.MaxElementSize, t.asc), 
nil
 }
 
 func loadBlockCursor(bc *blockCursor, tmpBlock *block, qo queryOptions, sm 
*stream) bool {
@@ -188,6 +193,9 @@ func loadBlockCursor(bc *blockCursor, tmpBlock *block, qo 
queryOptions, sm *stre
 }
 
 func (t *tsResult) Release() {
+       if t.ts != nil {
+               t.ts.close()
+       }
        for i := range t.segments {
                t.segments[i].DecRef()
        }
diff --git a/banyand/stream/snapshot.go b/banyand/stream/snapshot.go
index 02955948..bb4b5d3b 100644
--- a/banyand/stream/snapshot.go
+++ b/banyand/stream/snapshot.go
@@ -22,6 +22,7 @@ import (
        "encoding/json"
        "fmt"
        "path/filepath"
+       "sort" // added for sorting parts
        "sync"
        "sync/atomic"
        "time"
@@ -134,6 +135,49 @@ func (s *snapshot) remove(nextEpoch uint64, merged 
map[uint64]struct{}) snapshot
        return result
 }
 
+func getDisjointParts(parts []*part, asc bool) [][]*part {
+       if len(parts) == 0 {
+               return nil
+       }
+       sort.Slice(parts, func(i, j int) bool {
+               return parts[i].partMetadata.MinTimestamp < 
parts[j].partMetadata.MinTimestamp
+       })
+
+       var groups [][]*part
+       var currentGroup []*part
+       var boundary int64
+       for _, p := range parts {
+               pMin := p.partMetadata.MinTimestamp
+               pMax := p.partMetadata.MaxTimestamp
+               if len(currentGroup) == 0 {
+                       currentGroup = append(currentGroup, p)
+                       boundary = pMax
+               } else {
+                       if pMin <= boundary {
+                               currentGroup = append(currentGroup, p)
+                               if pMax > boundary {
+                                       boundary = pMax
+                               }
+                       } else {
+                               groups = append(groups, currentGroup)
+                               currentGroup = []*part{p}
+                               boundary = pMax
+                       }
+               }
+       }
+
+       if len(currentGroup) > 0 {
+               groups = append(groups, currentGroup)
+       }
+
+       if !asc {
+               for i, j := 0, len(groups)-1; i < j; i, j = i+1, j-1 {
+                       groups[i], groups[j] = groups[j], groups[i]
+               }
+       }
+       return groups
+}
+
 func snapshotName(snapshot uint64) string {
        return fmt.Sprintf("%016x%s", snapshot, snapshotSuffix)
 }
diff --git a/banyand/stream/snapshot_test.go b/banyand/stream/snapshot_test.go
index a61dc039..8739ee26 100644
--- a/banyand/stream/snapshot_test.go
+++ b/banyand/stream/snapshot_test.go
@@ -534,3 +534,26 @@ func TestSnapshotFunctionality(t *testing.T) {
                }
        }
 }
+
+func TestGetDisjointParts(t *testing.T) {
+       p1 := &part{partMetadata: partMetadata{ID: 1, MinTimestamp: 1, 
MaxTimestamp: 3}}
+       p2 := &part{partMetadata: partMetadata{ID: 2, MinTimestamp: 2, 
MaxTimestamp: 4}}
+       p3 := &part{partMetadata: partMetadata{ID: 3, MinTimestamp: 5, 
MaxTimestamp: 7}}
+       p4 := &part{partMetadata: partMetadata{ID: 4, MinTimestamp: 6, 
MaxTimestamp: 8}}
+
+       parts := []*part{p1, p2, p3, p4}
+
+       groupsAsc := getDisjointParts(parts, true)
+       require.Equal(t, 2, len(groupsAsc), "expected 2 groups in ascending 
order")
+       require.Equal(t, 2, len(groupsAsc[0]), "first group should have 2 
parts")
+       require.Equal(t, 2, len(groupsAsc[1]), "second group should have 2 
parts")
+       require.Equal(t, p1, groupsAsc[0][0])
+       require.Equal(t, p2, groupsAsc[0][1])
+       require.Equal(t, p3, groupsAsc[1][0])
+       require.Equal(t, p4, groupsAsc[1][1])
+
+       groupsDesc := getDisjointParts(parts, false)
+       require.Equal(t, 2, len(groupsDesc), "expected 2 groups in descending 
order")
+       require.Equal(t, groupsAsc[1], groupsDesc[0], "first group in 
descending order should match second group in ascending order")
+       require.Equal(t, groupsAsc[0], groupsDesc[1], "second group in 
descending order should match first group in ascending order")
+}

Reply via email to