This is an automated email from the ASF dual-hosted git repository. hanahmily pushed a commit to branch tsdb-closer in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git
commit c2e6cdf40a7a9e095e0b70a4759dcd7ec74efafb Author: Gao Hongtao <[email protected]> AuthorDate: Thu Nov 10 12:44:50 2022 +0000 Add a Closer to manage the closing phase Signed-off-by: Gao Hongtao <[email protected]> --- .github/workflows/ci.yml | 10 +---- .github/workflows/dependency-review.yml | 34 ++++++++++++++ .gitignore | 6 +++ banyand/tsdb/block.go | 51 ++++----------------- banyand/tsdb/bucket/bucket.go | 32 ++++++++----- banyand/tsdb/bucket/queue.go | 10 +++-- banyand/tsdb/bucket/queue_test.go | 15 +++++-- banyand/tsdb/bucket/strategy.go | 36 +++++++++++---- banyand/tsdb/bucket/strategy_test.go | 4 +- banyand/tsdb/metric.go | 3 +- banyand/tsdb/retention.go | 27 ++++------- banyand/tsdb/segment.go | 40 +++++++++-------- banyand/tsdb/shard.go | 25 +++++++---- pkg/index/index.go | 2 - pkg/index/inverted/inverted.go | 5 --- pkg/index/lsm/lsm.go | 4 -- pkg/run/closer.go | 79 +++++++++++++++++++++++++++++++++ 17 files changed, 245 insertions(+), 138 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1b75422..70d0a8a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -161,17 +161,9 @@ jobs: run: make test-ci - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 - dependency-review: - runs-on: ubuntu-latest - steps: - - name: 'Checkout Repository' - uses: actions/checkout@v3 - - name: 'Dependency Review' - uses: actions/dependency-review-action@v2 result: name: Continuous Integration runs-on: ubuntu-20.04 - needs: [check, build, test, dependency-review] + needs: [check, build, test] steps: - run: echo 'success' - diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml new file mode 100644 index 0000000..09f5be9 --- /dev/null +++ b/.github/workflows/dependency-review.yml @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dependency Review Action +# +# This Action will scan dependency manifest files that change as part of a Pull Request, surfacing known-vulnerable versions of the packages declared or updated in the PR. Once installed, if the workflow run is marked as required, PRs introducing known-vulnerable packages will be blocked from merging. +# +name: 'Dependency Review' +on: [pull_request] + +permissions: + contents: read + +jobs: + dependency-review: + runs-on: ubuntu-latest + steps: + - name: 'Checkout Repository' + uses: actions/checkout@v3 + - name: 'Dependency Review' + uses: actions/dependency-review-action@v2 diff --git a/.gitignore b/.gitignore index 32e3104..14ba169 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,9 @@ target # Test binary, build with `go test -c` *.test +# Ginkgo test report +*.report + # Output of the go coverage tool, specifically when used with LiteIDE *.out @@ -47,3 +50,6 @@ target # mock files *mock.go *mock_test.go + +# snky cache +.dccache diff --git a/banyand/tsdb/block.go b/banyand/tsdb/block.go index 8d050f5..9a813c0 100644 --- a/banyand/tsdb/block.go +++ b/banyand/tsdb/block.go @@ -72,13 +72,12 @@ type block struct { invertedIndex index.Store lsmIndex index.Store closableLst []io.Closer + clock timestamp.Clock timestamp.TimeRange bucket.Reporter segID uint16 blockID uint16 encodingMethod EncodingMethod - flushCh chan struct{} - stopCh chan struct{} } type blockOpts struct { @@ -104,7 +103,7 @@ func newBlock(ctx context.Context, opts blockOpts) (b *block, err error) { l: logger.Fetch(ctx, "block"), TimeRange: opts.timeRange, Reporter: bucket.NewTimeBasedReporter(opts.timeRange, clock), - flushCh: make(chan struct{}, 1), + clock: clock, ref: &atomic.Int32{}, closed: &atomic.Bool{}, deleted: &atomic.Bool{}, @@ -154,9 +153,6 @@ func (b *block) open() (err error) { kv.TSSWithEncoding(b.encodingMethod.EncoderPool, b.encodingMethod.DecoderPool), kv.TSSWithLogger(b.l.Named(componentMain)), kv.TSSWithMemTableSize(b.memSize), - kv.TSSWithFlushCallback(func() { - b.flushCh <- struct{}{} - }), ); err != nil { return err } @@ -174,20 +170,9 @@ func (b *block) open() (err error) { }); err != nil { return err } + b.Reporter = bucket.NewTimeBasedReporter(b.TimeRange, b.clock) b.closableLst = append(b.closableLst, b.invertedIndex, b.lsmIndex) b.ref.Store(0) - stopCh := make(chan struct{}) - b.stopCh = stopCh - go func() { - for { - select { - case <-b.flushCh: - b.flush() - case <-stopCh: - return - } - } - }() b.closed.Store(false) return nil } @@ -268,55 +253,35 @@ func (b *block) waitDone(stopped *atomic.Bool) <-chan struct{} { return ch } -func (b *block) flush() { - for i := 0; i < 10; i++ { - err := b.invertedIndex.Flush() - if err == nil { - break - } - time.Sleep(time.Second) - b.l.Warn().Err(err).Int("retried", i).Msg("failed to flush inverted index") - } -} - func (b *block) close(ctx context.Context) (err error) { b.lock.Lock() defer b.lock.Unlock() if b.closed.Load() { return nil } - b.closed.Store(true) stopWaiting := &atomic.Bool{} ch := b.waitDone(stopWaiting) select { case <-ctx.Done(): - b.closed.Store(false) stopWaiting.Store(true) return errors.Wrapf(ErrBlockClosingInterrupted, "block:%s", b) case <-ch: } + b.closed.Store(true) + if b.Reporter != nil { + b.Stop() + } for _, closer := range b.closableLst { err = multierr.Append(err, closer.Close()) } - close(b.stopCh) return err } -func (b *block) stopThenClose(ctx context.Context) error { - if b.Reporter != nil { - b.Stop() - } - return b.close(ctx) -} - func (b *block) delete(ctx context.Context) error { if b.deleted.Load() { return nil } b.deleted.Store(true) - if b.Reporter != nil { - b.Stop() - } b.close(ctx) return os.RemoveAll(b.path) } @@ -326,7 +291,7 @@ func (b *block) Closed() bool { } func (b *block) String() string { - return fmt.Sprintf("BlockID-%d-%d", parseSuffix(b.segID), parseSuffix(b.blockID)) + return fmt.Sprintf("BlockID-%d-%s", parseSuffix(b.segID), b.suffix) } func (b *block) stats() (names []string, stats []observability.Statistics) { diff --git a/banyand/tsdb/bucket/bucket.go b/banyand/tsdb/bucket/bucket.go index 5838c26..d0fe08a 100644 --- a/banyand/tsdb/bucket/bucket.go +++ b/banyand/tsdb/bucket/bucket.go @@ -18,11 +18,15 @@ package bucket import ( + "errors" "time" + "github.com/apache/skywalking-banyandb/pkg/run" "github.com/apache/skywalking-banyandb/pkg/timestamp" ) +var ErrReporterClosed = errors.New("reporter is closed") + type Controller interface { Current() (Reporter, error) Next() (Reporter, error) @@ -37,7 +41,7 @@ type Status struct { type Channel chan Status type Reporter interface { - Report() Channel + Report() (Channel, error) Stop() String() string } @@ -46,28 +50,34 @@ var _ Reporter = (*timeBasedReporter)(nil) type timeBasedReporter struct { timestamp.TimeRange - reporterStopCh chan struct{} - clock timestamp.Clock + clock timestamp.Clock + closer *run.Closer } func NewTimeBasedReporter(timeRange timestamp.TimeRange, clock timestamp.Clock) Reporter { if timeRange.End.Before(clock.Now()) { return nil } - return &timeBasedReporter{ - TimeRange: timeRange, - reporterStopCh: make(chan struct{}), - clock: clock, + t := &timeBasedReporter{ + TimeRange: timeRange, + clock: clock, + closer: run.NewCloser(0), } + return t } -func (tr *timeBasedReporter) Report() Channel { +func (tr *timeBasedReporter) Report() (Channel, error) { + if tr.closer.Closed() { + return nil, ErrReporterClosed + } ch := make(Channel, 1) interval := tr.Duration() >> 4 if interval < 100*time.Millisecond { interval = 100 * time.Millisecond } go func() { + tr.closer.AddRunning() + defer tr.closer.Done() defer close(ch) ticker := tr.clock.Ticker(interval) defer ticker.Stop() @@ -82,14 +92,14 @@ func (tr *timeBasedReporter) Report() Channel { if status.Volume >= status.Capacity { return } - case <-tr.reporterStopCh: + case <-tr.closer.CloseNotify(): return } } }() - return ch + return ch, nil } func (tr *timeBasedReporter) Stop() { - close(tr.reporterStopCh) + tr.closer.CloseThenWait() } diff --git a/banyand/tsdb/bucket/queue.go b/banyand/tsdb/bucket/queue.go index 0bcf49a..04f782b 100644 --- a/banyand/tsdb/bucket/queue.go +++ b/banyand/tsdb/bucket/queue.go @@ -28,6 +28,7 @@ import ( "github.com/robfig/cron/v3" "github.com/apache/skywalking-banyandb/pkg/logger" + "github.com/apache/skywalking-banyandb/pkg/run" "github.com/apache/skywalking-banyandb/pkg/timestamp" ) @@ -66,7 +67,7 @@ type lruQueue struct { recentEvict simplelru.LRUCache lock sync.RWMutex - stopCh chan struct{} + closer *run.Closer } func NewQueue(logger *logger.Logger, size int, maxSize int, clock timestamp.Clock, evictFn EvictFn) (Queue, error) { @@ -98,7 +99,7 @@ func NewQueue(logger *logger.Logger, size int, maxSize int, clock timestamp.Cloc evictSize: evictSize, evictFn: evictFn, l: logger, - stopCh: make(chan struct{}), + closer: run.NewCloser(1), } parser := cron.NewParser(cron.Second) // every 60 seconds to clean up recentEvict @@ -107,6 +108,7 @@ func NewQueue(logger *logger.Logger, size int, maxSize int, clock timestamp.Cloc return nil, err } go func() { + defer c.closer.Done() now := clock.Now() for { next := scheduler.Next(now) @@ -133,7 +135,7 @@ func NewQueue(logger *logger.Logger, size int, maxSize int, clock timestamp.Cloc cancel() c.lock.Unlock() } - case <-c.stopCh: + case <-c.closer.CloseNotify(): c.l.Info().Msg("stop") timer.Stop() return @@ -281,6 +283,6 @@ func (q *lruQueue) removeOldest(ctx context.Context, lst simplelru.LRUCache) err } func (q *lruQueue) Close() error { - close(q.stopCh) + q.closer.CloseThenWait() return nil } diff --git a/banyand/tsdb/bucket/queue_test.go b/banyand/tsdb/bucket/queue_test.go index f256358..8e16306 100644 --- a/banyand/tsdb/bucket/queue_test.go +++ b/banyand/tsdb/bucket/queue_test.go @@ -19,6 +19,7 @@ package bucket_test import ( "context" "strconv" + "sync" "time" . "github.com/onsi/ginkgo/v2" @@ -27,6 +28,7 @@ import ( "github.com/apache/skywalking-banyandb/banyand/tsdb/bucket" "github.com/apache/skywalking-banyandb/pkg/logger" + "github.com/apache/skywalking-banyandb/pkg/test/flags" "github.com/apache/skywalking-banyandb/pkg/timestamp" ) @@ -47,6 +49,7 @@ func entryID(id uint16) queueEntryID { } var _ = Describe("Queue", func() { + var lock sync.Mutex var evictLst []queueEntryID var l bucket.Queue var clock timestamp.MockClock @@ -60,6 +63,8 @@ var _ = Describe("Queue", func() { clock.Set(time.Date(1970, 0o1, 0o1, 0, 0, 0, 0, time.Local)) var err error l, err = bucket.NewQueue(logger.GetLogger("test"), 128, 192, clock, func(_ context.Context, id interface{}) error { + lock.Lock() + defer lock.Unlock() evictLst = append(evictLst, id.(queueEntryID)) return nil }) @@ -126,8 +131,12 @@ var _ = Describe("Queue", func() { Expect(enRecentSize).To(Equal(192)) Expect(l.Len()).To(Equal(128)) Expect(len(evictLst)).To(Equal(0)) - clock.Add(time.Minute) - GinkgoWriter.Printf("evicted size:%d \n", len(evictLst)) - Expect(len(evictLst)).To(BeNumerically(">", 1)) + Eventually(func() int { + clock.Add(time.Minute) + clock.TriggerTimer() + lock.Lock() + defer lock.Unlock() + return len(evictLst) + }).WithTimeout(flags.EventuallyTimeout).Should(BeNumerically(">", 1)) }) }) diff --git a/banyand/tsdb/bucket/strategy.go b/banyand/tsdb/bucket/strategy.go index 6bf51f2..bc09583 100644 --- a/banyand/tsdb/bucket/strategy.go +++ b/banyand/tsdb/bucket/strategy.go @@ -26,6 +26,7 @@ import ( "go.uber.org/multierr" "github.com/apache/skywalking-banyandb/pkg/logger" + "github.com/apache/skywalking-banyandb/pkg/run" ) var ( @@ -42,7 +43,8 @@ type Strategy struct { current atomic.Value currentRatio uint64 logger *logger.Logger - stopCh chan struct{} + + closer *run.Closer } type StrategyOptions func(*Strategy) @@ -71,7 +73,7 @@ func NewStrategy(ctrl Controller, options ...StrategyOptions) (*Strategy, error) strategy := &Strategy{ ctrl: ctrl, ratio: 0.8, - stopCh: make(chan struct{}), + closer: run.NewCloser(1), } for _, opt := range options { opt(strategy) @@ -82,18 +84,36 @@ func NewStrategy(ctrl Controller, options ...StrategyOptions) (*Strategy, error) if strategy.logger == nil { strategy.logger = logger.GetLogger("bucket-strategy") } - c, err := ctrl.Current() - if err != nil { + if err := strategy.resetCurrent(); err != nil { return nil, err } - strategy.current.Store(c) return strategy, nil } +func (s *Strategy) resetCurrent() error { + c, err := s.ctrl.Current() + if err != nil { + return err + } + s.current.Store(c) + return nil +} + func (s *Strategy) Run() { go func(s *Strategy) { + defer s.closer.Done() for { - c := s.current.Load().(Reporter).Report() + c, err := s.current.Load().(Reporter).Report() + if errors.Is(err, ErrReporterClosed) { + return + } + if err != nil { + s.logger.Error().Err(err).Msg("failed to get reporter") + if err := s.resetCurrent(); err != nil { + panic(err) + } + continue + } if !s.observe(c) { return } @@ -138,12 +158,12 @@ func (s *Strategy) observe(c Channel) bool { } return moreBucket } - case <-s.stopCh: + case <-s.closer.CloseNotify(): return false } } } func (s *Strategy) Close() { - close(s.stopCh) + s.closer.CloseThenWait() } diff --git a/banyand/tsdb/bucket/strategy_test.go b/banyand/tsdb/bucket/strategy_test.go index 00f6ef5..64adbda 100644 --- a/banyand/tsdb/bucket/strategy_test.go +++ b/banyand/tsdb/bucket/strategy_test.go @@ -133,7 +133,7 @@ type reporter struct { step int } -func (r *reporter) Report() bucket.Channel { +func (r *reporter) Report() (bucket.Channel, error) { ch := make(bucket.Channel, r.capacity) go func() { var volume int @@ -146,7 +146,7 @@ func (r *reporter) Report() bucket.Channel { } close(ch) }() - return ch + return ch, nil } func (r *reporter) Stop() { diff --git a/banyand/tsdb/metric.go b/banyand/tsdb/metric.go index a97cddf..696e019 100644 --- a/banyand/tsdb/metric.go +++ b/banyand/tsdb/metric.go @@ -52,13 +52,14 @@ func init() { func (s *shard) runStat() { go func() { + defer s.closer.Done() ticker := time.NewTicker(statInterval) defer ticker.Stop() for { select { case <-ticker.C: s.stat() - case <-s.stopCh: + case <-s.closer.CloseNotify(): return } } diff --git a/banyand/tsdb/retention.go b/banyand/tsdb/retention.go index 755bf41..78f7747 100644 --- a/banyand/tsdb/retention.go +++ b/banyand/tsdb/retention.go @@ -19,22 +19,21 @@ package tsdb import ( "context" - "sync" "time" "github.com/robfig/cron/v3" "github.com/apache/skywalking-banyandb/pkg/logger" + "github.com/apache/skywalking-banyandb/pkg/run" ) type retentionController struct { segment *segmentController scheduler cron.Schedule - stopped bool - stopMux sync.Mutex - stopCh chan struct{} duration time.Duration - l *logger.Logger + + closer *run.Closer + l *logger.Logger } func newRetentionController(segment *segmentController, ttl IntervalRule) (*retentionController, error) { @@ -56,22 +55,18 @@ func newRetentionController(segment *segmentController, ttl IntervalRule) (*rete return &retentionController{ segment: segment, scheduler: scheduler, - stopCh: make(chan struct{}), l: segment.l.Named("retention-controller"), duration: ttl.EstimatedDuration(), + closer: run.NewCloser(1), }, nil } func (rc *retentionController) start() { - rc.stopMux.Lock() - if rc.stopped { - return - } - rc.stopMux.Unlock() go rc.run() } func (rc *retentionController) run() { + defer rc.closer.Done() rc.l.Info().Msg("start") now := rc.segment.clock.Now() for { @@ -85,7 +80,7 @@ func (rc *retentionController) run() { rc.l.Error().Err(err) } cancel() - case <-rc.stopCh: + case <-rc.closer.CloseNotify(): timer.Stop() rc.l.Info().Msg("stop") return @@ -94,11 +89,5 @@ func (rc *retentionController) run() { } func (rc *retentionController) stop() { - rc.stopMux.Lock() - defer rc.stopMux.Unlock() - if rc.stopped { - return - } - rc.stopped = true - close(rc.stopCh) + rc.closer.CloseThenWait() } diff --git a/banyand/tsdb/segment.go b/banyand/tsdb/segment.go index 6d08e2f..abfcd28 100644 --- a/banyand/tsdb/segment.go +++ b/banyand/tsdb/segment.go @@ -20,7 +20,6 @@ package tsdb import ( "context" "errors" - "fmt" "os" "sort" "strconv" @@ -50,6 +49,7 @@ type segment struct { bucket.Reporter blockController *blockController blockManageStrategy *bucket.Strategy + closeOnce sync.Once } func openSegment(ctx context.Context, startTime time.Time, path, suffix string, @@ -111,21 +111,23 @@ func openSegment(ctx context.Context, startTime time.Time, path, suffix string, return s, nil } -func (s *segment) close(ctx context.Context) error { - if err := s.blockController.close(ctx); err != nil { - return err - } - if s.globalIndex != nil { - if err := s.globalIndex.Close(); err != nil { - return err +func (s *segment) close(ctx context.Context) (err error) { + s.closeOnce.Do(func() { + if err = s.blockController.close(ctx); err != nil { + return } - } - if s.blockManageStrategy != nil { - s.blockManageStrategy.Close() - } - if s.Reporter != nil { - s.Stop() - } + if s.globalIndex != nil { + if err = s.globalIndex.Close(); err != nil { + return + } + } + if s.blockManageStrategy != nil { + s.blockManageStrategy.Close() + } + if s.Reporter != nil { + s.Stop() + } + }) return nil } @@ -140,8 +142,8 @@ func (s *segment) delete(ctx context.Context) error { return os.RemoveAll(s.path) } -func (s segment) String() string { - return fmt.Sprintf("SegID-%d", parseSuffix(s.id)) +func (s *segment) String() string { + return "SegID-" + s.suffix } func (s *segment) Stats() observability.Statistics { @@ -431,8 +433,10 @@ func (bc *blockController) sortLst() { } func (bc *blockController) close(ctx context.Context) (err error) { + bc.Lock() + defer bc.Unlock() for _, s := range bc.lst { - err = multierr.Append(err, s.stopThenClose(ctx)) + err = multierr.Append(err, s.close(ctx)) } return err } diff --git a/banyand/tsdb/shard.go b/banyand/tsdb/shard.go index b77a232..93c9824 100644 --- a/banyand/tsdb/shard.go +++ b/banyand/tsdb/shard.go @@ -30,6 +30,7 @@ import ( "github.com/apache/skywalking-banyandb/api/common" "github.com/apache/skywalking-banyandb/banyand/tsdb/bucket" "github.com/apache/skywalking-banyandb/pkg/logger" + "github.com/apache/skywalking-banyandb/pkg/run" "github.com/apache/skywalking-banyandb/pkg/timestamp" ) @@ -51,7 +52,9 @@ type shard struct { segmentController *segmentController segmentManageStrategy *bucket.Strategy retentionController *retentionController - stopCh chan struct{} + + closeOnce sync.Once + closer *run.Closer } func OpenShard(ctx context.Context, id common.ShardID, @@ -79,7 +82,7 @@ func OpenShard(ctx context.Context, id common.ShardID, id: id, segmentController: sc, l: l, - stopCh: make(chan struct{}), + closer: run.NewCloser(1), } err = s.segmentController.open() if err != nil { @@ -162,13 +165,15 @@ func (s *shard) State() (shardState ShardState) { return shardState } -func (s *shard) Close() error { - s.retentionController.stop() - s.segmentManageStrategy.Close() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - err := multierr.Combine(s.segmentController.close(ctx), s.seriesDatabase.Close()) - close(s.stopCh) +func (s *shard) Close() (err error) { + s.closeOnce.Do(func() { + s.closer.CloseThenWait() + s.retentionController.stop() + s.segmentManageStrategy.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + err = multierr.Combine(s.segmentController.close(ctx), s.seriesDatabase.Close()) + }) return err } @@ -445,6 +450,8 @@ func (sc *segmentController) removeSeg(segID uint16) { } func (sc *segmentController) close(ctx context.Context) (err error) { + sc.Lock() + defer sc.Unlock() for _, s := range sc.lst { err = multierr.Append(err, s.close(ctx)) } diff --git a/pkg/index/index.go b/pkg/index/index.go index 6f490cd..6d06790 100644 --- a/pkg/index/index.go +++ b/pkg/index/index.go @@ -166,8 +166,6 @@ type Store interface { io.Closer Writer Searcher - // Flush flushed memory data to disk - Flush() error } type GetSearcher func(location databasev1.IndexRule_Type) (Searcher, error) diff --git a/pkg/index/inverted/inverted.go b/pkg/index/inverted/inverted.go index 5ce8878..4f680b8 100644 --- a/pkg/index/inverted/inverted.go +++ b/pkg/index/inverted/inverted.go @@ -204,11 +204,6 @@ func (s *store) Range(fieldKey index.FieldKey, opts index.RangeOpts) (list posti return } -// Flush flushed memory data to disk -func (s *store) Flush() error { - return nil -} - type blugeMatchIterator struct { delegated search.DocumentMatchIterator fieldKey string diff --git a/pkg/index/lsm/lsm.go b/pkg/index/lsm/lsm.go index 8c366ad..0a26c26 100644 --- a/pkg/index/lsm/lsm.go +++ b/pkg/index/lsm/lsm.go @@ -35,10 +35,6 @@ type store struct { l *logger.Logger } -func (*store) Flush() error { - panic("do not call flush here. LSM index is using its own controller to flush memory data") -} - func (s *store) Stats() observability.Statistics { return s.lsm.Stats() } diff --git a/pkg/run/closer.go b/pkg/run/closer.go new file mode 100644 index 0000000..cc70e53 --- /dev/null +++ b/pkg/run/closer.go @@ -0,0 +1,79 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package run + +import ( + "context" + "sync" + "sync/atomic" +) + +// Closer can close a goroutine then wait for it to stop. +type Closer struct { + waiting sync.WaitGroup + closed *atomic.Bool + + ctx context.Context + cancel context.CancelFunc +} + +// NewCloser instances a new Closer. +func NewCloser(initial int) *Closer { + c := &Closer{} + c.ctx, c.cancel = context.WithCancel(context.Background()) + c.closed = &atomic.Bool{} + c.waiting.Add(initial) + return c +} + +// AddRunning adds a running task. +func (c *Closer) AddRunning() { + c.waiting.Add(1) +} + +// Close sends a signal to the CloseNotify. +func (c *Closer) Close() { + c.closed.Store(true) + c.cancel() +} + +// CloseNotify receives a signal from Close. +func (c *Closer) CloseNotify() <-chan struct{} { + return c.ctx.Done() +} + +// Done notifies that one task is done. +func (c *Closer) Done() { + c.waiting.Done() +} + +// Wait waits until all tasks are done. +func (c *Closer) Wait() { + c.waiting.Wait() +} + +// CloseThenWait calls Close(), then Wait(). +func (c *Closer) CloseThenWait() { + c.Close() + c.Wait() +} + +// Closed returns whether the Closer is closed +func (c *Closer) Closed() bool { + return c.closed.Load() +}
