This is an automated email from the ASF dual-hosted git repository.
wusheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git
The following commit(s) were added to refs/heads/main by this push:
new 5d6fd736c Fix the Merged Files Exhausting the Disk (#967)
5d6fd736c is described below
commit 5d6fd736c0ab80f412a3a020491dbc2612416b4c
Author: Gao Hongtao <[email protected]>
AuthorDate: Sat Feb 7 11:15:07 2026 +0800
Fix the Merged Files Exhausting the Disk (#967)
* feat(snapshot): introduce generic snapshot coordination for atomic
transitions
- Added a new package for managing atomic snapshot transitions across trace
and sidx components.
- Enhanced the SIDX interface with snapshot transaction support, allowing
for coordinated updates.
- Updated existing snapshot handling to utilize the new transaction model
for memory, flushed, and merged parts.
- Refactored snapshot types and methods for consistency and clarity.
* feat(trace): enhance error handling during flush and merge operations
- Introduced ReleaseFlushedParts and ReleaseNewPart methods to manage
resource cleanup when flush or merge operations fail.
- Updated merge and flush logic to ensure proper cleanup of trace and sidx
parts on errors.
- Improved test coverage for error scenarios to validate cleanup behavior.
---
CHANGES.md | 1 +
banyand/internal/sidx/interfaces.go | 19 +
banyand/internal/sidx/introducer.go | 24 +-
banyand/internal/sidx/merge.go | 3 +
banyand/internal/sidx/query.go | 4 +-
banyand/internal/sidx/sidx.go | 77 ++-
banyand/internal/sidx/snapshot.go | 68 ++-
banyand/internal/snapshot/snapshot.go | 219 +++++++
banyand/internal/snapshot/snapshot_test.go | 917 +++++++++++++++++++++++++++++
banyand/internal/storage/segment_test.go | 8 +-
banyand/trace/flusher.go | 13 +
banyand/trace/introducer.go | 266 +++++++--
banyand/trace/merger.go | 42 +-
banyand/trace/merger_test.go | 127 ++++
banyand/trace/snapshot.go | 18 +-
banyand/trace/streaming_pipeline_test.go | 35 ++
16 files changed, 1738 insertions(+), 103 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 8d8facb5d..3ffc1b9be 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -22,6 +22,7 @@ Release Notes.
- **Breaking Change**: Change the data storage path structure for property
model:
- From: `<data-dir>/property/data/shard-<id>/...`
- To: `<data-dir>/property/data/<group>/shard-<id>/...`
+- Add a generic snapshot coordination package for atomic snapshot transitions
across trace and sidx.
### Bug Fixes
diff --git a/banyand/internal/sidx/interfaces.go
b/banyand/internal/sidx/interfaces.go
index 670f76f1a..5c4a88219 100644
--- a/banyand/internal/sidx/interfaces.go
+++ b/banyand/internal/sidx/interfaces.go
@@ -31,11 +31,30 @@ import (
"github.com/apache/skywalking-banyandb/pkg/query/model"
)
+// SnapshotTransactionSupport defines snapshot and transaction-related
operations
+// for coordinated snapshot updates. It is embedded by SIDX so that callers
+// (e.g. trace) can participate in atomic snapshot transactions.
+type SnapshotTransactionSupport interface {
+ // PrepareMemPart prepares a transition for introducing a memory part.
+ PrepareMemPart(partID uint64, mp *MemPart) func(cur *Snapshot) *Snapshot
+ // PrepareFlushed prepares a transition for introducing flushed parts.
+ PrepareFlushed(intro *FlusherIntroduction) func(cur *Snapshot) *Snapshot
+ // PrepareMerged prepares a transition for introducing merged parts.
+ PrepareMerged(intro *MergerIntroduction) func(cur *Snapshot) *Snapshot
+ // PrepareSynced prepares a transition for removing synced parts.
+ PrepareSynced(partIDsToSync map[uint64]struct{}) func(cur *Snapshot)
*Snapshot
+ // CurrentSnapshot returns the current snapshot with incremented
reference count.
+ CurrentSnapshot() *Snapshot
+ // ReplaceSnapshot atomically replaces the current snapshot with next.
+ ReplaceSnapshot(next *Snapshot)
+}
+
// SIDX defines the main secondary index interface with user-controlled
ordering.
// The core principle is that int64 keys are provided by users and treated as
// opaque ordering values by sidx - the system only performs numerical
comparisons
// without interpreting the semantic meaning of keys.
type SIDX interface {
+ SnapshotTransactionSupport
// IntroduceMemPart introduces a memPart to the SIDX instance.
IntroduceMemPart(partID uint64, mp *MemPart)
// IntroduceFlushed introduces a flushed map to the SIDX instance.
diff --git a/banyand/internal/sidx/introducer.go
b/banyand/internal/sidx/introducer.go
index 9bdefde82..d7b0b2ccc 100644
--- a/banyand/internal/sidx/introducer.go
+++ b/banyand/internal/sidx/introducer.go
@@ -34,6 +34,17 @@ func (i *FlusherIntroduction) Release() {
releaseFlusherIntroduction(i)
}
+// ReleaseFlushedParts releases all flushed part wrappers (closes file
handles).
+// Call this when a flush is abandoned so the caller can remove part
directories from disk before calling Release().
+func (i *FlusherIntroduction) ReleaseFlushedParts() {
+ for _, pw := range i.flushed {
+ pw.release()
+ }
+ for k := range i.flushed {
+ delete(i.flushed, k)
+ }
+}
+
func (i *FlusherIntroduction) reset() {
for k := range i.flushed {
delete(i.flushed, k)
@@ -69,6 +80,15 @@ func (i *MergerIntroduction) Release() {
releaseMergerIntroduction(i)
}
+// ReleaseNewPart releases the newPart from this introduction (closes file
handles).
+// Call this when the merge is abandoned so the part can be removed from disk
by the caller.
+func (i *MergerIntroduction) ReleaseNewPart() {
+ if i.newPart != nil {
+ i.newPart.release()
+ i.newPart = nil
+ }
+}
+
func (i *MergerIntroduction) reset() {
for k := range i.merged {
delete(i.merged, k)
@@ -100,7 +120,7 @@ func (s *sidx) IntroduceMemPart(partID uint64, memPart
*memPart) {
if cur != nil {
defer cur.decRef()
} else {
- cur = &snapshot{}
+ cur = &Snapshot{}
}
nextSnp := cur.copyAllTo()
@@ -174,7 +194,7 @@ func (s *sidx) TakeFileSnapshot(dst string) error {
return nil
}
-func (s *sidx) replaceSnapshot(next *snapshot) {
+func (s *sidx) replaceSnapshot(next *Snapshot) {
s.mu.Lock()
defer s.mu.Unlock()
if s.snapshot != nil {
diff --git a/banyand/internal/sidx/merge.go b/banyand/internal/sidx/merge.go
index d5b64f14a..a537e3881 100644
--- a/banyand/internal/sidx/merge.go
+++ b/banyand/internal/sidx/merge.go
@@ -45,6 +45,9 @@ func (s *sidx) Merge(closeCh <-chan struct{}, partIDtoMerge
map[uint64]struct{},
partsToMerge = append(partsToMerge, pw)
}
}
+ if len(partsToMerge) == 0 {
+ return nil, nil
+ }
if d := s.l.Debug(); d.Enabled() {
if len(partsToMerge) != len(partIDtoMerge) {
d.Int("parts_to_merge_count", len(partsToMerge)).
diff --git a/banyand/internal/sidx/query.go b/banyand/internal/sidx/query.go
index 36d2f80e2..65f55ff91 100644
--- a/banyand/internal/sidx/query.go
+++ b/banyand/internal/sidx/query.go
@@ -149,7 +149,7 @@ func finalizeStreamingSpan(span *query.Span, errPtr *error)
{
func (s *sidx) prepareStreamingResources(
ctx context.Context,
req QueryRequest,
- snap *snapshot,
+ snap *Snapshot,
span *query.Span,
) (*streamingQueryResources, bool) {
var prepareSpan *query.Span
@@ -485,7 +485,7 @@ func extractOrdering(req QueryRequest) bool {
}
// selectPartsForQuery selects relevant parts from snapshot based on key range.
-func selectPartsForQuery(snap *snapshot, minKey, maxKey int64) []*part {
+func selectPartsForQuery(snap *Snapshot, minKey, maxKey int64) []*part {
var selectedParts []*part
for _, pw := range snap.parts {
diff --git a/banyand/internal/sidx/sidx.go b/banyand/internal/sidx/sidx.go
index 181497702..7778610e1 100644
--- a/banyand/internal/sidx/sidx.go
+++ b/banyand/internal/sidx/sidx.go
@@ -45,7 +45,7 @@ const (
// sidx implements the SIDX interface with introduction channels for async
operations.
type sidx struct {
fileSystem fs.FileSystem
- snapshot *snapshot
+ snapshot *Snapshot
l *logger.Logger
pm protector.Memory
root string
@@ -233,7 +233,7 @@ func (s *sidx) Close() error {
}
// currentSnapshot returns the current snapshot with incremented reference
count.
-func (s *sidx) currentSnapshot() *snapshot {
+func (s *sidx) currentSnapshot() *Snapshot {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -248,6 +248,79 @@ func (s *sidx) currentSnapshot() *snapshot {
return nil
}
+// CurrentSnapshot returns the current snapshot with incremented reference
count.
+// Implements snapshot.Manager[*Snapshot] interface.
+func (s *sidx) CurrentSnapshot() *Snapshot {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if s.snapshot == nil {
+ return nil
+ }
+
+ s.snapshot.IncRef()
+ return s.snapshot
+}
+
+// ReplaceSnapshot atomically replaces the current snapshot with next.
+// Implements snapshot.Manager[*Snapshot] interface.
+// The old snapshot's DecRef is called automatically per the Manager contract.
+func (s *sidx) ReplaceSnapshot(next *Snapshot) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.snapshot != nil {
+ s.snapshot.DecRef()
+ }
+ s.snapshot = next
+}
+
+// PrepareMemPart prepares a transition for introducing a memory part.
+func (s *sidx) PrepareMemPart(partID uint64, mp *MemPart) func(cur *Snapshot)
*Snapshot {
+ return func(cur *Snapshot) *Snapshot {
+ if cur == nil {
+ cur = &Snapshot{}
+ }
+ next := cur.copyAllTo()
+ part := openMemPart(mp)
+ pw := newPartWrapper(mp, part)
+ pw.p.partMetadata.ID = partID
+ next.parts = append(next.parts, pw)
+ return next
+ }
+}
+
+// PrepareFlushed prepares a transition for introducing flushed parts.
+func (s *sidx) PrepareFlushed(intro *FlusherIntroduction) func(cur *Snapshot)
*Snapshot {
+ return func(cur *Snapshot) *Snapshot {
+ if cur == nil {
+ s.l.Panic().Msg("current snapshot is nil in
PrepareFlushed")
+ }
+ return cur.merge(intro.flushed)
+ }
+}
+
+// PrepareMerged prepares a transition for introducing merged parts.
+func (s *sidx) PrepareMerged(intro *MergerIntroduction) func(cur *Snapshot)
*Snapshot {
+ return func(cur *Snapshot) *Snapshot {
+ if cur == nil {
+ s.l.Panic().Msg("current snapshot is nil in
PrepareMerged")
+ }
+ next := cur.remove(intro.merged)
+ next.parts = append(next.parts, intro.newPart)
+ return next
+ }
+}
+
+// PrepareSynced prepares a transition for removing synced parts.
+func (s *sidx) PrepareSynced(partIDsToSync map[uint64]struct{}) func(cur
*Snapshot) *Snapshot {
+ return func(cur *Snapshot) *Snapshot {
+ if cur == nil {
+ s.l.Panic().Msg("current snapshot is nil in
PrepareSynced")
+ }
+ return cur.remove(partIDsToSync)
+ }
+}
+
// blockCursor represents a cursor for iterating through a loaded block,
similar to query_by_ts.go.
type blockCursor struct {
p *part
diff --git a/banyand/internal/sidx/snapshot.go
b/banyand/internal/sidx/snapshot.go
index c59c3c396..0cce7df83 100644
--- a/banyand/internal/sidx/snapshot.go
+++ b/banyand/internal/sidx/snapshot.go
@@ -25,10 +25,10 @@ import (
"github.com/apache/skywalking-banyandb/pkg/logger"
)
-// snapshot represents an immutable collection of parts at a specific epoch.
+// Snapshot represents an immutable collection of parts at a specific epoch.
// It provides safe concurrent access to parts through reference counting and
// enables queries to work with a consistent view of data.
-type snapshot struct {
+type Snapshot struct {
// parts contains all active parts sorted by epoch (oldest first)
parts []*partWrapper
@@ -38,8 +38,8 @@ type snapshot struct {
// newSnapshot creates a new snapshot with the given parts and epoch.
// The snapshot starts with a reference count of 1.
-func newSnapshot(parts []*partWrapper) *snapshot {
- s := &snapshot{}
+func newSnapshot(parts []*partWrapper) *Snapshot {
+ s := &Snapshot{}
s.parts = append(s.parts[:0], parts...)
s.ref = 1
@@ -56,20 +56,23 @@ func newSnapshot(parts []*partWrapper) *snapshot {
return s
}
+// IncRef increments the snapshot reference count.
+// Implements the snapshot.Snapshot interface.
+func (s *Snapshot) IncRef() {
+ atomic.AddInt32(&s.ref, 1)
+}
+
// acquire increments the snapshot reference count.
// Returns true if successful, false if snapshot has been released.
-func (s *snapshot) acquire() bool {
+// This is an internal method for backward compatibility.
+func (s *Snapshot) acquire() bool {
return atomic.AddInt32(&s.ref, 1) > 0
}
-// decRef decrements the snapshot reference count (helper for snapshot
interface).
-func (s *snapshot) decRef() {
- s.release()
-}
-
-// release decrements the snapshot reference count.
+// DecRef decrements the snapshot reference count.
// When the count reaches zero, all part references are released.
-func (s *snapshot) release() {
+// Implements the snapshot.Snapshot interface.
+func (s *Snapshot) DecRef() {
newRef := atomic.AddInt32(&s.ref, -1)
if newRef > 0 {
return
@@ -84,10 +87,21 @@ func (s *snapshot) release() {
s.reset()
}
+// decRef is an internal helper that calls DecRef.
+func (s *Snapshot) decRef() {
+ s.DecRef()
+}
+
+// release is an internal helper that calls DecRef.
+// Kept for backward compatibility.
+func (s *Snapshot) release() {
+ s.DecRef()
+}
+
// getParts returns parts that potentially contain data within the specified
key range.
// This method filters parts based on their key ranges to minimize I/O during
queries.
// Parts are returned in epoch order (oldest first) for consistent iteration.
-func (s *snapshot) getParts(minKey, maxKey int64) []*partWrapper {
+func (s *Snapshot) getParts(minKey, maxKey int64) []*partWrapper {
var result []*partWrapper
for _, pw := range s.parts {
@@ -112,7 +126,7 @@ func (s *snapshot) getParts(minKey, maxKey int64)
[]*partWrapper {
// getPartsAll returns all active parts in the snapshot.
// This is used when querying without key range restrictions.
-func (s *snapshot) getPartsAll() []*partWrapper {
+func (s *Snapshot) getPartsAll() []*partWrapper {
var result []*partWrapper
for _, pw := range s.parts {
@@ -125,17 +139,17 @@ func (s *snapshot) getPartsAll() []*partWrapper {
}
// getPartCount returns the number of parts in the snapshot.
-func (s *snapshot) getPartCount() int {
+func (s *Snapshot) getPartCount() int {
return len(s.getPartsAll())
}
// refCount returns the current reference count (for testing/debugging).
-func (s *snapshot) refCount() int32 {
+func (s *Snapshot) refCount() int32 {
return atomic.LoadInt32(&s.ref)
}
// validate checks snapshot consistency and part availability.
-func (s *snapshot) validate() error {
+func (s *Snapshot) validate() error {
if atomic.LoadInt32(&s.ref) <= 0 {
return fmt.Errorf("snapshot has zero or negative reference
count")
}
@@ -160,7 +174,7 @@ func (s *snapshot) validate() error {
// addPart adds a new part to the snapshot during construction.
// This should only be called before the snapshot is made available to other
goroutines.
// After construction, snapshots should be treated as immutable.
-func (s *snapshot) addPart(pw *partWrapper) {
+func (s *Snapshot) addPart(pw *partWrapper) {
if pw != nil && pw.acquire() {
s.parts = append(s.parts, pw)
}
@@ -168,7 +182,7 @@ func (s *snapshot) addPart(pw *partWrapper) {
// removePart marks a part for removal from future snapshots.
// The part remains accessible in this snapshot until the snapshot is released.
-func (s *snapshot) removePart(partID uint64) {
+func (s *Snapshot) removePart(partID uint64) {
for _, pw := range s.parts {
if pw.ID() == partID {
pw.markForRemoval()
@@ -178,7 +192,7 @@ func (s *snapshot) removePart(partID uint64) {
}
// reset clears the snapshot for reuse.
-func (s *snapshot) reset() {
+func (s *Snapshot) reset() {
// Release all part references
for _, pw := range s.parts {
if pw != nil {
@@ -191,7 +205,7 @@ func (s *snapshot) reset() {
}
// String returns a string representation of the snapshot.
-func (s *snapshot) String() string {
+func (s *Snapshot) String() string {
activeCount := s.getPartCount()
return fmt.Sprintf("snapshot{parts=%d/%d, ref=%d}",
activeCount, len(s.parts), s.refCount())
@@ -202,8 +216,8 @@ func parseEpoch(epochStr string) (uint64, error) {
}
// copyAllTo creates a new snapshot with all parts from current snapshot.
-func (s *snapshot) copyAllTo() *snapshot {
- var result snapshot
+func (s *Snapshot) copyAllTo() *Snapshot {
+ var result Snapshot
result.parts = make([]*partWrapper, len(s.parts))
result.ref = 1
@@ -219,8 +233,8 @@ func (s *snapshot) copyAllTo() *snapshot {
}
// merge creates a new snapshot by merging flushed parts into the current
snapshot.
-func (s *snapshot) merge(nextParts map[uint64]*partWrapper) *snapshot {
- var result snapshot
+func (s *Snapshot) merge(nextParts map[uint64]*partWrapper) *Snapshot {
+ var result Snapshot
result.ref = 1
for i := 0; i < len(s.parts); i++ {
if n, ok := nextParts[s.parts[i].ID()]; ok {
@@ -235,8 +249,8 @@ func (s *snapshot) merge(nextParts map[uint64]*partWrapper)
*snapshot {
}
// remove creates a new snapshot by removing specified parts.
-func (s *snapshot) remove(toRemove map[uint64]struct{}) *snapshot {
- var result snapshot
+func (s *Snapshot) remove(toRemove map[uint64]struct{}) *Snapshot {
+ var result Snapshot
result.ref = 1
// Copy parts except those being removed
diff --git a/banyand/internal/snapshot/snapshot.go
b/banyand/internal/snapshot/snapshot.go
new file mode 100644
index 000000000..1ab8b308a
--- /dev/null
+++ b/banyand/internal/snapshot/snapshot.go
@@ -0,0 +1,219 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Package snapshot provides generic transaction coordination for
snapshot-based systems.
+// It enables atomic updates across multiple heterogeneous snapshot managers
using Go generics.
+package snapshot
+
+import (
+ "reflect"
+ "sync"
+
+ "github.com/apache/skywalking-banyandb/pkg/pool"
+)
+
+// Snapshot is a type constraint for snapshot types that support reference
counting.
+// Any type implementing this interface can participate in atomic transactions.
+type Snapshot interface {
+ // IncRef increments the reference count.
+ IncRef()
+ // DecRef decrements the reference count and releases resources when
zero.
+ DecRef()
+}
+
+// Manager is a generic interface for managing snapshots of type S.
+// Both trace and sidx implement this interface for their respective snapshot
types.
+type Manager[S Snapshot] interface {
+ // CurrentSnapshot returns the current snapshot with incremented ref
count.
+ // Returns nil if no snapshot exists.
+ CurrentSnapshot() S
+ // ReplaceSnapshot atomically replaces the current snapshot with next.
+ // The old snapshot's DecRef is called automatically.
+ ReplaceSnapshot(next S)
+}
+
+// Transition represents a prepared but uncommitted snapshot change.
+// It holds both the current and next snapshots, allowing atomic commit or
rollback.
+type Transition[S Snapshot] struct {
+ manager Manager[S]
+ current S
+ next S
+ committed bool
+}
+
+// getTransitionPool returns or creates a pool for the given snapshot type.
+func getTransitionPool[S Snapshot]() *pool.Synced[any] {
+ var zero S
+ typ := reflect.TypeOf(zero)
+
+ if p, ok := transitionPools.Load(typ); ok {
+ return p.(*pool.Synced[any])
+ }
+
+ // Create new pool for this type
+ poolName := "snapshot.Transition[" + typ.String() + "]"
+ p := pool.Register[any](poolName)
+ actual, _ := transitionPools.LoadOrStore(typ, p)
+ return actual.(*pool.Synced[any])
+}
+
+// NewTransition creates a transition by preparing the next snapshot from the
pool.
+// The prepareNext function receives the current snapshot and returns the next.
+func NewTransition[S Snapshot](manager Manager[S], prepareNext func(current S)
S) *Transition[S] {
+ p := getTransitionPool[S]()
+
+ var t *Transition[S]
+ if pooled := p.Get(); pooled != nil {
+ t = pooled.(*Transition[S])
+ } else {
+ t = &Transition[S]{}
+ }
+
+ current := manager.CurrentSnapshot()
+ next := prepareNext(current)
+
+ t.manager = manager
+ t.current = current
+ t.next = next
+ t.committed = false
+
+ return t
+}
+
+// Commit commits the transition, replacing the current snapshot with next.
+// ReplaceSnapshot satisfies the Manager contract by calling DecRef on the old
snapshot.
+func (t *Transition[S]) Commit() {
+ if t.committed {
+ return
+ }
+ t.committed = true
+ t.manager.ReplaceSnapshot(t.next)
+}
+
+// Rollback discards the prepared next snapshot without committing.
+func (t *Transition[S]) Rollback() {
+ if t.committed {
+ return
+ }
+ // Release the prepared next snapshot
+ // Use reflection to check if underlying value is nil (Go interface
quirk)
+ if !reflect.ValueOf(t.next).IsNil() {
+ t.next.DecRef()
+ }
+ // Release reference to current (acquired in NewTransition)
+ if !reflect.ValueOf(t.current).IsNil() {
+ t.current.DecRef()
+ }
+}
+
+// Release returns the transition to the pool for reuse.
+// This should be called after Commit or Rollback to reduce allocations.
+func (t *Transition[S]) Release() {
+ t.reset()
+ p := getTransitionPool[S]()
+ p.Put(any(t))
+}
+
+// reset clears the transition state for reuse.
+// When the transition was committed, the ref acquired in NewTransition
(t.current) was not
+// decremented by ReplaceSnapshot (which only decrements the manager's ref),
so we release it here.
+func (t *Transition[S]) reset() {
+ var zero S
+ if t.committed && !reflect.ValueOf(t.current).IsNil() {
+ t.current.DecRef()
+ }
+ t.manager = nil
+ t.current = zero
+ t.next = zero
+ t.committed = false
+}
+
+var (
+ transactionPool = pool.Register[*Transaction]("snapshot.Transaction")
+ transitionPools sync.Map // map[reflect.Type]*pool.Synced[any]
+)
+
+// Transaction coordinates atomic commits across multiple heterogeneous
transitions.
+// It uses type erasure via function closures to handle different snapshot
types.
+type Transaction struct {
+ commits []func()
+ rollbacks []func()
+ mu sync.Mutex
+ finalized bool
+}
+
+// NewTransaction creates a new empty transaction from the pool.
+func NewTransaction() *Transaction {
+ txn := transactionPool.Get()
+ if txn == nil {
+ txn = &Transaction{}
+ }
+ return txn
+}
+
+// Release returns the transaction to the pool for reuse.
+// This should be called after Commit or Rollback to reduce allocations.
+func (txn *Transaction) Release() {
+ txn.reset()
+ transactionPool.Put(txn)
+}
+
+// reset clears the transaction state for reuse.
+func (txn *Transaction) reset() {
+ txn.commits = txn.commits[:0]
+ txn.rollbacks = txn.rollbacks[:0]
+ txn.finalized = false
+}
+
+// AddTransition adds a typed transition to the transaction.
+// This is a generic function (not a method) because Go doesn't support
generic methods.
+func AddTransition[S Snapshot](txn *Transaction, transition *Transition[S]) {
+ txn.commits = append(txn.commits, transition.Commit)
+ txn.rollbacks = append(txn.rollbacks, transition.Rollback)
+}
+
+// Commit atomically commits all transitions in the transaction.
+// All commits happen under a single lock to ensure atomicity.
+func (txn *Transaction) Commit() {
+ txn.mu.Lock()
+ defer txn.mu.Unlock()
+
+ if txn.finalized {
+ return
+ }
+ txn.finalized = true
+
+ for _, commit := range txn.commits {
+ commit()
+ }
+}
+
+// Rollback discards all prepared transitions without committing.
+// Rollbacks are executed in reverse order (LIFO).
+func (txn *Transaction) Rollback() {
+ txn.mu.Lock()
+ defer txn.mu.Unlock()
+
+ if txn.finalized {
+ return
+ }
+ txn.finalized = true
+
+ for i := len(txn.rollbacks) - 1; i >= 0; i-- {
+ txn.rollbacks[i]()
+ }
+}
diff --git a/banyand/internal/snapshot/snapshot_test.go
b/banyand/internal/snapshot/snapshot_test.go
new file mode 100644
index 000000000..b843c4653
--- /dev/null
+++ b/banyand/internal/snapshot/snapshot_test.go
@@ -0,0 +1,917 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package snapshot
+
+import (
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+
+ "github.com/apache/skywalking-banyandb/pkg/pool"
+)
+
+// mockSnapshot is a test implementation of the Snapshot interface.
+type mockSnapshot struct {
+ id int
+ ref atomic.Int32
+}
+
+func (m *mockSnapshot) IncRef() {
+ m.ref.Add(1)
+}
+
+func (m *mockSnapshot) DecRef() {
+ m.ref.Add(-1)
+}
+
+func (m *mockSnapshot) RefCount() int32 {
+ return m.ref.Load()
+}
+
+// mockManager is a test implementation of the Manager interface.
+type mockManager struct {
+ snapshot *mockSnapshot
+ mu sync.RWMutex
+}
+
+func (m *mockManager) CurrentSnapshot() *mockSnapshot {
+ m.mu.RLock()
+ defer m.mu.RUnlock()
+ if m.snapshot != nil {
+ m.snapshot.IncRef()
+ return m.snapshot
+ }
+ return nil
+}
+
+func (m *mockManager) ReplaceSnapshot(next *mockSnapshot) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ if m.snapshot != nil {
+ m.snapshot.DecRef()
+ }
+ m.snapshot = next
+}
+
+// getSnapshotPoolRefCounts returns ref counts for all pools with names
starting with "snapshot.".
+func getSnapshotPoolRefCounts() map[string]int {
+ all := pool.AllRefsCount()
+ result := make(map[string]int, len(all))
+ for name, count := range all {
+ if strings.HasPrefix(name, "snapshot.") {
+ result[name] = count
+ }
+ }
+ return result
+}
+
+// assertSnapshotPoolsNoLeak verifies that snapshot pool ref counts have not
increased (no leak).
+func assertSnapshotPoolsNoLeak(t *testing.T, before map[string]int) {
+ t.Helper()
+ after := getSnapshotPoolRefCounts()
+ for name, count := range after {
+ if beforeCount := before[name]; beforeCount != count {
+ t.Errorf("pool %s: ref count changed from %d to %d
(possible leak)", name, beforeCount, count)
+ }
+ }
+}
+
+func TestTransition_Commit(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ manager := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager.snapshot.IncRef()
+
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition.Release()
+
+ transition.Commit()
+
+ // Verify snapshot was replaced
+ current := manager.CurrentSnapshot()
+ if current == nil {
+ t.Fatal("expected non-nil snapshot after commit")
+ }
+ defer current.DecRef()
+
+ if current.id != 2 {
+ t.Errorf("expected snapshot id 2, got %d", current.id)
+ }
+}
+
+func TestTransition_Rollback(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ manager := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager.snapshot.IncRef()
+
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition.Release()
+
+ transition.Rollback()
+
+ // Verify snapshot was NOT replaced
+ current := manager.CurrentSnapshot()
+ if current == nil {
+ t.Fatal("expected non-nil snapshot after rollback")
+ }
+ defer current.DecRef()
+
+ if current.id != 1 {
+ t.Errorf("expected snapshot id 1, got %d", current.id)
+ }
+}
+
+func TestTransaction_SingleTransition(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ manager := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager.snapshot.IncRef()
+
+ txn := NewTransaction()
+ defer txn.Release()
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition.Release()
+ AddTransition(txn, transition)
+
+ txn.Commit()
+
+ // Verify snapshot was replaced
+ current := manager.CurrentSnapshot()
+ if current == nil {
+ t.Fatal("expected non-nil snapshot after commit")
+ }
+ defer current.DecRef()
+
+ if current.id != 2 {
+ t.Errorf("expected snapshot id 2, got %d", current.id)
+ }
+}
+
+func TestTransaction_MultipleTransitions(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ manager1 := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager1.snapshot.IncRef()
+
+ manager2 := &mockManager{
+ snapshot: &mockSnapshot{id: 10},
+ }
+ manager2.snapshot.IncRef()
+
+ txn := NewTransaction()
+ defer txn.Release()
+
+ transition1 := NewTransition(manager1, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition1.Release()
+ AddTransition(txn, transition1)
+
+ transition2 := NewTransition(manager2, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition2.Release()
+ AddTransition(txn, transition2)
+
+ txn.Commit()
+
+ // Verify both snapshots were replaced
+ current1 := manager1.CurrentSnapshot()
+ if current1 == nil {
+ t.Fatal("expected non-nil snapshot for manager1")
+ }
+ defer current1.DecRef()
+
+ if current1.id != 2 {
+ t.Errorf("expected snapshot id 2 for manager1, got %d",
current1.id)
+ }
+
+ current2 := manager2.CurrentSnapshot()
+ if current2 == nil {
+ t.Fatal("expected non-nil snapshot for manager2")
+ }
+ defer current2.DecRef()
+
+ if current2.id != 11 {
+ t.Errorf("expected snapshot id 11 for manager2, got %d",
current2.id)
+ }
+}
+
+func TestTransaction_Rollback(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ manager1 := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager1.snapshot.IncRef()
+
+ manager2 := &mockManager{
+ snapshot: &mockSnapshot{id: 10},
+ }
+ manager2.snapshot.IncRef()
+
+ txn := NewTransaction()
+ defer txn.Release()
+
+ transition1 := NewTransition(manager1, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition1.Release()
+ AddTransition(txn, transition1)
+
+ transition2 := NewTransition(manager2, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition2.Release()
+ AddTransition(txn, transition2)
+
+ txn.Rollback()
+
+ // Verify neither snapshot was replaced
+ current1 := manager1.CurrentSnapshot()
+ if current1 == nil {
+ t.Fatal("expected non-nil snapshot for manager1")
+ }
+ defer current1.DecRef()
+
+ if current1.id != 1 {
+ t.Errorf("expected snapshot id 1 for manager1, got %d",
current1.id)
+ }
+
+ current2 := manager2.CurrentSnapshot()
+ if current2 == nil {
+ t.Fatal("expected non-nil snapshot for manager2")
+ }
+ defer current2.DecRef()
+
+ if current2.id != 10 {
+ t.Errorf("expected snapshot id 10 for manager2, got %d",
current2.id)
+ }
+}
+
+func TestTransaction_Concurrent(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ manager := &mockManager{
+ snapshot: &mockSnapshot{id: 0},
+ }
+ manager.snapshot.IncRef()
+
+ const numGoroutines = 10
+
+ var wg sync.WaitGroup
+ wg.Add(numGoroutines)
+
+ for i := 0; i < numGoroutines; i++ {
+ go func() {
+ defer wg.Done()
+
+ txn := NewTransaction()
+ defer txn.Release()
+ transition := NewTransition(manager, func(cur
*mockSnapshot) *mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition.Release()
+ AddTransition(txn, transition)
+ txn.Commit()
+ }()
+ }
+
+ wg.Wait()
+
+ // Verify final snapshot exists and was updated
+ // Note: With concurrent updates, not all increments may be preserved
+ // because multiple goroutines might read the same snapshot
concurrently.
+ // The important thing is that no crashes occurred and the snapshot
changed.
+ current := manager.CurrentSnapshot()
+ if current == nil {
+ t.Fatal("expected non-nil snapshot after concurrent commits")
+ }
+ defer current.DecRef()
+
+ if current.id <= 0 {
+ t.Errorf("expected snapshot id > 0, got %d", current.id)
+ }
+ if current.id > numGoroutines {
+ t.Errorf("expected snapshot id <= %d, got %d", numGoroutines,
current.id)
+ }
+}
+
+func TestTransaction_IdempotentCommit(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ manager := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager.snapshot.IncRef()
+
+ txn := NewTransaction()
+ defer txn.Release()
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition.Release()
+ AddTransition(txn, transition)
+
+ // Commit multiple times
+ txn.Commit()
+ txn.Commit()
+ txn.Commit()
+
+ // Verify snapshot was replaced only once
+ current := manager.CurrentSnapshot()
+ if current == nil {
+ t.Fatal("expected non-nil snapshot")
+ }
+ defer current.DecRef()
+
+ if current.id != 2 {
+ t.Errorf("expected snapshot id 2, got %d", current.id)
+ }
+}
+
+func TestTransaction_IdempotentRollback(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ manager := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager.snapshot.IncRef()
+
+ txn := NewTransaction()
+ defer txn.Release()
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition.Release()
+ AddTransition(txn, transition)
+
+ // Rollback multiple times
+ txn.Rollback()
+ txn.Rollback()
+ txn.Rollback()
+
+ // Verify snapshot was not replaced
+ current := manager.CurrentSnapshot()
+ if current == nil {
+ t.Fatal("expected non-nil snapshot")
+ }
+ defer current.DecRef()
+
+ if current.id != 1 {
+ t.Errorf("expected snapshot id 1, got %d", current.id)
+ }
+}
+
+func TestTransaction_CommitAfterRollback(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ manager := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager.snapshot.IncRef()
+
+ txn := NewTransaction()
+ defer txn.Release()
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition.Release()
+ AddTransition(txn, transition)
+
+ // Rollback first, then try to commit
+ txn.Rollback()
+ txn.Commit()
+
+ // Verify snapshot was not replaced
+ current := manager.CurrentSnapshot()
+ if current == nil {
+ t.Fatal("expected non-nil snapshot")
+ }
+ defer current.DecRef()
+
+ if current.id != 1 {
+ t.Errorf("expected snapshot id 1, got %d", current.id)
+ }
+}
+
+func TestTransaction_HeterogeneousTypes(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ // Test with two different snapshot types to verify type safety
+ type alternateSnapshot struct {
+ value string
+ ref atomic.Int32
+ }
+
+ altSnap := &alternateSnapshot{value: "test"}
+ altSnap.ref.Store(1)
+
+ altManager := &struct {
+ snapshot *alternateSnapshot
+ mu sync.RWMutex
+ }{
+ snapshot: altSnap,
+ }
+
+ // Define methods inline for alternateSnapshot
+ incRef := func(s *alternateSnapshot) {
+ s.ref.Add(1)
+ }
+
+ // Create a custom manager for alternateSnapshot
+ currentAlt := func() *alternateSnapshot {
+ altManager.mu.RLock()
+ defer altManager.mu.RUnlock()
+ if altManager.snapshot != nil {
+ incRef(altManager.snapshot)
+ return altManager.snapshot
+ }
+ return nil
+ }
+
+ replaceAlt := func(next *alternateSnapshot) {
+ altManager.mu.Lock()
+ defer altManager.mu.Unlock()
+ altManager.snapshot = next
+ }
+
+ // This test verifies that the generic design compiles with different
types
+ // The actual runtime behavior is tested in other tests
+ _ = currentAlt
+ _ = replaceAlt
+}
+
+// BenchmarkTransaction_WithPool benchmarks transaction creation with pooling.
+func BenchmarkTransaction_WithPool(b *testing.B) {
+ manager := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager.snapshot.IncRef()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ txn := NewTransaction()
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ AddTransition(txn, transition)
+ txn.Commit()
+ txn.Release()
+ }
+}
+
+// BenchmarkTransaction_WithoutPool benchmarks transaction creation without
pooling.
+func BenchmarkTransaction_WithoutPool(b *testing.B) {
+ manager := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager.snapshot.IncRef()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // Simulate non-pooled allocation
+ txn := &Transaction{}
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ AddTransition(txn, transition)
+ txn.Commit()
+ // No Release() - simulates non-pooled version
+ }
+}
+
+// BenchmarkTransition_WithPool benchmarks transition creation with pooling.
+func BenchmarkTransition_WithPool(b *testing.B) {
+ manager := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager.snapshot.IncRef()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ transition.Commit()
+ transition.Release()
+ }
+}
+
+// BenchmarkTransition_WithoutPool benchmarks transition creation without
pooling.
+func BenchmarkTransition_WithoutPool(b *testing.B) {
+ manager := &mockManager{
+ snapshot: &mockSnapshot{id: 1},
+ }
+ manager.snapshot.IncRef()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // Directly allocate transition without pool
+ var zero *mockSnapshot
+ transition := &Transition[*mockSnapshot]{
+ manager: manager,
+ current: zero,
+ next: zero,
+ committed: false,
+ }
+ current := manager.CurrentSnapshot()
+ next := &mockSnapshot{id: current.id + 1}
+ next.IncRef()
+ transition.current = current
+ transition.next = next
+ transition.Commit()
+ // No Release() - simulates non-pooled version
+ }
+}
+
+// TestTransition_ReleaseReleasesCurrentRef verifies that Release() decrements
the ref
+// held in t.current after a committed transition (fixes leak where that ref
was never released).
+func TestTransition_ReleaseReleasesCurrentRef(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ oldSnapshot := &mockSnapshot{id: 1}
+ oldSnapshot.IncRef()
+ manager := &mockManager{snapshot: oldSnapshot}
+
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ transition.Commit()
+ transition.Release()
+
+ // After Release(), the ref we held in t.current must have been
released;
+ // ReplaceSnapshot already released the manager's ref, so oldSnapshot
should be at 0.
+ if oldSnapshot.RefCount() != 0 {
+ t.Errorf("expected old snapshot refcount 0 after
Commit+Release, got %d", oldSnapshot.RefCount())
+ }
+}
+
+// TestTransition_RefCountAfterCommit verifies reference counts are correct
after commit.
+func TestTransition_RefCountAfterCommit(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ oldSnapshot := &mockSnapshot{id: 1}
+ oldSnapshot.IncRef()
+
+ manager := &mockManager{
+ snapshot: oldSnapshot,
+ }
+
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition.Release()
+
+ newSnapshot := transition.next
+
+ // Before commit:
+ // - oldSnapshot: refcount = 2 (initial +1, CurrentSnapshot +1)
+ // - newSnapshot: refcount = 1 (created with +1)
+ if oldSnapshot.RefCount() != 2 {
+ t.Errorf("expected old snapshot refcount 2 before commit, got
%d", oldSnapshot.RefCount())
+ }
+ if newSnapshot.RefCount() != 1 {
+ t.Errorf("expected new snapshot refcount 1 before commit, got
%d", newSnapshot.RefCount())
+ }
+
+ transition.Commit()
+
+ // After commit:
+ // - oldSnapshot: refcount = 1 (Commit decremented it once)
+ // - newSnapshot: refcount = 1 (manager now holds it, no change)
+ if oldSnapshot.RefCount() != 1 {
+ t.Errorf("expected old snapshot refcount 1 after commit, got
%d", oldSnapshot.RefCount())
+ }
+ if newSnapshot.RefCount() != 1 {
+ t.Errorf("expected new snapshot refcount 1 after commit, got
%d", newSnapshot.RefCount())
+ }
+
+ // Verify manager has the new snapshot
+ current := manager.CurrentSnapshot()
+ if current == nil {
+ t.Fatal("expected non-nil snapshot from manager")
+ }
+ defer current.DecRef()
+
+ if current.id != 2 {
+ t.Errorf("expected current snapshot id 2, got %d", current.id)
+ }
+ if current.RefCount() != 2 {
+ t.Errorf("expected current snapshot refcount 2 after
CurrentSnapshot, got %d", current.RefCount())
+ }
+}
+
+// TestTransition_RefCountAfterRollback verifies reference counts are correct
after rollback.
+func TestTransition_RefCountAfterRollback(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ oldSnapshot := &mockSnapshot{id: 1}
+ oldSnapshot.IncRef()
+
+ manager := &mockManager{
+ snapshot: oldSnapshot,
+ }
+
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition.Release()
+
+ newSnapshot := transition.next
+
+ // Before rollback:
+ // - oldSnapshot: refcount = 2 (initial +1, CurrentSnapshot +1)
+ // - newSnapshot: refcount = 1 (created with +1)
+ if oldSnapshot.RefCount() != 2 {
+ t.Errorf("expected old snapshot refcount 2 before rollback, got
%d", oldSnapshot.RefCount())
+ }
+ if newSnapshot.RefCount() != 1 {
+ t.Errorf("expected new snapshot refcount 1 before rollback, got
%d", newSnapshot.RefCount())
+ }
+
+ transition.Rollback()
+
+ // After rollback:
+ // - oldSnapshot: refcount = 1 (Rollback decremented the current
reference)
+ // - newSnapshot: refcount = 0 (Rollback decremented the next reference)
+ if oldSnapshot.RefCount() != 1 {
+ t.Errorf("expected old snapshot refcount 1 after rollback, got
%d", oldSnapshot.RefCount())
+ }
+ if newSnapshot.RefCount() != 0 {
+ t.Errorf("expected new snapshot refcount 0 after rollback, got
%d", newSnapshot.RefCount())
+ }
+
+ // Verify manager still has the old snapshot
+ current := manager.CurrentSnapshot()
+ if current == nil {
+ t.Fatal("expected non-nil snapshot from manager")
+ }
+ defer current.DecRef()
+
+ if current.id != 1 {
+ t.Errorf("expected current snapshot id 1, got %d", current.id)
+ }
+}
+
+// TestTransaction_RefCountWithMultipleTransitions verifies reference counts
+// are correct when multiple transitions are committed together.
+func TestTransaction_RefCountWithMultipleTransitions(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ oldSnapshot1 := &mockSnapshot{id: 1}
+ oldSnapshot1.IncRef()
+ manager1 := &mockManager{snapshot: oldSnapshot1}
+
+ oldSnapshot2 := &mockSnapshot{id: 10}
+ oldSnapshot2.IncRef()
+ manager2 := &mockManager{snapshot: oldSnapshot2}
+
+ txn := NewTransaction()
+ defer txn.Release()
+
+ transition1 := NewTransition(manager1, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition1.Release()
+ AddTransition(txn, transition1)
+
+ transition2 := NewTransition(manager2, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition2.Release()
+ AddTransition(txn, transition2)
+
+ newSnapshot1 := transition1.next
+ newSnapshot2 := transition2.next
+
+ // Before commit:
+ // - oldSnapshot1: refcount = 2 (initial +1, CurrentSnapshot +1)
+ // - newSnapshot1: refcount = 1 (created with +1)
+ // - oldSnapshot2: refcount = 2 (initial +1, CurrentSnapshot +1)
+ // - newSnapshot2: refcount = 1 (created with +1)
+ if oldSnapshot1.RefCount() != 2 {
+ t.Errorf("expected old snapshot1 refcount 2 before commit, got
%d", oldSnapshot1.RefCount())
+ }
+ if newSnapshot1.RefCount() != 1 {
+ t.Errorf("expected new snapshot1 refcount 1 before commit, got
%d", newSnapshot1.RefCount())
+ }
+ if oldSnapshot2.RefCount() != 2 {
+ t.Errorf("expected old snapshot2 refcount 2 before commit, got
%d", oldSnapshot2.RefCount())
+ }
+ if newSnapshot2.RefCount() != 1 {
+ t.Errorf("expected new snapshot2 refcount 1 before commit, got
%d", newSnapshot2.RefCount())
+ }
+
+ txn.Commit()
+
+ // After commit:
+ // - oldSnapshot1: refcount = 1 (initial +1, Commit decremented once)
+ // - newSnapshot1: refcount = 1 (manager1 holds it)
+ // - oldSnapshot2: refcount = 1 (initial +1, Commit decremented once)
+ // - newSnapshot2: refcount = 1 (manager2 holds it)
+ if oldSnapshot1.RefCount() != 1 {
+ t.Errorf("expected old snapshot1 refcount 1 after commit, got
%d", oldSnapshot1.RefCount())
+ }
+ if newSnapshot1.RefCount() != 1 {
+ t.Errorf("expected new snapshot1 refcount 1 after commit, got
%d", newSnapshot1.RefCount())
+ }
+ if oldSnapshot2.RefCount() != 1 {
+ t.Errorf("expected old snapshot2 refcount 1 after commit, got
%d", oldSnapshot2.RefCount())
+ }
+ if newSnapshot2.RefCount() != 1 {
+ t.Errorf("expected new snapshot2 refcount 1 after commit, got
%d", newSnapshot2.RefCount())
+ }
+}
+
+// TestTransaction_RefCountAfterRollback verifies reference counts are correct
+// after rolling back multiple transitions.
+func TestTransaction_RefCountAfterRollback(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ oldSnapshot1 := &mockSnapshot{id: 1}
+ oldSnapshot1.IncRef()
+ manager1 := &mockManager{snapshot: oldSnapshot1}
+
+ oldSnapshot2 := &mockSnapshot{id: 10}
+ oldSnapshot2.IncRef()
+ manager2 := &mockManager{snapshot: oldSnapshot2}
+
+ txn := NewTransaction()
+ defer txn.Release()
+
+ transition1 := NewTransition(manager1, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition1.Release()
+ AddTransition(txn, transition1)
+
+ transition2 := NewTransition(manager2, func(cur *mockSnapshot)
*mockSnapshot {
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition2.Release()
+ AddTransition(txn, transition2)
+
+ newSnapshot1 := transition1.next
+ newSnapshot2 := transition2.next
+
+ txn.Rollback()
+
+ // After rollback:
+ // - oldSnapshot1: refcount = 1 (initial +1, rollback decremented
current ref)
+ // - newSnapshot1: refcount = 0 (rollback decremented next ref)
+ // - oldSnapshot2: refcount = 1 (initial +1, rollback decremented
current ref)
+ // - newSnapshot2: refcount = 0 (rollback decremented next ref)
+ if oldSnapshot1.RefCount() != 1 {
+ t.Errorf("expected old snapshot1 refcount 1 after rollback, got
%d", oldSnapshot1.RefCount())
+ }
+ if newSnapshot1.RefCount() != 0 {
+ t.Errorf("expected new snapshot1 refcount 0 after rollback, got
%d", newSnapshot1.RefCount())
+ }
+ if oldSnapshot2.RefCount() != 1 {
+ t.Errorf("expected old snapshot2 refcount 1 after rollback, got
%d", oldSnapshot2.RefCount())
+ }
+ if newSnapshot2.RefCount() != 0 {
+ t.Errorf("expected new snapshot2 refcount 0 after rollback, got
%d", newSnapshot2.RefCount())
+ }
+}
+
+// TestTransition_NoDoubleDecrement verifies that the fix prevents
double-decrement bug.
+// This is a regression test for the bug where ReplaceSnapshot and Commit both
+// decremented the same snapshot reference.
+func TestTransition_NoDoubleDecrement(t *testing.T) {
+ before := getSnapshotPoolRefCounts()
+ defer func() { assertSnapshotPoolsNoLeak(t, before) }()
+
+ oldSnapshot := &mockSnapshot{id: 1}
+ oldSnapshot.IncRef()
+
+ manager := &mockManager{
+ snapshot: oldSnapshot,
+ }
+
+ // Record initial refcount
+ initialRefCount := oldSnapshot.RefCount()
+ if initialRefCount != 1 {
+ t.Fatalf("expected initial refcount 1, got %d", initialRefCount)
+ }
+
+ transition := NewTransition(manager, func(cur *mockSnapshot)
*mockSnapshot {
+ // Verify cur is the same object as oldSnapshot
+ if cur != oldSnapshot {
+ t.Error("current snapshot should be the same object as
old snapshot")
+ }
+ // Verify refcount was incremented by CurrentSnapshot
+ if cur.RefCount() != 2 {
+ t.Errorf("expected cur refcount 2 (after
CurrentSnapshot), got %d", cur.RefCount())
+ }
+
+ next := &mockSnapshot{id: cur.id + 1}
+ next.IncRef()
+ return next
+ })
+ defer transition.Release()
+
+ transition.Commit()
+
+ // After commit, the old snapshot should have refcount 1
+ // (not 0 or negative which would indicate double-decrement)
+ finalRefCount := oldSnapshot.RefCount()
+ if finalRefCount != 1 {
+ t.Errorf("expected final refcount 1 (no double-decrement), got
%d", finalRefCount)
+ }
+
+ // Verify the old snapshot is still valid and wasn't prematurely
cleaned up
+ // In a real scenario, negative refcount would cause a crash or
corruption
+ if finalRefCount < 1 {
+ t.Error("double-decrement bug detected: refcount went below
expected value")
+ }
+}
diff --git a/banyand/internal/storage/segment_test.go
b/banyand/internal/storage/segment_test.go
index b7cd892d5..cb331568c 100644
--- a/banyand/internal/storage/segment_test.go
+++ b/banyand/internal/storage/segment_test.go
@@ -618,11 +618,11 @@ func TestDeleteExpiredSegmentsWithClosedSegments(t
*testing.T) {
assert.NotNil(t, segments[5].index, "Segment 5 should remain open")
// Now delete expired segments
- // Get the time range for segments 0, 1, and 2 (the expired ones)
+ // Use the same segment dates (UTC) used when creating segments 0, 1, 2
to avoid timezone mismatch with time.Now()
deletedCount := sc.deleteExpiredSegments([]string{
- time.Now().AddDate(0, 0, -6).Format(dayFormat),
- time.Now().AddDate(0, 0, -5).Format(dayFormat),
- time.Now().AddDate(0, 0, -4).Format(dayFormat),
+ segmentDates[0].Format(dayFormat),
+ segmentDates[1].Format(dayFormat),
+ segmentDates[2].Format(dayFormat),
})
assert.Equal(t, int64(3), deletedCount, "Should have deleted 3 expired
segments")
diff --git a/banyand/trace/flusher.go b/banyand/trace/flusher.go
index 51d82dc5e..8ebe7baf0 100644
--- a/banyand/trace/flusher.go
+++ b/banyand/trace/flusher.go
@@ -236,6 +236,19 @@ func (tst *tsTable) flush(snapshot *snapshot, flushCh chan
*flusherIntroduction)
sidxFlusherIntroduced, err := sidxInstance.Flush(partIDMap)
if err != nil {
tst.l.Warn().Err(err).Str("sidx", name).Msg("sidx flush
failed")
+ for _, sidxIntro := range ind.sidxFlusherIntroduced {
+ sidxIntro.ReleaseFlushedParts()
+ sidxIntro.Release()
+ }
+ for partID := range partIDMap {
+ for sidxName := range ind.sidxFlusherIntroduced
{
+ tst.removeSidxPartOnFailure(sidxName,
partID)
+ }
+ tst.removeSidxPartOnFailure(name, partID)
+ }
+ for _, pw := range ind.flushed {
+ tst.removeTracePartOnFailure(pw)
+ }
return
}
ind.sidxFlusherIntroduced[name] = sidxFlusherIntroduced
diff --git a/banyand/trace/introducer.go b/banyand/trace/introducer.go
index 1b7e3e9a2..e33bfb52e 100644
--- a/banyand/trace/introducer.go
+++ b/banyand/trace/introducer.go
@@ -19,6 +19,7 @@ package trace
import (
"github.com/apache/skywalking-banyandb/banyand/internal/sidx"
+ snapshotpkg
"github.com/apache/skywalking-banyandb/banyand/internal/snapshot"
"github.com/apache/skywalking-banyandb/pkg/pool"
"github.com/apache/skywalking-banyandb/pkg/watcher"
)
@@ -223,61 +224,141 @@ func (tst *tsTable) introducerLoopWithSync(flushCh chan
*flusherIntroduction, me
}
func (tst *tsTable) introduceMemPart(nextIntroduction *introduction, epoch
uint64) {
- cur := tst.currentSnapshot()
- if cur != nil {
- defer cur.decRef()
- } else {
- cur = new(snapshot)
- }
+ // Create generic transaction
+ txn := snapshotpkg.NewTransaction()
+ defer txn.Release()
+ // Prepare trace snapshot transition
next := nextIntroduction.memPart
tst.addPendingDataCount(-int64(next.mp.partMetadata.TotalCount))
- nextSnp := cur.copyAllTo(epoch)
- nextSnp.parts = append(nextSnp.parts, next)
- nextSnp.creator = snapshotCreatorMemPart
- tst.replaceSnapshot(&nextSnp)
+ partID := next.p.partMetadata.ID
+
+ traceTransition := snapshotpkg.NewTransition(tst, func(cur *snapshot)
*snapshot {
+ if cur == nil {
+ cur = new(snapshot)
+ }
+ nextSnp := cur.copyAllTo(epoch)
+ nextSnp.parts = append(nextSnp.parts, next)
+ nextSnp.creator = snapshotCreatorMemPart
+ return &nextSnp
+ })
+ defer traceTransition.Release()
+ snapshotpkg.AddTransition(txn, traceTransition)
+
+ // Prepare sidx snapshot transitions
+ var sidxTransitions []*snapshotpkg.Transition[*sidx.Snapshot]
for name, memPart := range nextIntroduction.sidxReqsMap {
-
tst.mustGetOrCreateSidx(name).IntroduceMemPart(next.p.partMetadata.ID, memPart)
+ sidxInstance := tst.mustGetOrCreateSidx(name)
+ prepareFunc := sidxInstance.PrepareMemPart(partID, memPart)
+ sidxTransition := snapshotpkg.NewTransition(sidxInstance,
prepareFunc)
+ sidxTransitions = append(sidxTransitions, sidxTransition)
+ snapshotpkg.AddTransition(txn, sidxTransition)
}
+ defer func() {
+ for _, t := range sidxTransitions {
+ t.Release()
+ }
+ }()
+
+ // Commit all atomically under single transaction lock
+ txn.Commit()
+
if nextIntroduction.applied != nil {
close(nextIntroduction.applied)
}
}
func (tst *tsTable) introduceFlushed(nextIntroduction *flusherIntroduction,
epoch uint64) {
- cur := tst.currentSnapshot()
- if cur == nil {
- tst.l.Panic().Msg("current snapshot is nil")
- }
- defer cur.decRef()
- nextSnp := cur.merge(epoch, nextIntroduction.flushed)
- nextSnp.creator = snapshotCreatorFlusher
- tst.replaceSnapshot(&nextSnp)
- tst.persistSnapshot(&nextSnp)
+ // Create generic transaction
+ txn := snapshotpkg.NewTransaction()
+ defer txn.Release()
+
+ // Prepare trace snapshot transition
+ traceTransition := snapshotpkg.NewTransition(tst, func(cur *snapshot)
*snapshot {
+ if cur == nil {
+ tst.l.Panic().Msg("current snapshot is nil")
+ }
+ nextSnp := cur.merge(epoch, nextIntroduction.flushed)
+ nextSnp.creator = snapshotCreatorFlusher
+ return &nextSnp
+ })
+ defer traceTransition.Release()
+ snapshotpkg.AddTransition(txn, traceTransition)
+
+ // Prepare sidx snapshot transitions
+ var sidxTransitions []*snapshotpkg.Transition[*sidx.Snapshot]
for name, sidxFlusherIntroduced := range
nextIntroduction.sidxFlusherIntroduced {
- tst.mustGetSidx(name).IntroduceFlushed(sidxFlusherIntroduced)
+ sidxInstance := tst.mustGetSidx(name)
+ prepareFunc :=
sidxInstance.PrepareFlushed(sidxFlusherIntroduced)
+ sidxTransition := snapshotpkg.NewTransition(sidxInstance,
prepareFunc)
+ sidxTransitions = append(sidxTransitions, sidxTransition)
+ snapshotpkg.AddTransition(txn, sidxTransition)
+ }
+ defer func() {
+ for _, t := range sidxTransitions {
+ t.Release()
+ }
+ }()
+
+ // Commit all atomically under single transaction lock
+ txn.Commit()
+
+ // Persist snapshot after commit
+ cur := tst.currentSnapshot()
+ if cur != nil {
+ defer cur.decRef()
+ tst.persistSnapshot(cur)
}
+
if nextIntroduction.applied != nil {
close(nextIntroduction.applied)
}
}
// introduceFlushedForSync introduces the flushed trace parts for syncing.
-// The SIDX parts are flushed before the trace parts so the syncer can always
find
-// the corresponding index on disk once a flushed trace part becomes visible.
+// The snapshots are updated atomically so the syncer can always find
+// the corresponding index once a flushed trace part becomes visible.
func (tst *tsTable) introduceFlushedForSync(nextIntroduction
*flusherIntroduction, epoch uint64) {
+ // Create generic transaction
+ txn := snapshotpkg.NewTransaction()
+ defer txn.Release()
+
+ // Prepare sidx snapshot transitions first for visibility ordering
+ var sidxTransitions []*snapshotpkg.Transition[*sidx.Snapshot]
for name, sidxFlusherIntroduced := range
nextIntroduction.sidxFlusherIntroduced {
- tst.mustGetSidx(name).IntroduceFlushed(sidxFlusherIntroduced)
+ sidxInstance := tst.mustGetSidx(name)
+ prepareFunc :=
sidxInstance.PrepareFlushed(sidxFlusherIntroduced)
+ sidxTransition := snapshotpkg.NewTransition(sidxInstance,
prepareFunc)
+ sidxTransitions = append(sidxTransitions, sidxTransition)
+ snapshotpkg.AddTransition(txn, sidxTransition)
}
+ defer func() {
+ for _, t := range sidxTransitions {
+ t.Release()
+ }
+ }()
+
+ // Prepare trace snapshot transition
+ traceTransition := snapshotpkg.NewTransition(tst, func(cur *snapshot)
*snapshot {
+ if cur == nil {
+ tst.l.Panic().Msg("current snapshot is nil")
+ }
+ nextSnp := cur.merge(epoch, nextIntroduction.flushed)
+ nextSnp.creator = snapshotCreatorFlusher
+ return &nextSnp
+ })
+ defer traceTransition.Release()
+ snapshotpkg.AddTransition(txn, traceTransition)
+
+ // Commit all atomically under single transaction lock
+ txn.Commit()
+
+ // Persist snapshot after commit
cur := tst.currentSnapshot()
- if cur == nil {
- tst.l.Panic().Msg("current snapshot is nil")
+ if cur != nil {
+ defer cur.decRef()
+ tst.persistSnapshot(cur)
}
- defer cur.decRef()
- nextSnp := cur.merge(epoch, nextIntroduction.flushed)
- nextSnp.creator = snapshotCreatorFlusher
- tst.replaceSnapshot(&nextSnp)
- tst.persistSnapshot(&nextSnp)
if nextIntroduction.applied != nil {
close(nextIntroduction.applied)
@@ -285,23 +366,48 @@ func (tst *tsTable)
introduceFlushedForSync(nextIntroduction *flusherIntroductio
}
func (tst *tsTable) introduceMerged(nextIntroduction *mergerIntroduction,
epoch uint64) {
- cur := tst.currentSnapshot()
- if cur == nil {
- tst.l.Panic().Msg("current snapshot is nil")
- return
- }
- defer cur.decRef()
- nextSnp := cur.remove(epoch, nextIntroduction.merged)
- nextSnp.parts = append(nextSnp.parts, nextIntroduction.newPart)
- nextSnp.creator = nextIntroduction.creator
- tst.replaceSnapshot(&nextSnp)
- tst.persistSnapshot(&nextSnp)
+ // Create generic transaction
+ txn := snapshotpkg.NewTransaction()
+ defer txn.Release()
+
+ // Prepare trace snapshot transition
+ traceTransition := snapshotpkg.NewTransition(tst, func(cur *snapshot)
*snapshot {
+ if cur == nil {
+ tst.l.Panic().Msg("current snapshot is nil")
+ }
+ nextSnp := cur.remove(epoch, nextIntroduction.merged)
+ nextSnp.parts = append(nextSnp.parts, nextIntroduction.newPart)
+ nextSnp.creator = nextIntroduction.creator
+ return &nextSnp
+ })
+ defer traceTransition.Release()
+ snapshotpkg.AddTransition(txn, traceTransition)
+
+ // Prepare sidx snapshot transitions
+ var sidxTransitions []*snapshotpkg.Transition[*sidx.Snapshot]
for name, sidxMergerIntroduced := range
nextIntroduction.sidxMergerIntroduced {
- deferFuncs :=
tst.mustGetSidx(name).IntroduceMerged(sidxMergerIntroduced)
- if deferFuncs != nil {
- defer deferFuncs()
+ sidxInstance := tst.mustGetSidx(name)
+ prepareFunc := sidxInstance.PrepareMerged(sidxMergerIntroduced)
+ sidxTransition := snapshotpkg.NewTransition(sidxInstance,
prepareFunc)
+ sidxTransitions = append(sidxTransitions, sidxTransition)
+ snapshotpkg.AddTransition(txn, sidxTransition)
+ }
+ defer func() {
+ for _, t := range sidxTransitions {
+ t.Release()
}
+ }()
+
+ // Commit all atomically under single transaction lock
+ txn.Commit()
+
+ // Persist snapshot after commit
+ cur := tst.currentSnapshot()
+ if cur != nil {
+ defer cur.decRef()
+ tst.persistSnapshot(cur)
}
+
if nextIntroduction.applied != nil {
close(nextIntroduction.applied)
}
@@ -309,33 +415,73 @@ func (tst *tsTable) introduceMerged(nextIntroduction
*mergerIntroduction, epoch
func (tst *tsTable) introduceSync(nextIntroduction *syncIntroduction, epoch
uint64) {
synced := nextIntroduction.synced
+
+ // Create generic transaction
+ txn := snapshotpkg.NewTransaction()
+ defer txn.Release()
+
+ // Prepare sidx snapshot transitions
+ var sidxTransitions []*snapshotpkg.Transition[*sidx.Snapshot]
allSidx := tst.getAllSidx()
- for _, sidx := range allSidx {
- deferFuncs := sidx.IntroduceSynced(synced)
- if deferFuncs != nil {
- defer deferFuncs()
- }
+ for _, sidxInstance := range allSidx {
+ prepareFunc := sidxInstance.PrepareSynced(synced)
+ sidxTransition := snapshotpkg.NewTransition(sidxInstance,
prepareFunc)
+ sidxTransitions = append(sidxTransitions, sidxTransition)
+ snapshotpkg.AddTransition(txn, sidxTransition)
}
+ defer func() {
+ for _, t := range sidxTransitions {
+ t.Release()
+ }
+ }()
+
+ // Prepare trace snapshot transition
+ traceTransition := snapshotpkg.NewTransition(tst, func(cur *snapshot)
*snapshot {
+ if cur == nil {
+ tst.l.Panic().Msg("current snapshot is nil")
+ }
+ nextSnp := cur.remove(epoch, synced)
+ nextSnp.creator = snapshotCreatorSyncer
+ return &nextSnp
+ })
+ defer traceTransition.Release()
+ snapshotpkg.AddTransition(txn, traceTransition)
+
+ // Commit all atomically under single transaction lock
+ txn.Commit()
+
+ // Persist snapshot after commit
cur := tst.currentSnapshot()
- if cur == nil {
- tst.l.Panic().Msg("current snapshot is nil")
- return
+ if cur != nil {
+ defer cur.decRef()
+ tst.persistSnapshot(cur)
}
- defer cur.decRef()
- nextSnp := cur.remove(epoch, synced)
- nextSnp.creator = snapshotCreatorSyncer
- tst.replaceSnapshot(&nextSnp)
- tst.persistSnapshot(&nextSnp)
+
if nextIntroduction.applied != nil {
close(nextIntroduction.applied)
}
}
-func (tst *tsTable) replaceSnapshot(next *snapshot) {
+// CurrentSnapshot returns the current snapshot with incremented reference
count.
+// Implements snapshot.Manager[*snapshot] interface.
+func (tst *tsTable) CurrentSnapshot() *snapshot {
+ tst.RLock()
+ defer tst.RUnlock()
+ if tst.snapshot == nil {
+ return nil
+ }
+ tst.snapshot.IncRef()
+ return tst.snapshot
+}
+
+// ReplaceSnapshot atomically replaces the current snapshot with next.
+// Implements snapshot.Manager[*snapshot] interface.
+// The old snapshot's DecRef is called automatically per the Manager contract.
+func (tst *tsTable) ReplaceSnapshot(next *snapshot) {
tst.Lock()
defer tst.Unlock()
if tst.snapshot != nil {
- tst.snapshot.decRef()
+ tst.snapshot.DecRef()
}
tst.snapshot = next
}
diff --git a/banyand/trace/merger.go b/banyand/trace/merger.go
index 55a9d2c81..875f98fb8 100644
--- a/banyand/trace/merger.go
+++ b/banyand/trace/merger.go
@@ -20,6 +20,7 @@ package trace
import (
"errors"
"fmt"
+ "path/filepath"
"sync/atomic"
"time"
@@ -164,10 +165,20 @@ func (tst *tsTable)
mergePartsThenSendIntroduction(creator snapshotCreator, part
mergerIntroductionMap := make(map[string]*sidx.MergerIntroduction)
for sidxName, sidxInstance := range tst.getAllSidx() {
start = time.Now()
- mergerIntroduction, err := sidxInstance.Merge(closeCh,
partIDMap, newPartID)
- if err != nil {
- tst.l.Warn().Err(err).Msg("sidx merge mem parts failed")
- return nil, err
+ mergerIntroduction, mergeErr := sidxInstance.Merge(closeCh,
partIDMap, newPartID)
+ if mergeErr != nil {
+ tst.l.Warn().Err(mergeErr).Msg("sidx merge mem parts
failed")
+ tst.removeSidxPartOnFailure(sidxName, newPartID)
+ tst.removeTracePartOnFailure(newPart)
+ for doneSidxName, intro := range mergerIntroductionMap {
+ intro.ReleaseNewPart()
+ tst.removeSidxPartOnFailure(doneSidxName,
newPartID)
+ intro.Release()
+ }
+ return nil, mergeErr
+ }
+ if mergerIntroduction == nil {
+ continue
}
mergerIntroductionMap[sidxName] = mergerIntroduction
elapsed = time.Since(start)
@@ -270,6 +281,29 @@ func (tst *tsTable) reserveSpace(parts []*partWrapper)
uint64 {
var errNoPartToMerge = fmt.Errorf("no part to merge")
+// removeTracePartOnFailure closes the part and removes its directory from
disk.
+// Used when a merge fails after the trace part was created so the directory
is not left as trash.
+func (tst *tsTable) removeTracePartOnFailure(pw *partWrapper) {
+ if pw == nil {
+ return
+ }
+ pathToRemove := pw.p.path
+ pw.decRef()
+ tst.fileSystem.MustRMAll(pathToRemove)
+}
+
+// sidxPartPath returns the on-disk path for a sidx part (same layout as sidx
package).
+func sidxPartPath(traceRoot, sidxName string, partID uint64) string {
+ return filepath.Join(traceRoot, sidxDirName, sidxName,
fmt.Sprintf("%016x", partID))
+}
+
+// removeSidxPartOnFailure removes a sidx part directory from disk.
+// Used when a merge fails after one or more sidx parts were created.
+func (tst *tsTable) removeSidxPartOnFailure(sidxName string, partID uint64) {
+ pathToRemove := sidxPartPath(tst.root, sidxName, partID)
+ tst.fileSystem.MustRMAll(pathToRemove)
+}
+
func (tst *tsTable) mergeParts(fileSystem fs.FileSystem, closeCh <-chan
struct{}, parts []*partWrapper, partID uint64, root string) (*partWrapper,
error) {
if len(parts) == 0 {
return nil, errNoPartToMerge
diff --git a/banyand/trace/merger_test.go b/banyand/trace/merger_test.go
index a01f2c470..f162b400e 100644
--- a/banyand/trace/merger_test.go
+++ b/banyand/trace/merger_test.go
@@ -19,6 +19,8 @@ package trace
import (
"errors"
+ "fmt"
+ "path/filepath"
"reflect"
"testing"
@@ -26,11 +28,14 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/require"
+ "github.com/apache/skywalking-banyandb/banyand/internal/sidx"
"github.com/apache/skywalking-banyandb/banyand/protector"
"github.com/apache/skywalking-banyandb/pkg/convert"
"github.com/apache/skywalking-banyandb/pkg/encoding"
"github.com/apache/skywalking-banyandb/pkg/fs"
+ "github.com/apache/skywalking-banyandb/pkg/logger"
pbv1 "github.com/apache/skywalking-banyandb/pkg/pb/v1"
+ "github.com/apache/skywalking-banyandb/pkg/run"
"github.com/apache/skywalking-banyandb/pkg/test"
)
@@ -1329,3 +1334,125 @@ func Test_mergeParts(t *testing.T) {
})
}
}
+
+// sidxMergeErrFake is a SIDX that returns an error from Merge to test cleanup
on failure.
+type sidxMergeErrFake struct {
+ fakeSIDX
+}
+
+func (f *sidxMergeErrFake) Merge(<-chan struct{}, map[uint64]struct{}, uint64)
(*sidx.MergerIntroduction, error) {
+ return nil, errors.New("sidx merge failed")
+}
+
+// Test_mergePartsThenSendIntroduction_cleansUpOnSidxMergeError verifies that
when sidx.Merge
+// returns an error, the newly created trace part and any sidx parts are
removed from disk.
+func Test_mergePartsThenSendIntroduction_cleansUpOnSidxMergeError(t
*testing.T) {
+ tmpPath, defFn := test.Space(require.New(t))
+ defer defFn()
+
+ fileSystem := fs.NewLocalFileSystem()
+ // Create two file parts with IDs 10 and 11 so the merged part gets
newPartID 1 (curPartID starts at 0).
+ var parts []*partWrapper
+ for i, ts := range []*traces{tsTS1, tsTS2} {
+ partID := uint64(10 + i)
+ mp := generateMemPart()
+ mp.mustInitFromTraces(ts)
+ mp.mustFlush(fileSystem, partPath(tmpPath, partID))
+ p := mustOpenFilePart(partID, tmpPath, fileSystem)
+ p.partMetadata.ID = partID
+ parts = append(parts, newPartWrapper(nil, p))
+ releaseMemPart(mp)
+ }
+ defer func() {
+ for _, pw := range parts {
+ pw.decRef()
+ }
+ }()
+
+ closer := run.NewCloser(1)
+ defer closer.Done()
+ l := logger.GetLogger("trace-merger-test")
+ tst := &tsTable{
+ pm: protector.Nop{},
+ fileSystem: fileSystem,
+ root: tmpPath,
+ loopCloser: closer,
+ l: l,
+ curPartID: 0,
+ sidxMap: map[string]sidx.SIDX{
+ "idx1": &sidxMergeErrFake{fakeSIDX{}},
+ },
+ }
+
+ merged := make(map[uint64]struct{})
+ for _, pw := range parts {
+ merged[pw.ID()] = struct{}{}
+ }
+ merges := make(chan *mergerIntroduction, 1)
+ closeCh := make(chan struct{})
+
+ _, err := tst.mergePartsThenSendIntroduction(snapshotCreatorMerger,
parts, merged, merges, closeCh, mergeTypeFile)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "sidx merge failed")
+
+ // New part ID is 1 (curPartID was 0, then AddUint64(..., 1)).
+ newPartID := uint64(1)
+ tracePartPath := partPath(tmpPath, newPartID)
+ require.False(t, fileSystem.IsExist(tracePartPath), "trace part
directory should be removed on sidx merge failure")
+
+ sidxPartPath := filepath.Join(tmpPath, sidxDirName, "idx1",
fmt.Sprintf("%016x", newPartID))
+ require.False(t, fileSystem.IsExist(sidxPartPath), "sidx part directory
should not exist (no sidx part created before first Merge)")
+}
+
+// sidxFlushErrFake is a SIDX that returns an error from Flush to test cleanup
on failure.
+type sidxFlushErrFake struct {
+ fakeSIDX
+}
+
+func (f *sidxFlushErrFake) Flush(map[uint64]struct{})
(*sidx.FlusherIntroduction, error) {
+ return nil, errors.New("sidx flush failed")
+}
+
+// Test_flush_cleansUpOnSidxFlushError verifies that when any sidx.Flush
returns an error,
+// the newly flushed trace parts and any sidx parts are removed from disk.
+func Test_flush_cleansUpOnSidxFlushError(t *testing.T) {
+ tmpPath, defFn := test.Space(require.New(t))
+ defer defFn()
+
+ fileSystem := fs.NewLocalFileSystem()
+ partID := uint64(100)
+ mp := generateMemPart()
+ mp.mustInitFromTraces(tsTS1)
+ mp.partMetadata.ID = partID
+ pw := newPartWrapper(mp, openMemPart(mp))
+ pw.p.partMetadata.ID = partID
+ pw.incRef()
+ snp := &snapshot{
+ parts: []*partWrapper{pw},
+ epoch: 1,
+ ref: 1,
+ }
+ defer snp.decRef()
+
+ closer := run.NewCloser(1)
+ defer closer.Done()
+ l := logger.GetLogger("trace-flusher-test")
+ tst := &tsTable{
+ pm: protector.Nop{},
+ fileSystem: fileSystem,
+ root: tmpPath,
+ loopCloser: closer,
+ l: l,
+ sidxMap: map[string]sidx.SIDX{
+ "idx1": &sidxFlushErrFake{fakeSIDX{}},
+ },
+ }
+
+ flushCh := make(chan *flusherIntroduction, 1)
+ tst.flush(snp, flushCh)
+
+ tracePartPath := partPath(tmpPath, partID)
+ require.False(t, fileSystem.IsExist(tracePartPath), "trace part
directory should be removed on sidx flush failure")
+ sidxPartPath := filepath.Join(tmpPath, sidxDirName, "idx1",
fmt.Sprintf("%016x", partID))
+ require.False(t, fileSystem.IsExist(sidxPartPath), "sidx part directory
should be removed on sidx flush failure")
+}
diff --git a/banyand/trace/snapshot.go b/banyand/trace/snapshot.go
index 479cc2146..431587b07 100644
--- a/banyand/trace/snapshot.go
+++ b/banyand/trace/snapshot.go
@@ -88,11 +88,15 @@ func (s *snapshot) getParts(dst []*part, minTimestamp
int64, maxTimestamp int64,
return dst, count
}
-func (s *snapshot) incRef() {
+// IncRef increments the snapshot reference count.
+// Implements the snapshot.Snapshot interface.
+func (s *snapshot) IncRef() {
atomic.AddInt32(&s.ref, 1)
}
-func (s *snapshot) decRef() {
+// DecRef decrements the snapshot reference count.
+// Implements the snapshot.Snapshot interface.
+func (s *snapshot) DecRef() {
n := atomic.AddInt32(&s.ref, -1)
if n > 0 {
return
@@ -103,6 +107,16 @@ func (s *snapshot) decRef() {
s.parts = s.parts[:0]
}
+// incRef is an internal helper for backward compatibility.
+func (s *snapshot) incRef() {
+ s.IncRef()
+}
+
+// decRef is an internal helper for backward compatibility.
+func (s *snapshot) decRef() {
+ s.DecRef()
+}
+
func (s *snapshot) copyAllTo(nextEpoch uint64) snapshot {
var result snapshot
result.epoch = nextEpoch
diff --git a/banyand/trace/streaming_pipeline_test.go
b/banyand/trace/streaming_pipeline_test.go
index 339208d2e..863c32206 100644
--- a/banyand/trace/streaming_pipeline_test.go
+++ b/banyand/trace/streaming_pipeline_test.go
@@ -86,6 +86,24 @@ func (f *fakeSIDX) ScanQuery(context.Context,
sidx.ScanQueryRequest) ([]*sidx.Qu
return nil, nil
}
+func (f *fakeSIDX) PrepareMemPart(uint64, *sidx.MemPart) func(cur
*sidx.Snapshot) *sidx.Snapshot {
+ return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur }
+}
+
+func (f *fakeSIDX) PrepareFlushed(*sidx.FlusherIntroduction) func(cur
*sidx.Snapshot) *sidx.Snapshot {
+ return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur }
+}
+
+func (f *fakeSIDX) PrepareMerged(*sidx.MergerIntroduction) func(cur
*sidx.Snapshot) *sidx.Snapshot {
+ return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur }
+}
+
+func (f *fakeSIDX) PrepareSynced(map[uint64]struct{}) func(cur *sidx.Snapshot)
*sidx.Snapshot {
+ return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur }
+}
+func (f *fakeSIDX) CurrentSnapshot() *sidx.Snapshot { return nil }
+func (f *fakeSIDX) ReplaceSnapshot(*sidx.Snapshot) {}
+
type fakeSIDXWithErr struct {
*fakeSIDX
err error
@@ -662,6 +680,23 @@ func (f *fakeSIDXInfinite) PartPaths(map[uint64]struct{})
map[uint64]string {
}
func (f *fakeSIDXInfinite) IntroduceSynced(map[uint64]struct{}) func() {
return func() {} }
func (f *fakeSIDXInfinite) TakeFileSnapshot(_ string) error {
return nil }
+func (f *fakeSIDXInfinite) PrepareMemPart(uint64, *sidx.MemPart) func(cur
*sidx.Snapshot) *sidx.Snapshot {
+ return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur }
+}
+
+func (f *fakeSIDXInfinite) PrepareFlushed(*sidx.FlusherIntroduction) func(cur
*sidx.Snapshot) *sidx.Snapshot {
+ return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur }
+}
+
+func (f *fakeSIDXInfinite) PrepareMerged(*sidx.MergerIntroduction) func(cur
*sidx.Snapshot) *sidx.Snapshot {
+ return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur }
+}
+
+func (f *fakeSIDXInfinite) PrepareSynced(map[uint64]struct{}) func(cur
*sidx.Snapshot) *sidx.Snapshot {
+ return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur }
+}
+func (f *fakeSIDXInfinite) CurrentSnapshot() *sidx.Snapshot { return nil }
+func (f *fakeSIDXInfinite) ReplaceSnapshot(*sidx.Snapshot) {}
// TestStreamSIDXTraceBatches_InfiniteChannelContinuesUntilCanceled verifies
that
// the streaming pipeline continues streaming from an infinite channel until
context is canceled.