This is an automated email from the ASF dual-hosted git repository. hanahmily pushed a commit to branch feature/snapshot-refactor in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git
commit bce893dd6614bb3938b3a67154c3a032859ab9a0 Author: Hongtao Gao <[email protected]> AuthorDate: Fri Feb 6 17:06:00 2026 +0800 feat(snapshot): introduce generic snapshot coordination for atomic transitions - Added a new package for managing atomic snapshot transitions across trace and sidx components. - Enhanced the SIDX interface with snapshot transaction support, allowing for coordinated updates. - Updated existing snapshot handling to utilize the new transaction model for memory, flushed, and merged parts. - Refactored snapshot types and methods for consistency and clarity. --- CHANGES.md | 1 + banyand/internal/sidx/interfaces.go | 19 + banyand/internal/sidx/introducer.go | 4 +- banyand/internal/sidx/query.go | 4 +- banyand/internal/sidx/sidx.go | 77 ++- banyand/internal/sidx/snapshot.go | 68 ++- banyand/internal/snapshot/snapshot.go | 219 +++++++ banyand/internal/snapshot/snapshot_test.go | 917 +++++++++++++++++++++++++++++ banyand/trace/introducer.go | 266 +++++++-- banyand/trace/snapshot.go | 18 +- banyand/trace/streaming_pipeline_test.go | 35 ++ 11 files changed, 1533 insertions(+), 95 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 8d8facb5d..3ffc1b9be 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -22,6 +22,7 @@ Release Notes. - **Breaking Change**: Change the data storage path structure for property model: - From: `<data-dir>/property/data/shard-<id>/...` - To: `<data-dir>/property/data/<group>/shard-<id>/...` +- Add a generic snapshot coordination package for atomic snapshot transitions across trace and sidx. ### Bug Fixes diff --git a/banyand/internal/sidx/interfaces.go b/banyand/internal/sidx/interfaces.go index 670f76f1a..5c4a88219 100644 --- a/banyand/internal/sidx/interfaces.go +++ b/banyand/internal/sidx/interfaces.go @@ -31,11 +31,30 @@ import ( "github.com/apache/skywalking-banyandb/pkg/query/model" ) +// SnapshotTransactionSupport defines snapshot and transaction-related operations +// for coordinated snapshot updates. It is embedded by SIDX so that callers +// (e.g. trace) can participate in atomic snapshot transactions. +type SnapshotTransactionSupport interface { + // PrepareMemPart prepares a transition for introducing a memory part. + PrepareMemPart(partID uint64, mp *MemPart) func(cur *Snapshot) *Snapshot + // PrepareFlushed prepares a transition for introducing flushed parts. + PrepareFlushed(intro *FlusherIntroduction) func(cur *Snapshot) *Snapshot + // PrepareMerged prepares a transition for introducing merged parts. + PrepareMerged(intro *MergerIntroduction) func(cur *Snapshot) *Snapshot + // PrepareSynced prepares a transition for removing synced parts. + PrepareSynced(partIDsToSync map[uint64]struct{}) func(cur *Snapshot) *Snapshot + // CurrentSnapshot returns the current snapshot with incremented reference count. + CurrentSnapshot() *Snapshot + // ReplaceSnapshot atomically replaces the current snapshot with next. + ReplaceSnapshot(next *Snapshot) +} + // SIDX defines the main secondary index interface with user-controlled ordering. // The core principle is that int64 keys are provided by users and treated as // opaque ordering values by sidx - the system only performs numerical comparisons // without interpreting the semantic meaning of keys. type SIDX interface { + SnapshotTransactionSupport // IntroduceMemPart introduces a memPart to the SIDX instance. IntroduceMemPart(partID uint64, mp *MemPart) // IntroduceFlushed introduces a flushed map to the SIDX instance. diff --git a/banyand/internal/sidx/introducer.go b/banyand/internal/sidx/introducer.go index 9bdefde82..5856828f9 100644 --- a/banyand/internal/sidx/introducer.go +++ b/banyand/internal/sidx/introducer.go @@ -100,7 +100,7 @@ func (s *sidx) IntroduceMemPart(partID uint64, memPart *memPart) { if cur != nil { defer cur.decRef() } else { - cur = &snapshot{} + cur = &Snapshot{} } nextSnp := cur.copyAllTo() @@ -174,7 +174,7 @@ func (s *sidx) TakeFileSnapshot(dst string) error { return nil } -func (s *sidx) replaceSnapshot(next *snapshot) { +func (s *sidx) replaceSnapshot(next *Snapshot) { s.mu.Lock() defer s.mu.Unlock() if s.snapshot != nil { diff --git a/banyand/internal/sidx/query.go b/banyand/internal/sidx/query.go index 36d2f80e2..65f55ff91 100644 --- a/banyand/internal/sidx/query.go +++ b/banyand/internal/sidx/query.go @@ -149,7 +149,7 @@ func finalizeStreamingSpan(span *query.Span, errPtr *error) { func (s *sidx) prepareStreamingResources( ctx context.Context, req QueryRequest, - snap *snapshot, + snap *Snapshot, span *query.Span, ) (*streamingQueryResources, bool) { var prepareSpan *query.Span @@ -485,7 +485,7 @@ func extractOrdering(req QueryRequest) bool { } // selectPartsForQuery selects relevant parts from snapshot based on key range. -func selectPartsForQuery(snap *snapshot, minKey, maxKey int64) []*part { +func selectPartsForQuery(snap *Snapshot, minKey, maxKey int64) []*part { var selectedParts []*part for _, pw := range snap.parts { diff --git a/banyand/internal/sidx/sidx.go b/banyand/internal/sidx/sidx.go index 181497702..4e0b814b6 100644 --- a/banyand/internal/sidx/sidx.go +++ b/banyand/internal/sidx/sidx.go @@ -45,7 +45,7 @@ const ( // sidx implements the SIDX interface with introduction channels for async operations. type sidx struct { fileSystem fs.FileSystem - snapshot *snapshot + snapshot *Snapshot l *logger.Logger pm protector.Memory root string @@ -233,7 +233,7 @@ func (s *sidx) Close() error { } // currentSnapshot returns the current snapshot with incremented reference count. -func (s *sidx) currentSnapshot() *snapshot { +func (s *sidx) currentSnapshot() *Snapshot { s.mu.RLock() defer s.mu.RUnlock() @@ -248,6 +248,79 @@ func (s *sidx) currentSnapshot() *snapshot { return nil } +// CurrentSnapshot returns the current snapshot with incremented reference count. +// Implements snapshot.Manager[*snapshot] interface. +func (s *sidx) CurrentSnapshot() *Snapshot { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.snapshot == nil { + return nil + } + + s.snapshot.IncRef() + return s.snapshot +} + +// ReplaceSnapshot atomically replaces the current snapshot with next. +// Implements snapshot.Manager[*Snapshot] interface. +// The old snapshot's DecRef is called automatically per the Manager contract. +func (s *sidx) ReplaceSnapshot(next *Snapshot) { + s.mu.Lock() + defer s.mu.Unlock() + if s.snapshot != nil { + s.snapshot.DecRef() + } + s.snapshot = next +} + +// PrepareMemPart prepares a transition for introducing a memory part. +func (s *sidx) PrepareMemPart(partID uint64, mp *MemPart) func(cur *Snapshot) *Snapshot { + return func(cur *Snapshot) *Snapshot { + if cur == nil { + cur = &Snapshot{} + } + next := cur.copyAllTo() + part := openMemPart(mp) + pw := newPartWrapper(mp, part) + pw.p.partMetadata.ID = partID + next.parts = append(next.parts, pw) + return next + } +} + +// PrepareFlushed prepares a transition for introducing flushed parts. +func (s *sidx) PrepareFlushed(intro *FlusherIntroduction) func(cur *Snapshot) *Snapshot { + return func(cur *Snapshot) *Snapshot { + if cur == nil { + s.l.Panic().Msg("current snapshot is nil in PrepareFlushed") + } + return cur.merge(intro.flushed) + } +} + +// PrepareMerged prepares a transition for introducing merged parts. +func (s *sidx) PrepareMerged(intro *MergerIntroduction) func(cur *Snapshot) *Snapshot { + return func(cur *Snapshot) *Snapshot { + if cur == nil { + s.l.Panic().Msg("current snapshot is nil in PrepareMerged") + } + next := cur.remove(intro.merged) + next.parts = append(next.parts, intro.newPart) + return next + } +} + +// PrepareSynced prepares a transition for removing synced parts. +func (s *sidx) PrepareSynced(partIDsToSync map[uint64]struct{}) func(cur *Snapshot) *Snapshot { + return func(cur *Snapshot) *Snapshot { + if cur == nil { + s.l.Panic().Msg("current snapshot is nil in PrepareSynced") + } + return cur.remove(partIDsToSync) + } +} + // blockCursor represents a cursor for iterating through a loaded block, similar to query_by_ts.go. type blockCursor struct { p *part diff --git a/banyand/internal/sidx/snapshot.go b/banyand/internal/sidx/snapshot.go index c59c3c396..0cce7df83 100644 --- a/banyand/internal/sidx/snapshot.go +++ b/banyand/internal/sidx/snapshot.go @@ -25,10 +25,10 @@ import ( "github.com/apache/skywalking-banyandb/pkg/logger" ) -// snapshot represents an immutable collection of parts at a specific epoch. +// Snapshot represents an immutable collection of parts at a specific epoch. // It provides safe concurrent access to parts through reference counting and // enables queries to work with a consistent view of data. -type snapshot struct { +type Snapshot struct { // parts contains all active parts sorted by epoch (oldest first) parts []*partWrapper @@ -38,8 +38,8 @@ type snapshot struct { // newSnapshot creates a new snapshot with the given parts and epoch. // The snapshot starts with a reference count of 1. -func newSnapshot(parts []*partWrapper) *snapshot { - s := &snapshot{} +func newSnapshot(parts []*partWrapper) *Snapshot { + s := &Snapshot{} s.parts = append(s.parts[:0], parts...) s.ref = 1 @@ -56,20 +56,23 @@ func newSnapshot(parts []*partWrapper) *snapshot { return s } +// IncRef increments the snapshot reference count. +// Implements the snapshot.Snapshot interface. +func (s *Snapshot) IncRef() { + atomic.AddInt32(&s.ref, 1) +} + // acquire increments the snapshot reference count. // Returns true if successful, false if snapshot has been released. -func (s *snapshot) acquire() bool { +// This is an internal method for backward compatibility. +func (s *Snapshot) acquire() bool { return atomic.AddInt32(&s.ref, 1) > 0 } -// decRef decrements the snapshot reference count (helper for snapshot interface). -func (s *snapshot) decRef() { - s.release() -} - -// release decrements the snapshot reference count. +// DecRef decrements the snapshot reference count. // When the count reaches zero, all part references are released. -func (s *snapshot) release() { +// Implements the snapshot.Snapshot interface. +func (s *Snapshot) DecRef() { newRef := atomic.AddInt32(&s.ref, -1) if newRef > 0 { return @@ -84,10 +87,21 @@ func (s *snapshot) release() { s.reset() } +// decRef is an internal helper that calls DecRef. +func (s *Snapshot) decRef() { + s.DecRef() +} + +// release is an internal helper that calls DecRef. +// Kept for backward compatibility. +func (s *Snapshot) release() { + s.DecRef() +} + // getParts returns parts that potentially contain data within the specified key range. // This method filters parts based on their key ranges to minimize I/O during queries. // Parts are returned in epoch order (oldest first) for consistent iteration. -func (s *snapshot) getParts(minKey, maxKey int64) []*partWrapper { +func (s *Snapshot) getParts(minKey, maxKey int64) []*partWrapper { var result []*partWrapper for _, pw := range s.parts { @@ -112,7 +126,7 @@ func (s *snapshot) getParts(minKey, maxKey int64) []*partWrapper { // getPartsAll returns all active parts in the snapshot. // This is used when querying without key range restrictions. -func (s *snapshot) getPartsAll() []*partWrapper { +func (s *Snapshot) getPartsAll() []*partWrapper { var result []*partWrapper for _, pw := range s.parts { @@ -125,17 +139,17 @@ func (s *snapshot) getPartsAll() []*partWrapper { } // getPartCount returns the number of parts in the snapshot. -func (s *snapshot) getPartCount() int { +func (s *Snapshot) getPartCount() int { return len(s.getPartsAll()) } // refCount returns the current reference count (for testing/debugging). -func (s *snapshot) refCount() int32 { +func (s *Snapshot) refCount() int32 { return atomic.LoadInt32(&s.ref) } // validate checks snapshot consistency and part availability. -func (s *snapshot) validate() error { +func (s *Snapshot) validate() error { if atomic.LoadInt32(&s.ref) <= 0 { return fmt.Errorf("snapshot has zero or negative reference count") } @@ -160,7 +174,7 @@ func (s *snapshot) validate() error { // addPart adds a new part to the snapshot during construction. // This should only be called before the snapshot is made available to other goroutines. // After construction, snapshots should be treated as immutable. -func (s *snapshot) addPart(pw *partWrapper) { +func (s *Snapshot) addPart(pw *partWrapper) { if pw != nil && pw.acquire() { s.parts = append(s.parts, pw) } @@ -168,7 +182,7 @@ func (s *snapshot) addPart(pw *partWrapper) { // removePart marks a part for removal from future snapshots. // The part remains accessible in this snapshot until the snapshot is released. -func (s *snapshot) removePart(partID uint64) { +func (s *Snapshot) removePart(partID uint64) { for _, pw := range s.parts { if pw.ID() == partID { pw.markForRemoval() @@ -178,7 +192,7 @@ func (s *snapshot) removePart(partID uint64) { } // reset clears the snapshot for reuse. -func (s *snapshot) reset() { +func (s *Snapshot) reset() { // Release all part references for _, pw := range s.parts { if pw != nil { @@ -191,7 +205,7 @@ func (s *snapshot) reset() { } // String returns a string representation of the snapshot. -func (s *snapshot) String() string { +func (s *Snapshot) String() string { activeCount := s.getPartCount() return fmt.Sprintf("snapshot{parts=%d/%d, ref=%d}", activeCount, len(s.parts), s.refCount()) @@ -202,8 +216,8 @@ func parseEpoch(epochStr string) (uint64, error) { } // copyAllTo creates a new snapshot with all parts from current snapshot. -func (s *snapshot) copyAllTo() *snapshot { - var result snapshot +func (s *Snapshot) copyAllTo() *Snapshot { + var result Snapshot result.parts = make([]*partWrapper, len(s.parts)) result.ref = 1 @@ -219,8 +233,8 @@ func (s *snapshot) copyAllTo() *snapshot { } // merge creates a new snapshot by merging flushed parts into the current snapshot. -func (s *snapshot) merge(nextParts map[uint64]*partWrapper) *snapshot { - var result snapshot +func (s *Snapshot) merge(nextParts map[uint64]*partWrapper) *Snapshot { + var result Snapshot result.ref = 1 for i := 0; i < len(s.parts); i++ { if n, ok := nextParts[s.parts[i].ID()]; ok { @@ -235,8 +249,8 @@ func (s *snapshot) merge(nextParts map[uint64]*partWrapper) *snapshot { } // remove creates a new snapshot by removing specified parts. -func (s *snapshot) remove(toRemove map[uint64]struct{}) *snapshot { - var result snapshot +func (s *Snapshot) remove(toRemove map[uint64]struct{}) *Snapshot { + var result Snapshot result.ref = 1 // Copy parts except those being removed diff --git a/banyand/internal/snapshot/snapshot.go b/banyand/internal/snapshot/snapshot.go new file mode 100644 index 000000000..1ab8b308a --- /dev/null +++ b/banyand/internal/snapshot/snapshot.go @@ -0,0 +1,219 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Package snapshot provides generic transaction coordination for snapshot-based systems. +// It enables atomic updates across multiple heterogeneous snapshot managers using Go generics. +package snapshot + +import ( + "reflect" + "sync" + + "github.com/apache/skywalking-banyandb/pkg/pool" +) + +// Snapshot is a type constraint for snapshot types that support reference counting. +// Any type implementing this interface can participate in atomic transactions. +type Snapshot interface { + // IncRef increments the reference count. + IncRef() + // DecRef decrements the reference count and releases resources when zero. + DecRef() +} + +// Manager is a generic interface for managing snapshots of type S. +// Both trace and sidx implement this interface for their respective snapshot types. +type Manager[S Snapshot] interface { + // CurrentSnapshot returns the current snapshot with incremented ref count. + // Returns nil if no snapshot exists. + CurrentSnapshot() S + // ReplaceSnapshot atomically replaces the current snapshot with next. + // The old snapshot's DecRef is called automatically. + ReplaceSnapshot(next S) +} + +// Transition represents a prepared but uncommitted snapshot change. +// It holds both the current and next snapshots, allowing atomic commit or rollback. +type Transition[S Snapshot] struct { + manager Manager[S] + current S + next S + committed bool +} + +// getTransitionPool returns or creates a pool for the given snapshot type. +func getTransitionPool[S Snapshot]() *pool.Synced[any] { + var zero S + typ := reflect.TypeOf(zero) + + if p, ok := transitionPools.Load(typ); ok { + return p.(*pool.Synced[any]) + } + + // Create new pool for this type + poolName := "snapshot.Transition[" + typ.String() + "]" + p := pool.Register[any](poolName) + actual, _ := transitionPools.LoadOrStore(typ, p) + return actual.(*pool.Synced[any]) +} + +// NewTransition creates a transition by preparing the next snapshot from the pool. +// The prepareNext function receives the current snapshot and returns the next. +func NewTransition[S Snapshot](manager Manager[S], prepareNext func(current S) S) *Transition[S] { + p := getTransitionPool[S]() + + var t *Transition[S] + if pooled := p.Get(); pooled != nil { + t = pooled.(*Transition[S]) + } else { + t = &Transition[S]{} + } + + current := manager.CurrentSnapshot() + next := prepareNext(current) + + t.manager = manager + t.current = current + t.next = next + t.committed = false + + return t +} + +// Commit commits the transition, replacing the current snapshot with next. +// ReplaceSnapshot satisfies the Manager contract by calling DecRef on the old snapshot. +func (t *Transition[S]) Commit() { + if t.committed { + return + } + t.committed = true + t.manager.ReplaceSnapshot(t.next) +} + +// Rollback discards the prepared next snapshot without committing. +func (t *Transition[S]) Rollback() { + if t.committed { + return + } + // Release the prepared next snapshot + // Use reflection to check if underlying value is nil (Go interface quirk) + if !reflect.ValueOf(t.next).IsNil() { + t.next.DecRef() + } + // Release reference to current (acquired in NewTransition) + if !reflect.ValueOf(t.current).IsNil() { + t.current.DecRef() + } +} + +// Release returns the transition to the pool for reuse. +// This should be called after Commit or Rollback to reduce allocations. +func (t *Transition[S]) Release() { + t.reset() + p := getTransitionPool[S]() + p.Put(any(t)) +} + +// reset clears the transition state for reuse. +// When the transition was committed, the ref acquired in NewTransition (t.current) was not +// decremented by ReplaceSnapshot (which only decrements the manager's ref), so we release it here. +func (t *Transition[S]) reset() { + var zero S + if t.committed && !reflect.ValueOf(t.current).IsNil() { + t.current.DecRef() + } + t.manager = nil + t.current = zero + t.next = zero + t.committed = false +} + +var ( + transactionPool = pool.Register[*Transaction]("snapshot.Transaction") + transitionPools sync.Map // map[reflect.Type]*pool.Synced[any] +) + +// Transaction coordinates atomic commits across multiple heterogeneous transitions. +// It uses type erasure via function closures to handle different snapshot types. +type Transaction struct { + commits []func() + rollbacks []func() + mu sync.Mutex + finalized bool +} + +// NewTransaction creates a new empty transaction from the pool. +func NewTransaction() *Transaction { + txn := transactionPool.Get() + if txn == nil { + txn = &Transaction{} + } + return txn +} + +// Release returns the transaction to the pool for reuse. +// This should be called after Commit or Rollback to reduce allocations. +func (txn *Transaction) Release() { + txn.reset() + transactionPool.Put(txn) +} + +// reset clears the transaction state for reuse. +func (txn *Transaction) reset() { + txn.commits = txn.commits[:0] + txn.rollbacks = txn.rollbacks[:0] + txn.finalized = false +} + +// AddTransition adds a typed transition to the transaction. +// This is a generic function (not a method) because Go doesn't support generic methods. +func AddTransition[S Snapshot](txn *Transaction, transition *Transition[S]) { + txn.commits = append(txn.commits, transition.Commit) + txn.rollbacks = append(txn.rollbacks, transition.Rollback) +} + +// Commit atomically commits all transitions in the transaction. +// All commits happen under a single lock to ensure atomicity. +func (txn *Transaction) Commit() { + txn.mu.Lock() + defer txn.mu.Unlock() + + if txn.finalized { + return + } + txn.finalized = true + + for _, commit := range txn.commits { + commit() + } +} + +// Rollback discards all prepared transitions without committing. +// Rollbacks are executed in reverse order (LIFO). +func (txn *Transaction) Rollback() { + txn.mu.Lock() + defer txn.mu.Unlock() + + if txn.finalized { + return + } + txn.finalized = true + + for i := len(txn.rollbacks) - 1; i >= 0; i-- { + txn.rollbacks[i]() + } +} diff --git a/banyand/internal/snapshot/snapshot_test.go b/banyand/internal/snapshot/snapshot_test.go new file mode 100644 index 000000000..b843c4653 --- /dev/null +++ b/banyand/internal/snapshot/snapshot_test.go @@ -0,0 +1,917 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package snapshot + +import ( + "strings" + "sync" + "sync/atomic" + "testing" + + "github.com/apache/skywalking-banyandb/pkg/pool" +) + +// mockSnapshot is a test implementation of the Snapshot interface. +type mockSnapshot struct { + id int + ref atomic.Int32 +} + +func (m *mockSnapshot) IncRef() { + m.ref.Add(1) +} + +func (m *mockSnapshot) DecRef() { + m.ref.Add(-1) +} + +func (m *mockSnapshot) RefCount() int32 { + return m.ref.Load() +} + +// mockManager is a test implementation of the Manager interface. +type mockManager struct { + snapshot *mockSnapshot + mu sync.RWMutex +} + +func (m *mockManager) CurrentSnapshot() *mockSnapshot { + m.mu.RLock() + defer m.mu.RUnlock() + if m.snapshot != nil { + m.snapshot.IncRef() + return m.snapshot + } + return nil +} + +func (m *mockManager) ReplaceSnapshot(next *mockSnapshot) { + m.mu.Lock() + defer m.mu.Unlock() + if m.snapshot != nil { + m.snapshot.DecRef() + } + m.snapshot = next +} + +// getSnapshotPoolRefCounts returns ref counts for all pools with names starting with "snapshot.". +func getSnapshotPoolRefCounts() map[string]int { + all := pool.AllRefsCount() + result := make(map[string]int, len(all)) + for name, count := range all { + if strings.HasPrefix(name, "snapshot.") { + result[name] = count + } + } + return result +} + +// assertSnapshotPoolsNoLeak verifies that snapshot pool ref counts have not increased (no leak). +func assertSnapshotPoolsNoLeak(t *testing.T, before map[string]int) { + t.Helper() + after := getSnapshotPoolRefCounts() + for name, count := range after { + if beforeCount := before[name]; beforeCount != count { + t.Errorf("pool %s: ref count changed from %d to %d (possible leak)", name, beforeCount, count) + } + } +} + +func TestTransition_Commit(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + manager := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager.snapshot.IncRef() + + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition.Release() + + transition.Commit() + + // Verify snapshot was replaced + current := manager.CurrentSnapshot() + if current == nil { + t.Fatal("expected non-nil snapshot after commit") + } + defer current.DecRef() + + if current.id != 2 { + t.Errorf("expected snapshot id 2, got %d", current.id) + } +} + +func TestTransition_Rollback(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + manager := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager.snapshot.IncRef() + + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition.Release() + + transition.Rollback() + + // Verify snapshot was NOT replaced + current := manager.CurrentSnapshot() + if current == nil { + t.Fatal("expected non-nil snapshot after rollback") + } + defer current.DecRef() + + if current.id != 1 { + t.Errorf("expected snapshot id 1, got %d", current.id) + } +} + +func TestTransaction_SingleTransition(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + manager := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager.snapshot.IncRef() + + txn := NewTransaction() + defer txn.Release() + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition.Release() + AddTransition(txn, transition) + + txn.Commit() + + // Verify snapshot was replaced + current := manager.CurrentSnapshot() + if current == nil { + t.Fatal("expected non-nil snapshot after commit") + } + defer current.DecRef() + + if current.id != 2 { + t.Errorf("expected snapshot id 2, got %d", current.id) + } +} + +func TestTransaction_MultipleTransitions(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + manager1 := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager1.snapshot.IncRef() + + manager2 := &mockManager{ + snapshot: &mockSnapshot{id: 10}, + } + manager2.snapshot.IncRef() + + txn := NewTransaction() + defer txn.Release() + + transition1 := NewTransition(manager1, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition1.Release() + AddTransition(txn, transition1) + + transition2 := NewTransition(manager2, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition2.Release() + AddTransition(txn, transition2) + + txn.Commit() + + // Verify both snapshots were replaced + current1 := manager1.CurrentSnapshot() + if current1 == nil { + t.Fatal("expected non-nil snapshot for manager1") + } + defer current1.DecRef() + + if current1.id != 2 { + t.Errorf("expected snapshot id 2 for manager1, got %d", current1.id) + } + + current2 := manager2.CurrentSnapshot() + if current2 == nil { + t.Fatal("expected non-nil snapshot for manager2") + } + defer current2.DecRef() + + if current2.id != 11 { + t.Errorf("expected snapshot id 11 for manager2, got %d", current2.id) + } +} + +func TestTransaction_Rollback(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + manager1 := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager1.snapshot.IncRef() + + manager2 := &mockManager{ + snapshot: &mockSnapshot{id: 10}, + } + manager2.snapshot.IncRef() + + txn := NewTransaction() + defer txn.Release() + + transition1 := NewTransition(manager1, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition1.Release() + AddTransition(txn, transition1) + + transition2 := NewTransition(manager2, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition2.Release() + AddTransition(txn, transition2) + + txn.Rollback() + + // Verify neither snapshot was replaced + current1 := manager1.CurrentSnapshot() + if current1 == nil { + t.Fatal("expected non-nil snapshot for manager1") + } + defer current1.DecRef() + + if current1.id != 1 { + t.Errorf("expected snapshot id 1 for manager1, got %d", current1.id) + } + + current2 := manager2.CurrentSnapshot() + if current2 == nil { + t.Fatal("expected non-nil snapshot for manager2") + } + defer current2.DecRef() + + if current2.id != 10 { + t.Errorf("expected snapshot id 10 for manager2, got %d", current2.id) + } +} + +func TestTransaction_Concurrent(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + manager := &mockManager{ + snapshot: &mockSnapshot{id: 0}, + } + manager.snapshot.IncRef() + + const numGoroutines = 10 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + + txn := NewTransaction() + defer txn.Release() + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition.Release() + AddTransition(txn, transition) + txn.Commit() + }() + } + + wg.Wait() + + // Verify final snapshot exists and was updated + // Note: With concurrent updates, not all increments may be preserved + // because multiple goroutines might read the same snapshot concurrently. + // The important thing is that no crashes occurred and the snapshot changed. + current := manager.CurrentSnapshot() + if current == nil { + t.Fatal("expected non-nil snapshot after concurrent commits") + } + defer current.DecRef() + + if current.id <= 0 { + t.Errorf("expected snapshot id > 0, got %d", current.id) + } + if current.id > numGoroutines { + t.Errorf("expected snapshot id <= %d, got %d", numGoroutines, current.id) + } +} + +func TestTransaction_IdempotentCommit(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + manager := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager.snapshot.IncRef() + + txn := NewTransaction() + defer txn.Release() + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition.Release() + AddTransition(txn, transition) + + // Commit multiple times + txn.Commit() + txn.Commit() + txn.Commit() + + // Verify snapshot was replaced only once + current := manager.CurrentSnapshot() + if current == nil { + t.Fatal("expected non-nil snapshot") + } + defer current.DecRef() + + if current.id != 2 { + t.Errorf("expected snapshot id 2, got %d", current.id) + } +} + +func TestTransaction_IdempotentRollback(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + manager := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager.snapshot.IncRef() + + txn := NewTransaction() + defer txn.Release() + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition.Release() + AddTransition(txn, transition) + + // Rollback multiple times + txn.Rollback() + txn.Rollback() + txn.Rollback() + + // Verify snapshot was not replaced + current := manager.CurrentSnapshot() + if current == nil { + t.Fatal("expected non-nil snapshot") + } + defer current.DecRef() + + if current.id != 1 { + t.Errorf("expected snapshot id 1, got %d", current.id) + } +} + +func TestTransaction_CommitAfterRollback(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + manager := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager.snapshot.IncRef() + + txn := NewTransaction() + defer txn.Release() + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition.Release() + AddTransition(txn, transition) + + // Rollback first, then try to commit + txn.Rollback() + txn.Commit() + + // Verify snapshot was not replaced + current := manager.CurrentSnapshot() + if current == nil { + t.Fatal("expected non-nil snapshot") + } + defer current.DecRef() + + if current.id != 1 { + t.Errorf("expected snapshot id 1, got %d", current.id) + } +} + +func TestTransaction_HeterogeneousTypes(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + // Test with two different snapshot types to verify type safety + type alternateSnapshot struct { + value string + ref atomic.Int32 + } + + altSnap := &alternateSnapshot{value: "test"} + altSnap.ref.Store(1) + + altManager := &struct { + snapshot *alternateSnapshot + mu sync.RWMutex + }{ + snapshot: altSnap, + } + + // Define methods inline for alternateSnapshot + incRef := func(s *alternateSnapshot) { + s.ref.Add(1) + } + + // Create a custom manager for alternateSnapshot + currentAlt := func() *alternateSnapshot { + altManager.mu.RLock() + defer altManager.mu.RUnlock() + if altManager.snapshot != nil { + incRef(altManager.snapshot) + return altManager.snapshot + } + return nil + } + + replaceAlt := func(next *alternateSnapshot) { + altManager.mu.Lock() + defer altManager.mu.Unlock() + altManager.snapshot = next + } + + // This test verifies that the generic design compiles with different types + // The actual runtime behavior is tested in other tests + _ = currentAlt + _ = replaceAlt +} + +// BenchmarkTransaction_WithPool benchmarks transaction creation with pooling. +func BenchmarkTransaction_WithPool(b *testing.B) { + manager := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager.snapshot.IncRef() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + txn := NewTransaction() + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + AddTransition(txn, transition) + txn.Commit() + txn.Release() + } +} + +// BenchmarkTransaction_WithoutPool benchmarks transaction creation without pooling. +func BenchmarkTransaction_WithoutPool(b *testing.B) { + manager := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager.snapshot.IncRef() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Simulate non-pooled allocation + txn := &Transaction{} + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + AddTransition(txn, transition) + txn.Commit() + // No Release() - simulates non-pooled version + } +} + +// BenchmarkTransition_WithPool benchmarks transition creation with pooling. +func BenchmarkTransition_WithPool(b *testing.B) { + manager := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager.snapshot.IncRef() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + transition.Commit() + transition.Release() + } +} + +// BenchmarkTransition_WithoutPool benchmarks transition creation without pooling. +func BenchmarkTransition_WithoutPool(b *testing.B) { + manager := &mockManager{ + snapshot: &mockSnapshot{id: 1}, + } + manager.snapshot.IncRef() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Directly allocate transition without pool + var zero *mockSnapshot + transition := &Transition[*mockSnapshot]{ + manager: manager, + current: zero, + next: zero, + committed: false, + } + current := manager.CurrentSnapshot() + next := &mockSnapshot{id: current.id + 1} + next.IncRef() + transition.current = current + transition.next = next + transition.Commit() + // No Release() - simulates non-pooled version + } +} + +// TestTransition_ReleaseReleasesCurrentRef verifies that Release() decrements the ref +// held in t.current after a committed transition (fixes leak where that ref was never released). +func TestTransition_ReleaseReleasesCurrentRef(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + oldSnapshot := &mockSnapshot{id: 1} + oldSnapshot.IncRef() + manager := &mockManager{snapshot: oldSnapshot} + + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + transition.Commit() + transition.Release() + + // After Release(), the ref we held in t.current must have been released; + // ReplaceSnapshot already released the manager's ref, so oldSnapshot should be at 0. + if oldSnapshot.RefCount() != 0 { + t.Errorf("expected old snapshot refcount 0 after Commit+Release, got %d", oldSnapshot.RefCount()) + } +} + +// TestTransition_RefCountAfterCommit verifies reference counts are correct after commit. +func TestTransition_RefCountAfterCommit(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + oldSnapshot := &mockSnapshot{id: 1} + oldSnapshot.IncRef() + + manager := &mockManager{ + snapshot: oldSnapshot, + } + + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition.Release() + + newSnapshot := transition.next + + // Before commit: + // - oldSnapshot: refcount = 2 (initial +1, CurrentSnapshot +1) + // - newSnapshot: refcount = 1 (created with +1) + if oldSnapshot.RefCount() != 2 { + t.Errorf("expected old snapshot refcount 2 before commit, got %d", oldSnapshot.RefCount()) + } + if newSnapshot.RefCount() != 1 { + t.Errorf("expected new snapshot refcount 1 before commit, got %d", newSnapshot.RefCount()) + } + + transition.Commit() + + // After commit: + // - oldSnapshot: refcount = 1 (Commit decremented it once) + // - newSnapshot: refcount = 1 (manager now holds it, no change) + if oldSnapshot.RefCount() != 1 { + t.Errorf("expected old snapshot refcount 1 after commit, got %d", oldSnapshot.RefCount()) + } + if newSnapshot.RefCount() != 1 { + t.Errorf("expected new snapshot refcount 1 after commit, got %d", newSnapshot.RefCount()) + } + + // Verify manager has the new snapshot + current := manager.CurrentSnapshot() + if current == nil { + t.Fatal("expected non-nil snapshot from manager") + } + defer current.DecRef() + + if current.id != 2 { + t.Errorf("expected current snapshot id 2, got %d", current.id) + } + if current.RefCount() != 2 { + t.Errorf("expected current snapshot refcount 2 after CurrentSnapshot, got %d", current.RefCount()) + } +} + +// TestTransition_RefCountAfterRollback verifies reference counts are correct after rollback. +func TestTransition_RefCountAfterRollback(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + oldSnapshot := &mockSnapshot{id: 1} + oldSnapshot.IncRef() + + manager := &mockManager{ + snapshot: oldSnapshot, + } + + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition.Release() + + newSnapshot := transition.next + + // Before rollback: + // - oldSnapshot: refcount = 2 (initial +1, CurrentSnapshot +1) + // - newSnapshot: refcount = 1 (created with +1) + if oldSnapshot.RefCount() != 2 { + t.Errorf("expected old snapshot refcount 2 before rollback, got %d", oldSnapshot.RefCount()) + } + if newSnapshot.RefCount() != 1 { + t.Errorf("expected new snapshot refcount 1 before rollback, got %d", newSnapshot.RefCount()) + } + + transition.Rollback() + + // After rollback: + // - oldSnapshot: refcount = 1 (Rollback decremented the current reference) + // - newSnapshot: refcount = 0 (Rollback decremented the next reference) + if oldSnapshot.RefCount() != 1 { + t.Errorf("expected old snapshot refcount 1 after rollback, got %d", oldSnapshot.RefCount()) + } + if newSnapshot.RefCount() != 0 { + t.Errorf("expected new snapshot refcount 0 after rollback, got %d", newSnapshot.RefCount()) + } + + // Verify manager still has the old snapshot + current := manager.CurrentSnapshot() + if current == nil { + t.Fatal("expected non-nil snapshot from manager") + } + defer current.DecRef() + + if current.id != 1 { + t.Errorf("expected current snapshot id 1, got %d", current.id) + } +} + +// TestTransaction_RefCountWithMultipleTransitions verifies reference counts +// are correct when multiple transitions are committed together. +func TestTransaction_RefCountWithMultipleTransitions(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + oldSnapshot1 := &mockSnapshot{id: 1} + oldSnapshot1.IncRef() + manager1 := &mockManager{snapshot: oldSnapshot1} + + oldSnapshot2 := &mockSnapshot{id: 10} + oldSnapshot2.IncRef() + manager2 := &mockManager{snapshot: oldSnapshot2} + + txn := NewTransaction() + defer txn.Release() + + transition1 := NewTransition(manager1, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition1.Release() + AddTransition(txn, transition1) + + transition2 := NewTransition(manager2, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition2.Release() + AddTransition(txn, transition2) + + newSnapshot1 := transition1.next + newSnapshot2 := transition2.next + + // Before commit: + // - oldSnapshot1: refcount = 2 (initial +1, CurrentSnapshot +1) + // - newSnapshot1: refcount = 1 (created with +1) + // - oldSnapshot2: refcount = 2 (initial +1, CurrentSnapshot +1) + // - newSnapshot2: refcount = 1 (created with +1) + if oldSnapshot1.RefCount() != 2 { + t.Errorf("expected old snapshot1 refcount 2 before commit, got %d", oldSnapshot1.RefCount()) + } + if newSnapshot1.RefCount() != 1 { + t.Errorf("expected new snapshot1 refcount 1 before commit, got %d", newSnapshot1.RefCount()) + } + if oldSnapshot2.RefCount() != 2 { + t.Errorf("expected old snapshot2 refcount 2 before commit, got %d", oldSnapshot2.RefCount()) + } + if newSnapshot2.RefCount() != 1 { + t.Errorf("expected new snapshot2 refcount 1 before commit, got %d", newSnapshot2.RefCount()) + } + + txn.Commit() + + // After commit: + // - oldSnapshot1: refcount = 1 (initial +1, Commit decremented once) + // - newSnapshot1: refcount = 1 (manager1 holds it) + // - oldSnapshot2: refcount = 1 (initial +1, Commit decremented once) + // - newSnapshot2: refcount = 1 (manager2 holds it) + if oldSnapshot1.RefCount() != 1 { + t.Errorf("expected old snapshot1 refcount 1 after commit, got %d", oldSnapshot1.RefCount()) + } + if newSnapshot1.RefCount() != 1 { + t.Errorf("expected new snapshot1 refcount 1 after commit, got %d", newSnapshot1.RefCount()) + } + if oldSnapshot2.RefCount() != 1 { + t.Errorf("expected old snapshot2 refcount 1 after commit, got %d", oldSnapshot2.RefCount()) + } + if newSnapshot2.RefCount() != 1 { + t.Errorf("expected new snapshot2 refcount 1 after commit, got %d", newSnapshot2.RefCount()) + } +} + +// TestTransaction_RefCountAfterRollback verifies reference counts are correct +// after rolling back multiple transitions. +func TestTransaction_RefCountAfterRollback(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + oldSnapshot1 := &mockSnapshot{id: 1} + oldSnapshot1.IncRef() + manager1 := &mockManager{snapshot: oldSnapshot1} + + oldSnapshot2 := &mockSnapshot{id: 10} + oldSnapshot2.IncRef() + manager2 := &mockManager{snapshot: oldSnapshot2} + + txn := NewTransaction() + defer txn.Release() + + transition1 := NewTransition(manager1, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition1.Release() + AddTransition(txn, transition1) + + transition2 := NewTransition(manager2, func(cur *mockSnapshot) *mockSnapshot { + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition2.Release() + AddTransition(txn, transition2) + + newSnapshot1 := transition1.next + newSnapshot2 := transition2.next + + txn.Rollback() + + // After rollback: + // - oldSnapshot1: refcount = 1 (initial +1, rollback decremented current ref) + // - newSnapshot1: refcount = 0 (rollback decremented next ref) + // - oldSnapshot2: refcount = 1 (initial +1, rollback decremented current ref) + // - newSnapshot2: refcount = 0 (rollback decremented next ref) + if oldSnapshot1.RefCount() != 1 { + t.Errorf("expected old snapshot1 refcount 1 after rollback, got %d", oldSnapshot1.RefCount()) + } + if newSnapshot1.RefCount() != 0 { + t.Errorf("expected new snapshot1 refcount 0 after rollback, got %d", newSnapshot1.RefCount()) + } + if oldSnapshot2.RefCount() != 1 { + t.Errorf("expected old snapshot2 refcount 1 after rollback, got %d", oldSnapshot2.RefCount()) + } + if newSnapshot2.RefCount() != 0 { + t.Errorf("expected new snapshot2 refcount 0 after rollback, got %d", newSnapshot2.RefCount()) + } +} + +// TestTransition_NoDoubleDecrement verifies that the fix prevents double-decrement bug. +// This is a regression test for the bug where ReplaceSnapshot and Commit both +// decremented the same snapshot reference. +func TestTransition_NoDoubleDecrement(t *testing.T) { + before := getSnapshotPoolRefCounts() + defer func() { assertSnapshotPoolsNoLeak(t, before) }() + + oldSnapshot := &mockSnapshot{id: 1} + oldSnapshot.IncRef() + + manager := &mockManager{ + snapshot: oldSnapshot, + } + + // Record initial refcount + initialRefCount := oldSnapshot.RefCount() + if initialRefCount != 1 { + t.Fatalf("expected initial refcount 1, got %d", initialRefCount) + } + + transition := NewTransition(manager, func(cur *mockSnapshot) *mockSnapshot { + // Verify cur is the same object as oldSnapshot + if cur != oldSnapshot { + t.Error("current snapshot should be the same object as old snapshot") + } + // Verify refcount was incremented by CurrentSnapshot + if cur.RefCount() != 2 { + t.Errorf("expected cur refcount 2 (after CurrentSnapshot), got %d", cur.RefCount()) + } + + next := &mockSnapshot{id: cur.id + 1} + next.IncRef() + return next + }) + defer transition.Release() + + transition.Commit() + + // After commit, the old snapshot should have refcount 1 + // (not 0 or negative which would indicate double-decrement) + finalRefCount := oldSnapshot.RefCount() + if finalRefCount != 1 { + t.Errorf("expected final refcount 1 (no double-decrement), got %d", finalRefCount) + } + + // Verify the old snapshot is still valid and wasn't prematurely cleaned up + // In a real scenario, negative refcount would cause a crash or corruption + if finalRefCount < 1 { + t.Error("double-decrement bug detected: refcount went below expected value") + } +} diff --git a/banyand/trace/introducer.go b/banyand/trace/introducer.go index 1b7e3e9a2..e33bfb52e 100644 --- a/banyand/trace/introducer.go +++ b/banyand/trace/introducer.go @@ -19,6 +19,7 @@ package trace import ( "github.com/apache/skywalking-banyandb/banyand/internal/sidx" + snapshotpkg "github.com/apache/skywalking-banyandb/banyand/internal/snapshot" "github.com/apache/skywalking-banyandb/pkg/pool" "github.com/apache/skywalking-banyandb/pkg/watcher" ) @@ -223,61 +224,141 @@ func (tst *tsTable) introducerLoopWithSync(flushCh chan *flusherIntroduction, me } func (tst *tsTable) introduceMemPart(nextIntroduction *introduction, epoch uint64) { - cur := tst.currentSnapshot() - if cur != nil { - defer cur.decRef() - } else { - cur = new(snapshot) - } + // Create generic transaction + txn := snapshotpkg.NewTransaction() + defer txn.Release() + // Prepare trace snapshot transition next := nextIntroduction.memPart tst.addPendingDataCount(-int64(next.mp.partMetadata.TotalCount)) - nextSnp := cur.copyAllTo(epoch) - nextSnp.parts = append(nextSnp.parts, next) - nextSnp.creator = snapshotCreatorMemPart - tst.replaceSnapshot(&nextSnp) + partID := next.p.partMetadata.ID + + traceTransition := snapshotpkg.NewTransition(tst, func(cur *snapshot) *snapshot { + if cur == nil { + cur = new(snapshot) + } + nextSnp := cur.copyAllTo(epoch) + nextSnp.parts = append(nextSnp.parts, next) + nextSnp.creator = snapshotCreatorMemPart + return &nextSnp + }) + defer traceTransition.Release() + snapshotpkg.AddTransition(txn, traceTransition) + + // Prepare sidx snapshot transitions + var sidxTransitions []*snapshotpkg.Transition[*sidx.Snapshot] for name, memPart := range nextIntroduction.sidxReqsMap { - tst.mustGetOrCreateSidx(name).IntroduceMemPart(next.p.partMetadata.ID, memPart) + sidxInstance := tst.mustGetOrCreateSidx(name) + prepareFunc := sidxInstance.PrepareMemPart(partID, memPart) + sidxTransition := snapshotpkg.NewTransition(sidxInstance, prepareFunc) + sidxTransitions = append(sidxTransitions, sidxTransition) + snapshotpkg.AddTransition(txn, sidxTransition) } + defer func() { + for _, t := range sidxTransitions { + t.Release() + } + }() + + // Commit all atomically under single transaction lock + txn.Commit() + if nextIntroduction.applied != nil { close(nextIntroduction.applied) } } func (tst *tsTable) introduceFlushed(nextIntroduction *flusherIntroduction, epoch uint64) { - cur := tst.currentSnapshot() - if cur == nil { - tst.l.Panic().Msg("current snapshot is nil") - } - defer cur.decRef() - nextSnp := cur.merge(epoch, nextIntroduction.flushed) - nextSnp.creator = snapshotCreatorFlusher - tst.replaceSnapshot(&nextSnp) - tst.persistSnapshot(&nextSnp) + // Create generic transaction + txn := snapshotpkg.NewTransaction() + defer txn.Release() + + // Prepare trace snapshot transition + traceTransition := snapshotpkg.NewTransition(tst, func(cur *snapshot) *snapshot { + if cur == nil { + tst.l.Panic().Msg("current snapshot is nil") + } + nextSnp := cur.merge(epoch, nextIntroduction.flushed) + nextSnp.creator = snapshotCreatorFlusher + return &nextSnp + }) + defer traceTransition.Release() + snapshotpkg.AddTransition(txn, traceTransition) + + // Prepare sidx snapshot transitions + var sidxTransitions []*snapshotpkg.Transition[*sidx.Snapshot] for name, sidxFlusherIntroduced := range nextIntroduction.sidxFlusherIntroduced { - tst.mustGetSidx(name).IntroduceFlushed(sidxFlusherIntroduced) + sidxInstance := tst.mustGetSidx(name) + prepareFunc := sidxInstance.PrepareFlushed(sidxFlusherIntroduced) + sidxTransition := snapshotpkg.NewTransition(sidxInstance, prepareFunc) + sidxTransitions = append(sidxTransitions, sidxTransition) + snapshotpkg.AddTransition(txn, sidxTransition) + } + defer func() { + for _, t := range sidxTransitions { + t.Release() + } + }() + + // Commit all atomically under single transaction lock + txn.Commit() + + // Persist snapshot after commit + cur := tst.currentSnapshot() + if cur != nil { + defer cur.decRef() + tst.persistSnapshot(cur) } + if nextIntroduction.applied != nil { close(nextIntroduction.applied) } } // introduceFlushedForSync introduces the flushed trace parts for syncing. -// The SIDX parts are flushed before the trace parts so the syncer can always find -// the corresponding index on disk once a flushed trace part becomes visible. +// The snapshots are updated atomically so the syncer can always find +// the corresponding index once a flushed trace part becomes visible. func (tst *tsTable) introduceFlushedForSync(nextIntroduction *flusherIntroduction, epoch uint64) { + // Create generic transaction + txn := snapshotpkg.NewTransaction() + defer txn.Release() + + // Prepare sidx snapshot transitions first for visibility ordering + var sidxTransitions []*snapshotpkg.Transition[*sidx.Snapshot] for name, sidxFlusherIntroduced := range nextIntroduction.sidxFlusherIntroduced { - tst.mustGetSidx(name).IntroduceFlushed(sidxFlusherIntroduced) + sidxInstance := tst.mustGetSidx(name) + prepareFunc := sidxInstance.PrepareFlushed(sidxFlusherIntroduced) + sidxTransition := snapshotpkg.NewTransition(sidxInstance, prepareFunc) + sidxTransitions = append(sidxTransitions, sidxTransition) + snapshotpkg.AddTransition(txn, sidxTransition) } + defer func() { + for _, t := range sidxTransitions { + t.Release() + } + }() + + // Prepare trace snapshot transition + traceTransition := snapshotpkg.NewTransition(tst, func(cur *snapshot) *snapshot { + if cur == nil { + tst.l.Panic().Msg("current snapshot is nil") + } + nextSnp := cur.merge(epoch, nextIntroduction.flushed) + nextSnp.creator = snapshotCreatorFlusher + return &nextSnp + }) + defer traceTransition.Release() + snapshotpkg.AddTransition(txn, traceTransition) + + // Commit all atomically under single transaction lock + txn.Commit() + + // Persist snapshot after commit cur := tst.currentSnapshot() - if cur == nil { - tst.l.Panic().Msg("current snapshot is nil") + if cur != nil { + defer cur.decRef() + tst.persistSnapshot(cur) } - defer cur.decRef() - nextSnp := cur.merge(epoch, nextIntroduction.flushed) - nextSnp.creator = snapshotCreatorFlusher - tst.replaceSnapshot(&nextSnp) - tst.persistSnapshot(&nextSnp) if nextIntroduction.applied != nil { close(nextIntroduction.applied) @@ -285,23 +366,48 @@ func (tst *tsTable) introduceFlushedForSync(nextIntroduction *flusherIntroductio } func (tst *tsTable) introduceMerged(nextIntroduction *mergerIntroduction, epoch uint64) { - cur := tst.currentSnapshot() - if cur == nil { - tst.l.Panic().Msg("current snapshot is nil") - return - } - defer cur.decRef() - nextSnp := cur.remove(epoch, nextIntroduction.merged) - nextSnp.parts = append(nextSnp.parts, nextIntroduction.newPart) - nextSnp.creator = nextIntroduction.creator - tst.replaceSnapshot(&nextSnp) - tst.persistSnapshot(&nextSnp) + // Create generic transaction + txn := snapshotpkg.NewTransaction() + defer txn.Release() + + // Prepare trace snapshot transition + traceTransition := snapshotpkg.NewTransition(tst, func(cur *snapshot) *snapshot { + if cur == nil { + tst.l.Panic().Msg("current snapshot is nil") + } + nextSnp := cur.remove(epoch, nextIntroduction.merged) + nextSnp.parts = append(nextSnp.parts, nextIntroduction.newPart) + nextSnp.creator = nextIntroduction.creator + return &nextSnp + }) + defer traceTransition.Release() + snapshotpkg.AddTransition(txn, traceTransition) + + // Prepare sidx snapshot transitions + var sidxTransitions []*snapshotpkg.Transition[*sidx.Snapshot] for name, sidxMergerIntroduced := range nextIntroduction.sidxMergerIntroduced { - deferFuncs := tst.mustGetSidx(name).IntroduceMerged(sidxMergerIntroduced) - if deferFuncs != nil { - defer deferFuncs() + sidxInstance := tst.mustGetSidx(name) + prepareFunc := sidxInstance.PrepareMerged(sidxMergerIntroduced) + sidxTransition := snapshotpkg.NewTransition(sidxInstance, prepareFunc) + sidxTransitions = append(sidxTransitions, sidxTransition) + snapshotpkg.AddTransition(txn, sidxTransition) + } + defer func() { + for _, t := range sidxTransitions { + t.Release() } + }() + + // Commit all atomically under single transaction lock + txn.Commit() + + // Persist snapshot after commit + cur := tst.currentSnapshot() + if cur != nil { + defer cur.decRef() + tst.persistSnapshot(cur) } + if nextIntroduction.applied != nil { close(nextIntroduction.applied) } @@ -309,33 +415,73 @@ func (tst *tsTable) introduceMerged(nextIntroduction *mergerIntroduction, epoch func (tst *tsTable) introduceSync(nextIntroduction *syncIntroduction, epoch uint64) { synced := nextIntroduction.synced + + // Create generic transaction + txn := snapshotpkg.NewTransaction() + defer txn.Release() + + // Prepare sidx snapshot transitions + var sidxTransitions []*snapshotpkg.Transition[*sidx.Snapshot] allSidx := tst.getAllSidx() - for _, sidx := range allSidx { - deferFuncs := sidx.IntroduceSynced(synced) - if deferFuncs != nil { - defer deferFuncs() - } + for _, sidxInstance := range allSidx { + prepareFunc := sidxInstance.PrepareSynced(synced) + sidxTransition := snapshotpkg.NewTransition(sidxInstance, prepareFunc) + sidxTransitions = append(sidxTransitions, sidxTransition) + snapshotpkg.AddTransition(txn, sidxTransition) } + defer func() { + for _, t := range sidxTransitions { + t.Release() + } + }() + + // Prepare trace snapshot transition + traceTransition := snapshotpkg.NewTransition(tst, func(cur *snapshot) *snapshot { + if cur == nil { + tst.l.Panic().Msg("current snapshot is nil") + } + nextSnp := cur.remove(epoch, synced) + nextSnp.creator = snapshotCreatorSyncer + return &nextSnp + }) + defer traceTransition.Release() + snapshotpkg.AddTransition(txn, traceTransition) + + // Commit all atomically under single transaction lock + txn.Commit() + + // Persist snapshot after commit cur := tst.currentSnapshot() - if cur == nil { - tst.l.Panic().Msg("current snapshot is nil") - return + if cur != nil { + defer cur.decRef() + tst.persistSnapshot(cur) } - defer cur.decRef() - nextSnp := cur.remove(epoch, synced) - nextSnp.creator = snapshotCreatorSyncer - tst.replaceSnapshot(&nextSnp) - tst.persistSnapshot(&nextSnp) + if nextIntroduction.applied != nil { close(nextIntroduction.applied) } } -func (tst *tsTable) replaceSnapshot(next *snapshot) { +// CurrentSnapshot returns the current snapshot with incremented reference count. +// Implements snapshot.Manager[*snapshot] interface. +func (tst *tsTable) CurrentSnapshot() *snapshot { + tst.RLock() + defer tst.RUnlock() + if tst.snapshot == nil { + return nil + } + tst.snapshot.IncRef() + return tst.snapshot +} + +// ReplaceSnapshot atomically replaces the current snapshot with next. +// Implements snapshot.Manager[*snapshot] interface. +// The old snapshot's DecRef is called automatically per the Manager contract. +func (tst *tsTable) ReplaceSnapshot(next *snapshot) { tst.Lock() defer tst.Unlock() if tst.snapshot != nil { - tst.snapshot.decRef() + tst.snapshot.DecRef() } tst.snapshot = next } diff --git a/banyand/trace/snapshot.go b/banyand/trace/snapshot.go index 479cc2146..431587b07 100644 --- a/banyand/trace/snapshot.go +++ b/banyand/trace/snapshot.go @@ -88,11 +88,15 @@ func (s *snapshot) getParts(dst []*part, minTimestamp int64, maxTimestamp int64, return dst, count } -func (s *snapshot) incRef() { +// IncRef increments the snapshot reference count. +// Implements the snapshot.Snapshot interface. +func (s *snapshot) IncRef() { atomic.AddInt32(&s.ref, 1) } -func (s *snapshot) decRef() { +// DecRef decrements the snapshot reference count. +// Implements the snapshot.Snapshot interface. +func (s *snapshot) DecRef() { n := atomic.AddInt32(&s.ref, -1) if n > 0 { return @@ -103,6 +107,16 @@ func (s *snapshot) decRef() { s.parts = s.parts[:0] } +// incRef is an internal helper for backward compatibility. +func (s *snapshot) incRef() { + s.IncRef() +} + +// decRef is an internal helper for backward compatibility. +func (s *snapshot) decRef() { + s.DecRef() +} + func (s *snapshot) copyAllTo(nextEpoch uint64) snapshot { var result snapshot result.epoch = nextEpoch diff --git a/banyand/trace/streaming_pipeline_test.go b/banyand/trace/streaming_pipeline_test.go index 339208d2e..863c32206 100644 --- a/banyand/trace/streaming_pipeline_test.go +++ b/banyand/trace/streaming_pipeline_test.go @@ -86,6 +86,24 @@ func (f *fakeSIDX) ScanQuery(context.Context, sidx.ScanQueryRequest) ([]*sidx.Qu return nil, nil } +func (f *fakeSIDX) PrepareMemPart(uint64, *sidx.MemPart) func(cur *sidx.Snapshot) *sidx.Snapshot { + return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur } +} + +func (f *fakeSIDX) PrepareFlushed(*sidx.FlusherIntroduction) func(cur *sidx.Snapshot) *sidx.Snapshot { + return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur } +} + +func (f *fakeSIDX) PrepareMerged(*sidx.MergerIntroduction) func(cur *sidx.Snapshot) *sidx.Snapshot { + return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur } +} + +func (f *fakeSIDX) PrepareSynced(map[uint64]struct{}) func(cur *sidx.Snapshot) *sidx.Snapshot { + return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur } +} +func (f *fakeSIDX) CurrentSnapshot() *sidx.Snapshot { return nil } +func (f *fakeSIDX) ReplaceSnapshot(*sidx.Snapshot) {} + type fakeSIDXWithErr struct { *fakeSIDX err error @@ -662,6 +680,23 @@ func (f *fakeSIDXInfinite) PartPaths(map[uint64]struct{}) map[uint64]string { } func (f *fakeSIDXInfinite) IntroduceSynced(map[uint64]struct{}) func() { return func() {} } func (f *fakeSIDXInfinite) TakeFileSnapshot(_ string) error { return nil } +func (f *fakeSIDXInfinite) PrepareMemPart(uint64, *sidx.MemPart) func(cur *sidx.Snapshot) *sidx.Snapshot { + return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur } +} + +func (f *fakeSIDXInfinite) PrepareFlushed(*sidx.FlusherIntroduction) func(cur *sidx.Snapshot) *sidx.Snapshot { + return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur } +} + +func (f *fakeSIDXInfinite) PrepareMerged(*sidx.MergerIntroduction) func(cur *sidx.Snapshot) *sidx.Snapshot { + return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur } +} + +func (f *fakeSIDXInfinite) PrepareSynced(map[uint64]struct{}) func(cur *sidx.Snapshot) *sidx.Snapshot { + return func(cur *sidx.Snapshot) *sidx.Snapshot { return cur } +} +func (f *fakeSIDXInfinite) CurrentSnapshot() *sidx.Snapshot { return nil } +func (f *fakeSIDXInfinite) ReplaceSnapshot(*sidx.Snapshot) {} // TestStreamSIDXTraceBatches_InfiniteChannelContinuesUntilCanceled verifies that // the streaming pipeline continues streaming from an infinite channel until context is canceled.
