This is an automated email from the ASF dual-hosted git repository.

wusheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git


The following commit(s) were added to refs/heads/main by this push:
     new 768baa12 Shard-level buffer (#339)
768baa12 is described below

commit 768baa127e08283109b8f26cab0abf3f18fe9292
Author: Gao Hongtao <hanahm...@gmail.com>
AuthorDate: Tue Oct 10 17:11:44 2023 +0800

    Shard-level buffer (#339)
---
 CHANGES.md                  |   1 +
 banyand/kv/badger.go        |  46 +++++++++++++-
 banyand/kv/badger_test.go   |  72 ++++++++++++++++++++++
 banyand/kv/kv.go            |  10 +++
 banyand/measure/tstable.go  |  67 ++++++---------------
 banyand/stream/tstable.go   |  60 ++++++------------
 banyand/tsdb/block.go       |   8 ++-
 banyand/tsdb/buffer.go      | 144 ++++++++++++++++++++++++++++++++++++++++----
 banyand/tsdb/buffer_test.go | 135 ++++++++++++++++++++++++++++++-----------
 banyand/tsdb/shard.go       |  30 ++++++---
 banyand/tsdb/shard_test.go  |   3 +
 banyand/tsdb/tsdb.go        |   8 ++-
 banyand/tsdb/tsdb_test.go   |   2 +-
 banyand/tsdb/tstable.go     |  26 +-------
 14 files changed, 433 insertions(+), 179 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 8edd1e3e..fd2e3767 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -18,6 +18,7 @@ Release Notes.
 - Add mod revision check to write requests.
 - Add TTL to the property.
 - Implement node selector (e.g. PickFirst Selector, Maglev Selector).
+- Unified the buffers separated in blocks to a single buffer in the shard.
 
 ### Bugs
 
diff --git a/banyand/kv/badger.go b/banyand/kv/badger.go
index d1f998ad..09f1d385 100644
--- a/banyand/kv/badger.go
+++ b/banyand/kv/badger.go
@@ -31,6 +31,7 @@ import (
 
        "github.com/apache/skywalking-banyandb/pkg/encoding"
        "github.com/apache/skywalking-banyandb/pkg/logger"
+       "github.com/apache/skywalking-banyandb/pkg/timestamp"
 )
 
 var (
@@ -45,12 +46,16 @@ var (
 
 type badgerTSS struct {
        badger.TSet
-       db     *badger.DB
-       dbOpts badger.Options
+       timeRange timestamp.TimeRange
+       db        *badger.DB
+       dbOpts    badger.Options
 }
 
 func (b *badgerTSS) Handover(skl *skl.Skiplist) error {
-       return b.db.HandoverIterator(skl.NewUniIterator(false))
+       return b.db.HandoverIterator(&timeRangeIterator{
+               timeRange:   b.timeRange,
+               UniIterator: skl.NewUniIterator(false),
+       })
 }
 
 func (b *badgerTSS) Close() error {
@@ -119,6 +124,41 @@ func (i mergedIter) Value() y.ValueStruct {
        }
 }
 
+type timeRangeIterator struct {
+       *skl.UniIterator
+       timeRange timestamp.TimeRange
+}
+
+func (i *timeRangeIterator) Next() {
+       i.UniIterator.Next()
+       for !i.validTime() {
+               i.UniIterator.Next()
+       }
+}
+
+func (i *timeRangeIterator) Rewind() {
+       i.UniIterator.Rewind()
+       if !i.validTime() {
+               i.Next()
+       }
+}
+
+func (i *timeRangeIterator) Seek(key []byte) {
+       i.UniIterator.Seek(key)
+       if !i.validTime() {
+               i.Next()
+       }
+}
+
+func (i *timeRangeIterator) validTime() bool {
+       if !i.Valid() {
+               // If the underlying iterator is invalid, we should return true 
to stop iterating.
+               return true
+       }
+       ts := y.ParseTs(i.Key())
+       return i.timeRange.Contains(ts)
+}
+
 type badgerDB struct {
        db     *badger.DB
        dbOpts badger.Options
diff --git a/banyand/kv/badger_test.go b/banyand/kv/badger_test.go
new file mode 100644
index 00000000..5313bbfa
--- /dev/null
+++ b/banyand/kv/badger_test.go
@@ -0,0 +1,72 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package kv
+
+import (
+       "reflect"
+       "testing"
+       "time"
+
+       "github.com/dgraph-io/badger/v3/skl"
+       "github.com/dgraph-io/badger/v3/y"
+
+       "github.com/apache/skywalking-banyandb/pkg/timestamp"
+)
+
+var arenaSize = int64(4 << 20)
+
+func TestTimeRangeIterator(t *testing.T) {
+       // Create a new skiplist and insert some data
+       sl := skl.NewSkiplist(arenaSize)
+       sl.Put(y.KeyWithTs([]byte("key"), 1), y.ValueStruct{Value: 
[]byte("value1"), Meta: 0})
+       sl.Put(y.KeyWithTs([]byte("key"), 2), y.ValueStruct{Value: 
[]byte("value2"), Meta: 0})
+       sl.Put(y.KeyWithTs([]byte("key"), 3), y.ValueStruct{Value: 
[]byte("value3"), Meta: 0})
+       sl.Put(y.KeyWithTs([]byte("key"), 4), y.ValueStruct{Value: 
[]byte("value4"), Meta: 0})
+
+       // Create a new time range iterator for the skiplist
+       iter := &timeRangeIterator{
+               timeRange: timestamp.NewInclusiveTimeRange(
+                       timestamp.DefaultTimeRange.Begin.AsTime(),
+                       timestamp.DefaultTimeRange.End.AsTime()),
+               UniIterator: sl.NewUniIterator(false),
+       }
+
+       // Test Next() and Value() methods
+       var values []string
+       for iter.Rewind(); iter.Valid(); iter.Next() {
+               values = append(values, string(iter.Value().Value))
+       }
+       expectedValues := []string{"value4", "value3", "value2", "value1"}
+       if !reflect.DeepEqual(values, expectedValues) {
+               t.Errorf("unexpected values: %v, expected: %v", values, 
expectedValues)
+       }
+
+       // Test Next() method with time range filtering
+       iter = &timeRangeIterator{
+               timeRange:   timestamp.NewSectionTimeRange(time.Unix(0, 2), 
time.Unix(0, 3)),
+               UniIterator: sl.NewUniIterator(false),
+       }
+       values = nil
+       for iter.Rewind(); iter.Valid(); iter.Next() {
+               values = append(values, string(iter.Value().Value))
+       }
+       expectedValues = []string{"value2"}
+       if !reflect.DeepEqual(values, expectedValues) {
+               t.Errorf("unexpected values: %v, expected: %v", values, 
expectedValues)
+       }
+}
diff --git a/banyand/kv/kv.go b/banyand/kv/kv.go
index 49b105c3..28675730 100644
--- a/banyand/kv/kv.go
+++ b/banyand/kv/kv.go
@@ -30,6 +30,7 @@ import (
 
        "github.com/apache/skywalking-banyandb/pkg/encoding"
        "github.com/apache/skywalking-banyandb/pkg/logger"
+       "github.com/apache/skywalking-banyandb/pkg/timestamp"
 )
 
 var (
@@ -147,6 +148,15 @@ func TSSWithMemTableSize(sizeInBytes int64) 
TimeSeriesOptions {
        }
 }
 
+// TSSWithTimeRange sets the time range of the time series.
+func TSSWithTimeRange(timeRange timestamp.TimeRange) TimeSeriesOptions {
+       return func(store TimeSeriesStore) {
+               if btss, ok := store.(*badgerTSS); ok {
+                       btss.timeRange = timeRange
+               }
+       }
+}
+
 // Iterator allows iterating the kv tables.
 // TODO: use generic to provide a unique iterator.
 type Iterator interface {
diff --git a/banyand/measure/tstable.go b/banyand/measure/tstable.go
index a52b2a78..28ca9561 100644
--- a/banyand/measure/tstable.go
+++ b/banyand/measure/tstable.go
@@ -21,6 +21,7 @@ import (
        "fmt"
        "io"
        "path"
+       "strings"
        "sync"
        "sync/atomic"
        "time"
@@ -49,15 +50,14 @@ const (
 var _ tsdb.TSTable = (*tsTable)(nil)
 
 type tsTable struct {
-       encoderSST kv.TimeSeriesStore
-       sst        kv.TimeSeriesStore
-       *tsdb.BlockExpiryTracker
+       encoderSST        kv.TimeSeriesStore
+       sst               kv.TimeSeriesStore
+       bufferSupplier    *tsdb.BufferSupplier
        l                 *logger.Logger
        encoderBuffer     *tsdb.Buffer
        buffer            *tsdb.Buffer
-       closeBufferTimer  *time.Timer
        position          common.Position
-       path              string
+       id                string
        bufferSize        int64
        encoderBufferSize int64
        lock              sync.RWMutex
@@ -74,40 +74,13 @@ func (t *tsTable) openBuffer() (err error) {
                return nil
        }
        bufferSize := int(t.encoderBufferSize / defaultNumBufferShards)
-       if t.encoderBuffer, err = tsdb.NewBufferWithWal(t.l, t.position, 
bufferSize,
-               defaultWriteConcurrency, defaultNumBufferShards, 
t.encoderFlush, defaultWriteWal, &t.path); err != nil {
+       if t.encoderBuffer, err = t.bufferSupplier.Borrow(encoded, t.id, 
bufferSize, t.encoderFlush); err != nil {
                return fmt.Errorf("failed to create encoder buffer: %w", err)
        }
        bufferSize = int(t.bufferSize / defaultNumBufferShards)
-       if t.buffer, err = tsdb.NewBufferWithWal(t.l, t.position, bufferSize,
-               defaultWriteConcurrency, defaultNumBufferShards, t.flush, 
defaultWriteWal, &t.path); err != nil {
+       if t.buffer, err = t.bufferSupplier.Borrow(plain, t.id, bufferSize, 
t.flush); err != nil {
                return fmt.Errorf("failed to create buffer: %w", err)
        }
-       end := t.EndTime()
-       now := time.Now()
-       closeAfter := end.Sub(now)
-       if now.After(end) {
-               closeAfter = t.BlockExpiryDuration()
-       }
-       t.closeBufferTimer = time.AfterFunc(closeAfter, func() {
-               if t.l.Debug().Enabled() {
-                       t.l.Debug().Msg("closing buffer")
-               }
-               t.lock.Lock()
-               defer t.lock.Unlock()
-               if t.encoderBuffer != nil {
-                       if err := t.encoderBuffer.Close(); err != nil {
-                               t.l.Error().Err(err).Msg("close encoder buffer 
error")
-                       }
-                       t.encoderBuffer = nil
-               }
-               if t.buffer != nil {
-                       if err := t.buffer.Close(); err != nil {
-                               t.l.Error().Err(err).Msg("close buffer error")
-                       }
-                       t.buffer = nil
-               }
-       })
        return nil
 }
 
@@ -119,6 +92,9 @@ func (t *tsTable) Close() (err error) {
                        err = multierr.Append(err, b.Close())
                }
        }
+       for _, b := range []string{encoded, plain} {
+               t.bufferSupplier.Return(b, t.id)
+       }
        return err
 }
 
@@ -203,7 +179,7 @@ type tsTableFactory struct {
        compressionMethod databasev1.CompressionMethod
 }
 
-func (ttf *tsTableFactory) NewTSTable(blockExpiryTracker 
tsdb.BlockExpiryTracker, root string, position common.Position, l 
*logger.Logger) (tsdb.TSTable, error) {
+func (ttf *tsTableFactory) NewTSTable(bufferSupplier *tsdb.BufferSupplier, 
root string, position common.Position, l *logger.Logger) (tsdb.TSTable, error) {
        encoderSST, err := kv.OpenTimeSeriesStore(
                path.Join(root, encoded),
                kv.TSSWithMemTableSize(ttf.bufferSize),
@@ -223,19 +199,14 @@ func (ttf *tsTableFactory) NewTSTable(blockExpiryTracker 
tsdb.BlockExpiryTracker
                return nil, fmt.Errorf("failed to create time series table: 
%w", err)
        }
        table := &tsTable{
-               bufferSize:         ttf.bufferSize,
-               encoderBufferSize:  ttf.encoderBufferSize,
-               l:                  l,
-               position:           position,
-               encoderSST:         encoderSST,
-               sst:                sst,
-               BlockExpiryTracker: &blockExpiryTracker,
-               path:               root,
-       }
-       if table.IsActive() {
-               if err := table.openBuffer(); err != nil {
-                       return nil, fmt.Errorf("failed to open buffer: %w", err)
-               }
+               bufferSize:        ttf.bufferSize,
+               encoderBufferSize: ttf.encoderBufferSize,
+               l:                 l,
+               position:          position,
+               encoderSST:        encoderSST,
+               sst:               sst,
+               id:                strings.Join([]string{position.Segment, 
position.Block}, "-"),
+               bufferSupplier:    bufferSupplier,
        }
        return table, nil
 }
diff --git a/banyand/stream/tstable.go b/banyand/stream/tstable.go
index 6dca480c..1cdc4fa7 100644
--- a/banyand/stream/tstable.go
+++ b/banyand/stream/tstable.go
@@ -20,6 +20,7 @@ package stream
 import (
        "fmt"
        "path"
+       "strings"
        "sync"
        "time"
 
@@ -43,14 +44,14 @@ const (
 var _ tsdb.TSTable = (*tsTable)(nil)
 
 type tsTable struct {
-       sst kv.TimeSeriesStore
-       *tsdb.BlockExpiryTracker
-       l                *logger.Logger
-       buffer           *tsdb.Buffer
-       closeBufferTimer *time.Timer
-       position         common.Position
-       bufferSize       int64
-       lock             sync.RWMutex
+       sst            kv.TimeSeriesStore
+       l              *logger.Logger
+       buffer         *tsdb.Buffer
+       bufferSupplier *tsdb.BufferSupplier
+       position       common.Position
+       id             string
+       bufferSize     int64
+       lock           sync.RWMutex
 }
 
 func (t *tsTable) SizeOnDisk() int64 {
@@ -64,30 +65,9 @@ func (t *tsTable) openBuffer() (err error) {
                return nil
        }
        bufferSize := int(t.bufferSize / defaultNumBufferShards)
-       if t.buffer, err = tsdb.NewBuffer(t.l, t.position, bufferSize,
-               defaultWriteConcurrency, defaultNumBufferShards, t.flush); err 
!= nil {
+       if t.buffer, err = t.bufferSupplier.Borrow(id, t.id, bufferSize, 
t.flush); err != nil {
                return fmt.Errorf("failed to create buffer: %w", err)
        }
-       end := t.EndTime()
-       now := time.Now()
-       closeAfter := end.Sub(now)
-       if now.After(end) {
-               closeAfter = t.BlockExpiryDuration()
-       }
-       t.closeBufferTimer = time.AfterFunc(closeAfter, func() {
-               if t.l.Debug().Enabled() {
-                       t.l.Debug().Msg("closing buffer")
-               }
-               t.lock.Lock()
-               defer t.lock.Unlock()
-               if t.buffer == nil {
-                       return
-               }
-               if err := t.buffer.Close(); err != nil {
-                       t.l.Error().Err(err).Msg("close buffer error")
-               }
-               t.buffer = nil
-       })
        return nil
 }
 
@@ -95,7 +75,7 @@ func (t *tsTable) Close() (err error) {
        t.lock.Lock()
        defer t.lock.Unlock()
        if t.buffer != nil {
-               err = multierr.Append(err, t.buffer.Close())
+               t.bufferSupplier.Return(id, t.id)
        }
        return multierr.Combine(err, t.sst.Close())
 }
@@ -137,23 +117,19 @@ type tsTableFactory struct {
        chunkSize         int
 }
 
-func (ttf *tsTableFactory) NewTSTable(blockExpiryTracker 
tsdb.BlockExpiryTracker, root string, position common.Position, l 
*logger.Logger) (tsdb.TSTable, error) {
+func (ttf *tsTableFactory) NewTSTable(bufferSupplier *tsdb.BufferSupplier, 
root string, position common.Position, l *logger.Logger) (tsdb.TSTable, error) {
        sst, err := kv.OpenTimeSeriesStore(path.Join(root, id), 
kv.TSSWithMemTableSize(ttf.bufferSize), kv.TSSWithLogger(l.Named(id)),
                kv.TSSWithZSTDCompression(ttf.chunkSize))
        if err != nil {
                return nil, fmt.Errorf("failed to create time series table: 
%w", err)
        }
        table := &tsTable{
-               bufferSize:         ttf.bufferSize,
-               l:                  l,
-               position:           position,
-               sst:                sst,
-               BlockExpiryTracker: &blockExpiryTracker,
-       }
-       if table.IsActive() {
-               if err := table.openBuffer(); err != nil {
-                       return nil, fmt.Errorf("failed to open buffer: %w", err)
-               }
+               bufferSupplier: bufferSupplier,
+               bufferSize:     ttf.bufferSize,
+               l:              l,
+               position:       position,
+               sst:            sst,
+               id:             strings.Join([]string{position.Segment, 
position.Block}, "-"),
        }
        return table, nil
 }
diff --git a/banyand/tsdb/block.go b/banyand/tsdb/block.go
index 4f67416b..32e4d521 100644
--- a/banyand/tsdb/block.go
+++ b/banyand/tsdb/block.go
@@ -108,6 +108,7 @@ type block struct {
 
 type openOpts struct {
        tsTableFactory TSTableFactory
+       bufferSupplier *BufferSupplier
        inverted       *inverted.StoreOpts
        lsm            lsm.StoreOpts
 }
@@ -179,6 +180,11 @@ func options(ctx context.Context, root string, l 
*logger.Logger) (openOpts, erro
        if opts.tsTableFactory == nil {
                return opts, errors.New("ts table factory is nil")
        }
+       bs := ctx.Value(bufferSupplierKey)
+       if bs == nil {
+               return opts, errors.New("buffer supplier not found")
+       }
+       opts.bufferSupplier = bs.(*BufferSupplier)
        return opts, nil
 }
 
@@ -195,7 +201,7 @@ func (b *block) openSafely() (err error) {
 }
 
 func (b *block) open() (err error) {
-       if b.tsTable, err = 
b.openOpts.tsTableFactory.NewTSTable(BlockExpiryTracker{ttl: b.End, clock: 
b.clock},
+       if b.tsTable, err = 
b.openOpts.tsTableFactory.NewTSTable(b.openOpts.bufferSupplier,
                b.path, b.position, b.l); err != nil {
                return err
        }
diff --git a/banyand/tsdb/buffer.go b/banyand/tsdb/buffer.go
index e1fe1c50..49dbf4d7 100644
--- a/banyand/tsdb/buffer.go
+++ b/banyand/tsdb/buffer.go
@@ -89,7 +89,7 @@ type bufferShardBucket struct {
 
 // Buffer is an exported struct that represents a buffer composed of multiple 
shard buckets.
 type Buffer struct {
-       onFlushFn      onFlush
+       onFlushFn      sync.Map
        entryCloser    *run.Closer
        log            *logger.Logger
        buckets        []bufferShardBucket
@@ -101,18 +101,17 @@ type Buffer struct {
 }
 
 // NewBuffer creates a new Buffer instance with the given parameters.
-func NewBuffer(log *logger.Logger, position common.Position, flushSize, 
writeConcurrency, numShards int, onFlushFn onFlush) (*Buffer, error) {
-       return NewBufferWithWal(log, position, flushSize, writeConcurrency, 
numShards, onFlushFn, false, nil)
+func NewBuffer(log *logger.Logger, position common.Position, flushSize, 
writeConcurrency, numShards int) (*Buffer, error) {
+       return NewBufferWithWal(log, position, flushSize, writeConcurrency, 
numShards, false, "")
 }
 
 // NewBufferWithWal creates a new Buffer instance with the given parameters.
-func NewBufferWithWal(log *logger.Logger, position common.Position, flushSize, 
writeConcurrency, numShards int, onFlushFn onFlush, enableWal bool, walPath 
*string,
+func NewBufferWithWal(log *logger.Logger, position common.Position, flushSize, 
writeConcurrency, numShards int, enableWal bool, walPath string,
 ) (*Buffer, error) {
        buckets := make([]bufferShardBucket, numShards)
        buffer := &Buffer{
                buckets:     buckets,
                numShards:   numShards,
-               onFlushFn:   onFlushFn,
                entryCloser: run.NewCloser(1),
                log:         log.Named("buffer"),
                enableWal:   enableWal,
@@ -133,12 +132,12 @@ func NewBufferWithWal(log *logger.Logger, position 
common.Position, flushSize, w
                        shardLabelValues: position.ShardLabelValues(),
                        enableWal:        enableWal,
                }
-               buckets[i].start(onFlushFn)
+               buckets[i].start(buffer.flushers)
                if enableWal {
-                       if walPath == nil {
+                       if walPath == "" {
                                return nil, errors.New("wal path is required")
                        }
-                       shardWalPath := fmt.Sprintf("%s/buffer-%d", *walPath, i)
+                       shardWalPath := fmt.Sprintf("%s/buffer-%d", walPath, i)
                        if err := buckets[i].startWal(shardWalPath, 
defaultWalSyncMode); err != nil {
                                return nil, errors.Wrap(err, "failed to start 
wal")
                        }
@@ -148,6 +147,29 @@ func NewBufferWithWal(log *logger.Logger, position 
common.Position, flushSize, w
        return buffer, nil
 }
 
+// Register registers a callback function that will be called when a shard 
bucket is flushed.
+func (b *Buffer) Register(id string, onFlushFn onFlush) {
+       b.onFlushFn.LoadOrStore(id, onFlushFn)
+}
+
+// Unregister unregisters a callback function that will be called when a shard 
bucket is flushed.
+func (b *Buffer) Unregister(id string) {
+       b.onFlushFn.Delete(id)
+}
+
+func (b *Buffer) flushers() []onFlush {
+       var flushers []onFlush
+       b.onFlushFn.Range(func(key, value interface{}) bool {
+               flushers = append(flushers, value.(onFlush))
+               return true
+       })
+       return flushers
+}
+
+func (b *Buffer) isEmpty() bool {
+       return len(b.flushers()) == 0
+}
+
 // Write adds a key-value pair with a timestamp to the appropriate shard 
bucket in the buffer.
 func (b *Buffer) Write(key, value []byte, timestamp time.Time) error {
        if b == nil || !b.entryCloser.AddRunning() {
@@ -206,8 +228,11 @@ func (b *Buffer) Close() error {
                }
                b.writeWaitGroup.Wait()
                for i := 0; i < b.numShards; i++ {
-                       if err := b.onFlushFn(i, b.buckets[i].mutable); err != 
nil {
-                               b.buckets[i].log.Err(err).Msg("flushing mutable 
buffer failed")
+                       ff := b.flushers()
+                       for _, fn := range ff {
+                               if err := fn(i, b.buckets[i].mutable); err != 
nil {
+                                       b.buckets[i].log.Err(err).Msg("flushing 
mutable buffer failed")
+                               }
                        }
                        b.buckets[i].mutable.DecrRef()
                }
@@ -246,7 +271,7 @@ func (bsb *bufferShardBucket) getAll() ([]*skl.Skiplist, 
func()) {
        }
 }
 
-func (bsb *bufferShardBucket) start(onFlushFn onFlush) {
+func (bsb *bufferShardBucket) start(flushers func() []onFlush) {
        go func() {
                defer func() {
                        for _, g := range []meter.Gauge{maxBytes, mutableBytes} 
{
@@ -259,12 +284,21 @@ func (bsb *bufferShardBucket) start(onFlushFn onFlush) {
                        memSize := oldSkipList.MemSize()
                        onFlushFnDone := false
                        t1 := time.Now()
+                       ff := flushers()
                        for {
                                if !onFlushFnDone {
-                                       if err := onFlushFn(bsb.index, 
oldSkipList); err != nil {
-                                               bsb.log.Err(err).Msg("flushing 
immutable buffer failed. Retrying...")
+                                       failedFns := make([]onFlush, 0)
+                                       for i := 0; i < len(ff); i++ {
+                                               fn := ff[i]
+                                               if err := fn(bsb.index, 
oldSkipList); err != nil {
+                                                       
bsb.log.Err(err).Msg("flushing immutable buffer failed. Retrying...")
+                                                       failedFns = 
append(failedFns, fn)
+                                               }
+                                       }
+                                       if len(failedFns) > 0 {
                                                flushNum.Inc(1, 
append(bsb.labelValues[:2], "true")...)
                                                time.Sleep(time.Second)
+                                               ff = failedFns
                                                continue
                                        }
                                        onFlushFnDone = true
@@ -445,3 +479,87 @@ func (bsb *bufferShardBucket) writeWal(key, value []byte, 
timestamp time.Time) e
        wg.Wait()
        return walErr
 }
+
+// BufferSupplier lends a Buffer to a caller and returns it when the caller is 
done with it.
+type BufferSupplier struct {
+       l                *logger.Logger
+       p                common.Position
+       buffers          sync.Map
+       path             string
+       writeConcurrency int
+       numShards        int
+       enableWAL        bool
+}
+
+// NewBufferSupplier creates a new BufferSupplier instance with the given 
parameters.
+func NewBufferSupplier(l *logger.Logger, p common.Position, writeConcurrency, 
numShards int, enableWAL bool, path string) *BufferSupplier {
+       return &BufferSupplier{
+               l:                l.Named("buffer-supplier"),
+               p:                p,
+               writeConcurrency: writeConcurrency,
+               numShards:        numShards,
+               enableWAL:        enableWAL,
+               path:             path,
+       }
+}
+
+// Borrow borrows a Buffer from the BufferSupplier.
+func (b *BufferSupplier) Borrow(bufferName, name string, bufferSize int, 
onFlushFn onFlush) (buffer *Buffer, err error) {
+       if bufferName == "" || name == "" {
+               return nil, errors.New("bufferName and name are required")
+       }
+       if onFlushFn == nil {
+               return nil, errors.New("onFlushFn is required")
+       }
+       defer func() {
+               if buffer != nil {
+                       buffer.Register(name, onFlushFn)
+               }
+       }()
+       if v, ok := b.buffers.Load(bufferName); ok {
+               buffer = v.(*Buffer)
+               return v.(*Buffer), nil
+       }
+       if buffer, err = NewBufferWithWal(b.l.Named("buffer-"+bufferName), b.p,
+               bufferSize, b.writeConcurrency, b.numShards, b.enableWAL, 
b.path); err != nil {
+               return nil, err
+       }
+       if v, ok := b.buffers.LoadOrStore(bufferName, buffer); ok {
+               _ = buffer.Close()
+               buffer = v.(*Buffer)
+               return buffer, nil
+       }
+       return buffer, nil
+}
+
+// Return returns a Buffer to the BufferSupplier.
+func (b *BufferSupplier) Return(bufferName, name string) {
+       if v, ok := b.buffers.Load(bufferName); ok {
+               buffer := v.(*Buffer)
+               buffer.Unregister(name)
+               if buffer.isEmpty() {
+                       b.buffers.Delete(bufferName)
+                       _ = buffer.Close()
+               }
+       }
+}
+
+// Volume returns the number of Buffers in the BufferSupplier.
+func (b *BufferSupplier) Volume() int {
+       volume := 0
+       b.buffers.Range(func(key, value interface{}) bool {
+               volume++
+               return true
+       })
+       return volume
+}
+
+// Close closes all Buffers in the BufferSupplier.
+func (b *BufferSupplier) Close() error {
+       b.buffers.Range(func(key, value interface{}) bool {
+               buffer := value.(*Buffer)
+               _ = buffer.Close()
+               return true
+       })
+       return nil
+}
diff --git a/banyand/tsdb/buffer_test.go b/banyand/tsdb/buffer_test.go
index 6b46eea3..82df4251 100644
--- a/banyand/tsdb/buffer_test.go
+++ b/banyand/tsdb/buffer_test.go
@@ -38,6 +38,10 @@ import (
        "github.com/apache/skywalking-banyandb/pkg/test/flags"
 )
 
+var emptyFn = func(shardIndex int, skl *skl.Skiplist) error {
+       return nil
+}
+
 var _ = Describe("Buffer", func() {
        var (
                buffer *tsdb.Buffer
@@ -54,10 +58,9 @@ var _ = Describe("Buffer", func() {
        Context("Write and Read", func() {
                BeforeEach(func() {
                        var err error
-                       buffer, err = tsdb.NewBuffer(log, common.Position{}, 
1024*1024, 16, 4, func(shardIndex int, skl *skl.Skiplist) error {
-                               return nil
-                       })
+                       buffer, err = tsdb.NewBuffer(log, common.Position{}, 
1024*1024, 16, 4)
                        Expect(err).ToNot(HaveOccurred())
+                       buffer.Register("test", emptyFn)
                })
 
                AfterEach(func() {
@@ -122,11 +125,12 @@ var _ = Describe("Buffer", func() {
                        }
 
                        var err error
-                       buffer, err = tsdb.NewBuffer(log, common.Position{}, 
1024, 16, numShards, onFlushFn)
+                       buffer, err = tsdb.NewBuffer(log, common.Position{}, 
1024, 16, numShards)
                        defer func() {
                                _ = buffer.Close()
                        }()
                        Expect(err).ToNot(HaveOccurred())
+                       buffer.Register("test", onFlushFn)
 
                        randInt := func() int {
                                n, err := rand.Int(rand.Reader, 
big.NewInt(1000))
@@ -181,25 +185,24 @@ var _ = Describe("Buffer", func() {
                                flushSize,
                                writeConcurrency,
                                numShards,
-                               func(shardIndex int, skl *skl.Skiplist) error {
-                                       flushMutex.Lock()
-                                       defer flushMutex.Unlock()
-
-                                       shardWalDir := filepath.Join(path, 
"buffer-"+strconv.Itoa(shardIndex))
-                                       var shardWalList []os.DirEntry
-                                       shardWalList, err = 
os.ReadDir(shardWalDir)
-                                       Expect(err).ToNot(HaveOccurred())
-                                       for _, shardWalFile := range 
shardWalList {
-                                               
Expect(shardWalFile.IsDir()).To(BeFalse())
-                                               
Expect(shardWalFile.Name()).To(HaveSuffix(".wal"))
-                                               shardWalFileHistory[shardIndex] 
= append(shardWalFileHistory[shardIndex], shardWalFile.Name())
-                                       }
-                                       return nil
-                               },
                                true,
-                               &path)
+                               path)
                        Expect(err).ToNot(HaveOccurred())
+                       buffer.Register("test", func(shardIndex int, skl 
*skl.Skiplist) error {
+                               flushMutex.Lock()
+                               defer flushMutex.Unlock()
 
+                               shardWalDir := filepath.Join(path, 
"buffer-"+strconv.Itoa(shardIndex))
+                               var shardWalList []os.DirEntry
+                               shardWalList, err = os.ReadDir(shardWalDir)
+                               Expect(err).ToNot(HaveOccurred())
+                               for _, shardWalFile := range shardWalList {
+                                       
Expect(shardWalFile.IsDir()).To(BeFalse())
+                                       
Expect(shardWalFile.Name()).To(HaveSuffix(".wal"))
+                                       shardWalFileHistory[shardIndex] = 
append(shardWalFileHistory[shardIndex], shardWalFile.Name())
+                               }
+                               return nil
+                       })
                        // Write buffer & wal
                        var wg sync.WaitGroup
                        wg.Add(writeConcurrency)
@@ -245,18 +248,18 @@ var _ = Describe("Buffer", func() {
                                flushSize,
                                writeConcurrency,
                                numShards,
-                               func(shardIndex int, skl *skl.Skiplist) error {
-                                       flushMutex.Lock()
-                                       defer flushMutex.Unlock()
-
-                                       if !bufferFlushed {
-                                               bufferFlushed = true
-                                       }
-                                       return nil
-                               },
                                true,
-                               &path)
+                               path)
                        Expect(err).ToNot(HaveOccurred())
+                       buffer.Register("test", func(shardIndex int, skl 
*skl.Skiplist) error {
+                               flushMutex.Lock()
+                               defer flushMutex.Unlock()
+
+                               if !bufferFlushed {
+                                       bufferFlushed = true
+                               }
+                               return nil
+                       })
 
                        // Write buffer & wal
                        for i := 0; i < numShards; i++ {
@@ -278,13 +281,11 @@ var _ = Describe("Buffer", func() {
                                flushSize,
                                writeConcurrency,
                                numShards,
-                               func(shardIndex int, skl *skl.Skiplist) error {
-                                       return nil
-                               },
                                true,
-                               &path)
+                               path)
                        Expect(err).ToNot(HaveOccurred())
                        defer buffer.Close()
+                       buffer.Register("test", emptyFn)
 
                        // Check buffer was recovered from wal
                        for i := 0; i < numShards; i++ {
@@ -298,3 +299,69 @@ var _ = Describe("Buffer", func() {
                })
        })
 })
+
+var _ = Describe("bufferSupplier", func() {
+       var (
+               b     *tsdb.BufferSupplier
+               goods []gleak.Goroutine
+       )
+
+       BeforeEach(func() {
+               goods = gleak.Goroutines()
+               b = 
tsdb.NewBufferSupplier(logger.GetLogger("buffer-supplier-test"), 
common.Position{}, 16, 4, false, "")
+       })
+       AfterEach(func() {
+               b.Close()
+               Eventually(gleak.Goroutines, 
flags.EventuallyTimeout).ShouldNot(gleak.HaveLeaked(goods))
+       })
+
+       Describe("Borrow", func() {
+               Context("when borrowing a buffer with a new name", func() {
+                       It("should return a new buffer instance", func() {
+                               buf, err := b.Borrow("buffer", "test", 
1024*1024, emptyFn)
+                               Expect(err).ToNot(HaveOccurred())
+                               Expect(buf).ToNot(BeNil())
+                               Expect(b.Volume()).To(Equal(1))
+                               b.Return("buffer", "test")
+                               Expect(b.Volume()).To(Equal(0))
+                       })
+               })
+
+               Context("when borrowing a buffer with an existing name", func() 
{
+                       It("should return the same buffer instance", func() {
+                               buf1, err := b.Borrow("buffer", "test", 
1024*1024, emptyFn)
+                               Expect(err).ToNot(HaveOccurred())
+                               Expect(buf1).ToNot(BeNil())
+                               Expect(b.Volume()).To(Equal(1))
+
+                               buf2, err := b.Borrow("buffer", "test", 
1024*1024, emptyFn)
+                               Expect(err).ToNot(HaveOccurred())
+                               Expect(buf2).ToNot(BeNil())
+                               Expect(b.Volume()).To(Equal(1))
+
+                               Expect(buf2).To(Equal(buf1))
+                               b.Return("buffer", "test")
+                               Expect(b.Volume()).To(Equal(0))
+                       })
+               })
+
+               Context("when borrowing a buffer from different buffer pools", 
func() {
+                       It("should return different buffer instances", func() {
+                               buf1, err := b.Borrow("buffer1", "test", 
1024*1024, emptyFn)
+                               Expect(err).ToNot(HaveOccurred())
+                               Expect(buf1).ToNot(BeNil())
+                               Expect(b.Volume()).To(Equal(1))
+
+                               buf2, err := b.Borrow("buffer2", "test", 
1024*1024, emptyFn)
+                               Expect(err).ToNot(HaveOccurred())
+                               Expect(buf2).ToNot(BeNil())
+                               Expect(b.Volume()).To(Equal(2))
+
+                               Expect(buf2).ToNot(Equal(buf1))
+                               b.Return("buffer1", "test")
+                               b.Return("buffer2", "test")
+                               Expect(b.Volume()).To(Equal(0))
+                       })
+               })
+       })
+})
diff --git a/banyand/tsdb/shard.go b/banyand/tsdb/shard.go
index af9369af..b7841837 100644
--- a/banyand/tsdb/shard.go
+++ b/banyand/tsdb/shard.go
@@ -41,6 +41,8 @@ const (
        defaultBlockQueueSize    = 2
        defaultMaxBlockQueueSize = 64
        defaultKVMemorySize      = 4 << 20
+       defaultNumBufferShards   = 2
+       defaultWriteConcurrency  = 1000
 )
 
 var (
@@ -58,6 +60,10 @@ var (
        onDiskBytesGauge     meter.Gauge
 )
 
+type contextBufferSupplierKey struct{}
+
+var bufferSupplierKey = contextBufferSupplierKey{}
+
 func init() {
        labelNames := common.ShardLabelNames()
        diskStateGauge = shardProvider.Gauge("disk_state", append(labelNames, 
"kind")...)
@@ -76,6 +82,7 @@ type shard struct {
        segmentController     *segmentController
        segmentManageStrategy *bucket.Strategy
        scheduler             *timestamp.Scheduler
+       bufferSupplier        *BufferSupplier
        position              common.Position
        closeOnce             sync.Once
        id                    common.ShardID
@@ -83,7 +90,7 @@ type shard struct {
 
 // OpenShard returns an existed Shard or create a new one if not existed.
 func OpenShard(ctx context.Context, id common.ShardID,
-       root string, segmentSize, blockSize, ttl IntervalRule, openedBlockSize, 
maxOpenedBlockSize int,
+       root string, segmentSize, blockSize, ttl IntervalRule, openedBlockSize, 
maxOpenedBlockSize int, enableWAL bool,
 ) (Shard, error) {
        path, err := mkdir(shardTemplate, root, int(id))
        if err != nil {
@@ -100,17 +107,20 @@ func OpenShard(ctx context.Context, id common.ShardID,
                return p
        })
        clock, _ := timestamp.GetClock(shardCtx)
+
        scheduler := timestamp.NewScheduler(l, clock)
-       sc, err := newSegmentController(shardCtx, path, segmentSize, blockSize, 
openedBlockSize, maxOpenedBlockSize, l, scheduler)
+       s := &shard{
+               id:        id,
+               l:         l,
+               scheduler: scheduler,
+               position:  common.GetPosition(shardCtx),
+       }
+       s.bufferSupplier = NewBufferSupplier(l, s.position, 
defaultWriteConcurrency, defaultNumBufferShards, enableWAL, path)
+       shardCtx = context.WithValue(shardCtx, bufferSupplierKey, 
s.bufferSupplier)
+       s.segmentController, err = newSegmentController(shardCtx, path, 
segmentSize, blockSize, openedBlockSize, maxOpenedBlockSize, l, scheduler)
        if err != nil {
                return nil, errors.Wrapf(err, "create the segment controller of 
the shard %d", int(id))
        }
-       s := &shard{
-               id:                id,
-               segmentController: sc,
-               l:                 l,
-               scheduler:         scheduler,
-       }
        err = s.segmentController.open()
        if err != nil {
                return nil, err
@@ -131,7 +141,6 @@ func OpenShard(ctx context.Context, id common.ShardID,
                return nil, err
        }
        s.segmentManageStrategy.Run()
-       s.position = common.GetPosition(shardCtx)
        retentionTask := newRetentionTask(s.segmentController, ttl)
        if err := scheduler.Register("retention", retentionTask.option, 
retentionTask.expr, retentionTask.run); err != nil {
                return nil, err
@@ -235,11 +244,12 @@ func (s *shard) TriggerSchedule(task string) bool {
 
 func (s *shard) Close() (err error) {
        s.closeOnce.Do(func() {
+               _ = s.bufferSupplier.Close()
                s.scheduler.Close()
                s.segmentManageStrategy.Close()
                ctx, cancel := context.WithTimeout(context.Background(), 
5*time.Second)
                defer cancel()
-               err = multierr.Combine(s.segmentController.close(ctx), 
s.seriesDatabase.Close())
+               err = multierr.Combine(s.bufferSupplier.Close(), 
s.segmentController.close(ctx), s.seriesDatabase.Close())
        })
        return err
 }
diff --git a/banyand/tsdb/shard_test.go b/banyand/tsdb/shard_test.go
index d952f85c..ae9abfd8 100644
--- a/banyand/tsdb/shard_test.go
+++ b/banyand/tsdb/shard_test.go
@@ -90,6 +90,7 @@ var _ = Describe("Shard", func() {
                                },
                                2,
                                3,
+                               false,
                        )
                        Expect(err).NotTo(HaveOccurred())
                        started("BlockID-19700101-1970010100-1", 
"SegID-19700101-1")
@@ -433,6 +434,7 @@ var _ = Describe("Shard", func() {
                                },
                                10,
                                15,
+                               false,
                        )
                        Expect(err).NotTo(HaveOccurred())
                        started("BlockID-19700101-1970010100-1", 
"SegID-19700101-1", "retention")
@@ -556,6 +558,7 @@ var _ = Describe("Shard", func() {
                                },
                                2,
                                3,
+                               false,
                        )
                        Expect(err).NotTo(HaveOccurred())
                        started("BlockID-19700101-1970010101-1", 
"SegID-19700101-1", "retention")
diff --git a/banyand/tsdb/tsdb.go b/banyand/tsdb/tsdb.go
index c47f2d72..e8b8a311 100644
--- a/banyand/tsdb/tsdb.go
+++ b/banyand/tsdb/tsdb.go
@@ -121,9 +121,10 @@ type DatabaseOpts struct {
        BlockInvertedIndex InvertedIndexOpts
        SeriesMemSize      run.Bytes
        GlobalIndexMemSize run.Bytes
+       IndexGranularity   IndexGranularity
        ShardNum           uint32
        EnableGlobalIndex  bool
-       IndexGranularity   IndexGranularity
+       EnableWAL          bool
 }
 
 // InvertedIndexOpts wraps options to create the block inverted index.
@@ -212,6 +213,7 @@ type database struct {
        sync.RWMutex
        shardNum           uint32
        shardCreationState uint32
+       enableWAL          bool
 }
 
 func (d *database) CreateShardsAndGetByID(id common.ShardID) (Shard, error) {
@@ -292,6 +294,7 @@ func OpenDatabase(ctx context.Context, opts DatabaseOpts) 
(Database, error) {
                segmentSize: opts.SegmentInterval,
                blockSize:   opts.BlockInterval,
                ttl:         opts.TTL,
+               enableWAL:   opts.EnableWAL,
        }
        db.logger.Info().Str("path", opts.Location).Msg("initialized")
        var entries []os.DirEntry
@@ -313,7 +316,7 @@ func createDatabase(db *database, startID int) (Database, 
error) {
        for i := startID; i < int(db.shardNum); i++ {
                db.logger.Info().Int("shard_id", i).Msg("creating a shard")
                so, errNewShard := OpenShard(db.shardCreationCtx, 
common.ShardID(i),
-                       db.location, db.segmentSize, db.blockSize, db.ttl, 
defaultBlockQueueSize, defaultMaxBlockQueueSize)
+                       db.location, db.segmentSize, db.blockSize, db.ttl, 
defaultBlockQueueSize, defaultMaxBlockQueueSize, db.enableWAL)
                if errNewShard != nil {
                        err = multierr.Append(err, errNewShard)
                        continue
@@ -346,6 +349,7 @@ func loadDatabase(ctx context.Context, db *database) 
(Database, error) {
                        db.ttl,
                        defaultBlockQueueSize,
                        defaultMaxBlockQueueSize,
+                       db.enableWAL,
                )
                if errOpenShard != nil {
                        return errOpenShard
diff --git a/banyand/tsdb/tsdb_test.go b/banyand/tsdb/tsdb_test.go
index 9046b8c7..ebe299b7 100644
--- a/banyand/tsdb/tsdb_test.go
+++ b/banyand/tsdb/tsdb_test.go
@@ -141,7 +141,7 @@ func NewByPassTSTableFactory() TSTableFactory {
        return bypassTSTableFactory{}
 }
 
-func (bypassTSTableFactory) NewTSTable(_ BlockExpiryTracker, _ string, _ 
common.Position, _ *logger.Logger) (TSTable, error) {
+func (bypassTSTableFactory) NewTSTable(_ *BufferSupplier, _ string, _ 
common.Position, _ *logger.Logger) (TSTable, error) {
        return newBypassTSTable()
 }
 
diff --git a/banyand/tsdb/tstable.go b/banyand/tsdb/tstable.go
index a10d1de9..9aab9bb0 100644
--- a/banyand/tsdb/tstable.go
+++ b/banyand/tsdb/tstable.go
@@ -25,11 +25,8 @@ import (
 
        "github.com/apache/skywalking-banyandb/api/common"
        "github.com/apache/skywalking-banyandb/pkg/logger"
-       "github.com/apache/skywalking-banyandb/pkg/timestamp"
 )
 
-const maxBlockAge = time.Hour
-
 // TSTable is time series table.
 type TSTable interface {
        // Put a value with a timestamp/version
@@ -46,26 +43,5 @@ type TSTable interface {
 // TSTableFactory is the factory of TSTable.
 type TSTableFactory interface {
        // NewTSTable creates a new TSTable.
-       NewTSTable(bufferLifecycle BlockExpiryTracker, root string, position 
common.Position, l *logger.Logger) (TSTable, error)
-}
-
-// BlockExpiryTracker tracks the expiry of the buffer.
-type BlockExpiryTracker struct {
-       clock timestamp.Clock
-       ttl   time.Time
-}
-
-// IsActive checks if the buffer is active.
-func (bl *BlockExpiryTracker) IsActive() bool {
-       return !bl.clock.Now().After(bl.EndTime())
-}
-
-// EndTime returns the end time of the buffer.
-func (bl *BlockExpiryTracker) EndTime() time.Time {
-       return bl.ttl.Add(maxBlockAge)
-}
-
-// BlockExpiryDuration returns the expiry duration of the buffer.
-func (bl *BlockExpiryTracker) BlockExpiryDuration() time.Duration {
-       return maxBlockAge
+       NewTSTable(bufferSupplier *BufferSupplier, root string, position 
common.Position, l *logger.Logger) (TSTable, error)
 }


Reply via email to