This is an automated email from the ASF dual-hosted git repository.

hanahmily pushed a commit to branch storage-column
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git


The following commit(s) were added to refs/heads/storage-column by this push:
     new b6594cea Add TSTable
b6594cea is described below

commit b6594ceabf5becfc35257d546cd9f13b3f60361e
Author: Gao Hongtao <[email protected]>
AuthorDate: Tue Nov 21 08:46:39 2023 +0000

    Add TSTable
    
    Signed-off-by: Gao Hongtao <[email protected]>
---
 banyand/internal/storage/retention.go |  58 ++++++
 banyand/internal/storage/segment.go   | 373 ++++++++++++++++++++++++++++++++++
 banyand/internal/storage/shard.go     |  84 ++++++++
 banyand/internal/storage/storage.go   |  27 +--
 banyand/internal/storage/tsdb.go      | 217 ++++++++++----------
 pkg/fs/file_system.go                 |  19 +-
 pkg/fs/local_file_system.go           |  28 ++-
 7 files changed, 682 insertions(+), 124 deletions(-)

diff --git a/banyand/internal/storage/retention.go 
b/banyand/internal/storage/retention.go
new file mode 100644
index 00000000..9d9e8c65
--- /dev/null
+++ b/banyand/internal/storage/retention.go
@@ -0,0 +1,58 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package storage
+
+import (
+       "time"
+
+       "github.com/robfig/cron/v3"
+
+       "github.com/apache/skywalking-banyandb/pkg/logger"
+)
+
+type retentionTask[T TSTable[T]] struct {
+       segment  *segmentController[T]
+       expr     string
+       option   cron.ParseOption
+       duration time.Duration
+}
+
+func newRetentionTask[T TSTable[T]](segment *segmentController[T], ttl 
IntervalRule) *retentionTask[T] {
+       var expr string
+       switch ttl.Unit {
+       case HOUR:
+               // Every hour on the 5th minute
+               expr = "5 *"
+       case DAY:
+               // Every day on 00:05
+               expr = "5 0"
+       }
+       return &retentionTask[T]{
+               segment:  segment,
+               option:   cron.Minute | cron.Hour,
+               expr:     expr,
+               duration: ttl.estimatedDuration(),
+       }
+}
+
+func (rc *retentionTask[T]) run(now time.Time, l *logger.Logger) bool {
+       if err := rc.segment.remove(now.Add(-rc.duration)); err != nil {
+               l.Error().Err(err)
+       }
+       return true
+}
diff --git a/banyand/internal/storage/segment.go 
b/banyand/internal/storage/segment.go
new file mode 100644
index 00000000..639b2f4a
--- /dev/null
+++ b/banyand/internal/storage/segment.go
@@ -0,0 +1,373 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package storage
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "path"
+       "sort"
+       "strconv"
+       "sync"
+       "time"
+
+       "go.uber.org/multierr"
+
+       "github.com/apache/skywalking-banyandb/api/common"
+       "github.com/apache/skywalking-banyandb/banyand/tsdb/bucket"
+       "github.com/apache/skywalking-banyandb/pkg/index"
+       "github.com/apache/skywalking-banyandb/pkg/logger"
+       "github.com/apache/skywalking-banyandb/pkg/run"
+       "github.com/apache/skywalking-banyandb/pkg/timestamp"
+)
+
+var errEndOfSegment = errors.New("reached the end of the segment")
+
+type segment[T TSTable[T]] struct {
+       bucket.Reporter
+       invertedIndex index.Store
+       lsmIndex      index.Store
+       tsTable       TSTable[T]
+       l             *logger.Logger
+       closer        *run.Closer
+       position      common.Position
+       timestamp.TimeRange
+       path      string
+       suffix    string
+       closeOnce sync.Once
+       id        SegID
+}
+
+func openSegment[T TSTable[T]](ctx context.Context, startTime, endTime 
time.Time, path, suffix string,
+       segmentSize IntervalRule, scheduler *timestamp.Scheduler, tsTable 
TSTable[T],
+) (s *segment[T], err error) {
+       suffixInteger, err := strconv.Atoi(suffix)
+       if err != nil {
+               return nil, err
+       }
+       id := GenerateSegID(segmentSize.Unit, suffixInteger)
+       timeRange := timestamp.NewSectionTimeRange(startTime, endTime)
+       s = &segment[T]{
+               id:        id,
+               path:      path,
+               suffix:    suffix,
+               TimeRange: timeRange,
+               position:  common.GetPosition(ctx),
+               tsTable:   tsTable,
+               closer:    run.NewCloser(1),
+       }
+       l := logger.Fetch(ctx, s.String())
+       s.l = l
+       clock, _ := timestamp.GetClock(ctx)
+       s.Reporter = bucket.NewTimeBasedReporter(s.String(), timeRange, clock, 
scheduler)
+       return s, nil
+}
+
+func (s *segment[T]) close() {
+       s.closeOnce.Do(func() {
+               s.closer.Done()
+               s.closer.CloseThenWait()
+               _ = s.tsTable.Close()
+       })
+}
+
+func (s *segment[T]) delete() error {
+       s.close()
+       return lfs.DeleteFile(s.path)
+}
+
+func (s *segment[T]) String() string {
+       return "SegID-" + s.suffix
+}
+
+type segmentController[T TSTable[T]] struct {
+       clock          timestamp.Clock
+       scheduler      *timestamp.Scheduler
+       l              *logger.Logger
+       tsTableCreator TSTableCreator[T]
+       position       common.Position
+       location       string
+       lst            []*segment[T]
+       segmentSize    IntervalRule
+       sync.RWMutex
+}
+
+func newSegmentController[T TSTable[T]](ctx context.Context, location string,
+       segmentSize IntervalRule, l *logger.Logger, scheduler 
*timestamp.Scheduler,
+       tsTableCreator TSTableCreator[T],
+) *segmentController[T] {
+       clock, _ := timestamp.GetClock(ctx)
+       return &segmentController[T]{
+               location:       location,
+               segmentSize:    segmentSize,
+               l:              l,
+               clock:          clock,
+               scheduler:      scheduler,
+               position:       common.GetPosition(ctx),
+               tsTableCreator: tsTableCreator,
+       }
+}
+
+func (sc *segmentController[T]) selectTSTables(timeRange timestamp.TimeRange) 
(tt []TSTable[T]) {
+       lst := sc.segments()
+       last := len(lst) - 1
+       for i := range lst {
+               s := lst[last-i]
+               if s.Overlapping(timeRange) && s.closer.AddRunning() {
+                       tt = append(tt, s.tsTable)
+               }
+       }
+       return tt
+}
+
+func (sc *segmentController[T]) createTSTable(ts time.Time) (TSTable[T], 
error) {
+       s, err := sc.create(ts)
+       if err != nil {
+               return nil, err
+       }
+       if s.closer.AddRunning() {
+               return s.tsTable, nil
+       }
+       return nil, errors.New("segmentController is closed")
+}
+
+func (sc *segmentController[T]) put(tsTables ...TSTable[T]) {
+       lst := sc.segments()
+       for _, t := range tsTables {
+               for _, s := range lst {
+                       if s.tsTable == t {
+                               s.closer.Done()
+                       }
+               }
+       }
+}
+
+func (sc *segmentController[T]) segments() (ss []*segment[T]) {
+       sc.RLock()
+       defer sc.RUnlock()
+       r := make([]*segment[T], len(sc.lst))
+       copy(r, sc.lst)
+       return r
+}
+
+func (sc *segmentController[T]) Current() (bucket.Reporter, error) {
+       now := sc.Standard(sc.clock.Now())
+       ns := uint64(now.UnixNano())
+       if b := func() bucket.Reporter {
+               sc.RLock()
+               defer sc.RUnlock()
+               for _, s := range sc.lst {
+                       if s.Contains(ns) {
+                               return s
+                       }
+               }
+               return nil
+       }(); b != nil {
+               return b, nil
+       }
+       return sc.create(now)
+}
+
+func (sc *segmentController[T]) Next() (bucket.Reporter, error) {
+       c, err := sc.Current()
+       if err != nil {
+               return nil, err
+       }
+       seg := c.(*segment[T])
+       reporter, err := sc.create(sc.segmentSize.nextTime(seg.Start))
+       if errors.Is(err, errEndOfSegment) {
+               return nil, bucket.ErrNoMoreBucket
+       }
+       return reporter, err
+}
+
+func (sc *segmentController[T]) OnMove(prev bucket.Reporter, next 
bucket.Reporter) {
+       event := sc.l.Info()
+       if prev != nil {
+               event.Stringer("prev", prev)
+       }
+       if next != nil {
+               event.Stringer("next", next)
+       }
+       event.Msg("move to the next segment")
+}
+
+func (sc *segmentController[T]) Standard(t time.Time) time.Time {
+       switch sc.segmentSize.Unit {
+       case HOUR:
+               return time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), 0, 0, 
0, t.Location())
+       case DAY:
+               return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, 
t.Location())
+       }
+       panic("invalid interval unit")
+}
+
+func (sc *segmentController[T]) Format(tm time.Time) string {
+       switch sc.segmentSize.Unit {
+       case HOUR:
+               return tm.Format(hourFormat)
+       case DAY:
+               return tm.Format(dayFormat)
+       }
+       panic("invalid interval unit")
+}
+
+func (sc *segmentController[T]) Parse(value string) (time.Time, error) {
+       switch sc.segmentSize.Unit {
+       case HOUR:
+               return time.ParseInLocation(hourFormat, value, time.Local)
+       case DAY:
+               return time.ParseInLocation(dayFormat, value, time.Local)
+       }
+       panic("invalid interval unit")
+}
+
+func (sc *segmentController[T]) open() error {
+       sc.Lock()
+       defer sc.Unlock()
+       return loadSegments(sc.location, segPathPrefix, sc, sc.segmentSize, 
func(start, end time.Time) error {
+               _, err := sc.load(start, end, sc.location)
+               if errors.Is(err, errEndOfSegment) {
+                       return nil
+               }
+               return err
+       })
+}
+
+func (sc *segmentController[T]) create(start time.Time) (*segment[T], error) {
+       sc.Lock()
+       defer sc.Unlock()
+       start = sc.Standard(start)
+       var next *segment[T]
+       for _, s := range sc.lst {
+               if s.Contains(uint64(start.UnixNano())) {
+                       return s, nil
+               }
+               if next == nil && s.Start.After(start) {
+                       next = s
+               }
+       }
+       stdEnd := sc.segmentSize.nextTime(start)
+       var end time.Time
+       if next != nil && next.Start.Before(stdEnd) {
+               end = next.Start
+       } else {
+               end = stdEnd
+       }
+       lfs.MkdirPanicIfExist(path.Join(sc.location, fmt.Sprintf(segTemplate, 
sc.Format(start))), dirPerm)
+       return sc.load(start, end, sc.location)
+}
+
+func (sc *segmentController[T]) sortLst() {
+       sort.Slice(sc.lst, func(i, j int) bool {
+               return sc.lst[i].id < sc.lst[j].id
+       })
+}
+
+func (sc *segmentController[T]) load(start, end time.Time, root string) (seg 
*segment[T], err error) {
+       var tsTable TSTable[T]
+       if tsTable, err = sc.tsTableCreator(sc.location, sc.position, sc.l, 
timestamp.NewSectionTimeRange(start, end)); err != nil {
+               return nil, err
+       }
+       suffix := sc.Format(start)
+       ctx := context.WithValue(context.Background(), logger.ContextKey, sc.l)
+       seg, err = openSegment[T](common.SetPosition(ctx, func(p 
common.Position) common.Position {
+               p.Segment = suffix
+               return p
+       }), start, end, path.Join(root, fmt.Sprintf(segTemplate, suffix)), 
suffix, sc.segmentSize, sc.scheduler, tsTable)
+       if err != nil {
+               return nil, err
+       }
+       sc.lst = append(sc.lst, seg)
+       sc.sortLst()
+       return seg, nil
+}
+
+func (sc *segmentController[T]) remove(deadline time.Time) (err error) {
+       sc.l.Info().Time("deadline", deadline).Msg("start to remove before 
deadline")
+       for _, s := range sc.segments() {
+               if s.End.Before(deadline) || 
s.Contains(uint64(deadline.UnixNano())) {
+                       if e := sc.l.Debug(); e.Enabled() {
+                               e.Stringer("segment", s).Msg("start to remove 
data in a segment")
+                       }
+                       if s.End.Before(deadline) {
+                               sc.Lock()
+                               if errDel := s.delete(); errDel != nil {
+                                       err = multierr.Append(err, errDel)
+                               } else {
+                                       sc.removeSeg(s.id)
+                               }
+                               sc.Unlock()
+                       }
+               }
+       }
+       return err
+}
+
+func (sc *segmentController[T]) removeSeg(segID SegID) {
+       for i, b := range sc.lst {
+               if b.id == segID {
+                       sc.lst = append(sc.lst[:i], sc.lst[i+1:]...)
+                       break
+               }
+       }
+}
+
+func (sc *segmentController[T]) close() {
+       sc.Lock()
+       defer sc.Unlock()
+       for _, s := range sc.lst {
+               s.close()
+       }
+       sc.lst = sc.lst[:0]
+}
+
+type parser interface {
+       Parse(value string) (time.Time, error)
+}
+
+func loadSegments(root, prefix string, parser parser, intervalRule 
IntervalRule, loadFn func(start, end time.Time) error) error {
+       var startTimeLst []time.Time
+       if err := walkDir(
+               root,
+               prefix,
+               func(suffix string) error {
+                       startTime, err := parser.Parse(suffix)
+                       if err != nil {
+                               return err
+                       }
+                       startTimeLst = append(startTimeLst, startTime)
+                       return nil
+               }); err != nil {
+               return err
+       }
+       sort.Slice(startTimeLst, func(i, j int) bool { return i < j })
+       for i, start := range startTimeLst {
+               var end time.Time
+               if i < len(startTimeLst)-1 {
+                       end = startTimeLst[i+1]
+               } else {
+                       end = intervalRule.nextTime(start)
+               }
+               if err := loadFn(start, end); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
diff --git a/banyand/internal/storage/shard.go 
b/banyand/internal/storage/shard.go
new file mode 100644
index 00000000..5cd59920
--- /dev/null
+++ b/banyand/internal/storage/shard.go
@@ -0,0 +1,84 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package storage
+
+import (
+       "context"
+       "fmt"
+       "path"
+       "strconv"
+       "sync"
+
+       "github.com/apache/skywalking-banyandb/api/common"
+       "github.com/apache/skywalking-banyandb/banyand/tsdb/bucket"
+       "github.com/apache/skywalking-banyandb/pkg/logger"
+       "github.com/apache/skywalking-banyandb/pkg/timestamp"
+)
+
+type shard[T TSTable[T]] struct {
+       l                     *logger.Logger
+       segmentController     *segmentController[T]
+       segmentManageStrategy *bucket.Strategy
+       scheduler             *timestamp.Scheduler
+       position              common.Position
+       closeOnce             sync.Once
+       id                    common.ShardID
+}
+
+func (d *database[T]) openShard(ctx context.Context, id common.ShardID) 
(*shard[T], error) {
+       location := path.Join(d.location, fmt.Sprintf(shardTemplate, int(id)))
+       lfs.MkdirIfNotExist(location, dirPerm)
+       l := logger.Fetch(ctx, "shard"+strconv.Itoa(int(id)))
+       l.Info().Int("shard_id", int(id)).Str("path", location).Msg("creating a 
shard")
+       shardCtx := context.WithValue(ctx, logger.ContextKey, l)
+       shardCtx = common.SetPosition(shardCtx, func(p common.Position) 
common.Position {
+               p.Shard = strconv.Itoa(int(id))
+               return p
+       })
+       clock, _ := timestamp.GetClock(shardCtx)
+
+       scheduler := timestamp.NewScheduler(l, clock)
+       s := &shard[T]{
+               id:                id,
+               l:                 l,
+               scheduler:         scheduler,
+               position:          common.GetPosition(shardCtx),
+               segmentController: newSegmentController[T](shardCtx, location, 
d.opts.SegmentInterval, l, scheduler, d.opts.TSTableCreator),
+       }
+       var err error
+       if err = s.segmentController.open(); err != nil {
+               return nil, err
+       }
+       if s.segmentManageStrategy, err = 
bucket.NewStrategy(s.segmentController, bucket.WithLogger(s.l)); err != nil {
+               return nil, err
+       }
+       s.segmentManageStrategy.Run()
+       retentionTask := newRetentionTask(s.segmentController, d.opts.TTL)
+       if err := scheduler.Register("retention", retentionTask.option, 
retentionTask.expr, retentionTask.run); err != nil {
+               return nil, err
+       }
+       return s, nil
+}
+
+func (s *shard[T]) closer() {
+       s.closeOnce.Do(func() {
+               s.scheduler.Close()
+               s.segmentManageStrategy.Close()
+               s.segmentController.close()
+       })
+}
diff --git a/banyand/internal/storage/storage.go 
b/banyand/internal/storage/storage.go
index ff0c151a..5a47be6b 100644
--- a/banyand/internal/storage/storage.go
+++ b/banyand/internal/storage/storage.go
@@ -30,6 +30,9 @@ import (
        "github.com/pkg/errors"
 
        "github.com/apache/skywalking-banyandb/api/common"
+       "github.com/apache/skywalking-banyandb/pkg/fs"
+       "github.com/apache/skywalking-banyandb/pkg/logger"
+       "github.com/apache/skywalking-banyandb/pkg/timestamp"
 )
 
 const (
@@ -37,6 +40,7 @@ const (
        shardTemplate   = shardPathPrefix + "-%d"
        metadataPath    = "metadata"
        segTemplate     = "seg-%s"
+       segPathPrefix   = "seg"
 
        hourFormat = "2006010215"
        dayFormat  = "20060102"
@@ -48,28 +52,27 @@ var (
        // ErrUnknownShard indicates that the shard is not found.
        ErrUnknownShard = errors.New("unknown shard")
        errOpenDatabase = errors.New("fails to open the database")
+
+       lfs = fs.NewLocalFileSystemWithLogger(logger.GetLogger("storage"))
 )
 
 // Supplier allows getting a tsdb's runtime.
-type Supplier interface {
-       SupplyTSDB() Database
-}
+type SupplyTSDB[T TSTable[T]] func() TSDB[T]
 
-// Database allows listing and getting shard details.
-type Database interface {
+// TSDB allows listing and getting shard details.
+type TSDB[T TSTable[T]] interface {
        io.Closer
-       CreateShardsAndGetByID(id common.ShardID) (Shard, error)
-       Shards() []Shard
-       Shard(id common.ShardID) (Shard, error)
 }
 
-// Shard allows accessing data of tsdb.
-type Shard interface {
+// TSTable is time series table.
+type TSTable[T any] interface {
        io.Closer
-       ID() common.ShardID
-       // Series() SeriesDatabase
 }
 
+// TSTableCreator creates a TSTable.
+type TSTableCreator[T TSTable[T]] func(root string, position common.Position,
+       l *logger.Logger, timeRange timestamp.TimeRange) (T, error)
+
 // IntervalUnit denotes the unit of a time point.
 type IntervalUnit int
 
diff --git a/banyand/internal/storage/tsdb.go b/banyand/internal/storage/tsdb.go
index aab9c8c1..6bb8daf9 100644
--- a/banyand/internal/storage/tsdb.go
+++ b/banyand/internal/storage/tsdb.go
@@ -26,16 +26,17 @@ package storage
 import (
        "context"
        "path/filepath"
+       "strconv"
+       "strings"
        "sync"
        "sync/atomic"
+       "time"
 
        "github.com/pkg/errors"
-       "go.uber.org/multierr"
 
        "github.com/apache/skywalking-banyandb/api/common"
-       "github.com/apache/skywalking-banyandb/pkg/convert"
-       "github.com/apache/skywalking-banyandb/pkg/fs"
        "github.com/apache/skywalking-banyandb/pkg/logger"
+       "github.com/apache/skywalking-banyandb/pkg/timestamp"
 )
 
 // IndexGranularity denotes the granularity of the local index.
@@ -47,15 +48,13 @@ const (
        IndexGranularitySeries
 )
 
-var _ Database = (*database)(nil)
-
-// DatabaseOpts wraps options to create a tsdb.
-type DatabaseOpts struct {
-       Location         string
-       SegmentInterval  IntervalRule
-       TTL              IntervalRule
-       IndexGranularity IndexGranularity
-       ShardNum         uint32
+// TSDBOpts wraps options to create a tsdb.
+type TSDBOpts[T TSTable[T]] struct {
+       TSTableCreator  TSTableCreator[T]
+       Location        string
+       SegmentInterval IntervalRule
+       TTL             IntervalRule
+       ShardNum        uint32
 }
 
 type (
@@ -67,87 +66,28 @@ func GenerateSegID(unit IntervalUnit, suffix int) SegID {
        return SegID(unit)<<31 | ((SegID(suffix) << 1) >> 1)
 }
 
-func parseSuffix(id SegID) int {
-       return int((id << 1) >> 1)
-}
-
-func segIDToBytes(id SegID) []byte {
-       return convert.Uint32ToBytes(uint32(id))
-}
-
-func readSegID(data []byte, offset int) (SegID, int) {
-       end := offset + 4
-       return SegID(convert.BytesToUint32(data[offset:end])), end
-}
-
-type database struct {
-       fileSystem  fs.FileSystem
-       logger      *logger.Logger
-       index       *seriesIndex
-       location    string
-       sLst        []Shard
-       segmentSize IntervalRule
-       ttl         IntervalRule
+type database[T TSTable[T]] struct {
+       logger   *logger.Logger
+       index    *seriesIndex
+       location string
+       sLst     []*shard[T]
+       opts     TSDBOpts[T]
        sync.RWMutex
-       shardNum           uint32
-       shardCreationState uint32
-}
-
-func (d *database) CreateShardsAndGetByID(id common.ShardID) (Shard, error) {
-       if atomic.LoadUint32(&d.shardCreationState) != 0 {
-               return d.shard(id)
-       }
-       d.Lock()
-       defer d.Unlock()
-       if atomic.LoadUint32(&d.shardCreationState) != 0 {
-               return d.shard(id)
-       }
-       loadedShardsNum := len(d.sLst)
-       if loadedShardsNum < int(d.shardNum) {
-               _, err := createDatabase(d, loadedShardsNum)
-               if err != nil {
-                       return nil, errors.WithMessage(err, "create the 
database failed")
-               }
-       }
-       atomic.StoreUint32(&d.shardCreationState, 1)
-       return d.shard(id)
-}
-
-func (d *database) Shards() []Shard {
-       d.RLock()
-       defer d.RUnlock()
-       return d.sLst
-}
-
-func (d *database) Shard(id common.ShardID) (Shard, error) {
-       d.RLock()
-       defer d.RUnlock()
-       return d.shard(id)
-}
-
-func (d *database) shard(id common.ShardID) (Shard, error) {
-       if uint(id) >= uint(len(d.sLst)) {
-               return nil, ErrUnknownShard
-       }
-       return d.sLst[id], nil
+       sLen uint32
 }
 
-func (d *database) Close() error {
+func (d *database[T]) Close() error {
        d.Lock()
        defer d.Unlock()
-       var err error
        for _, s := range d.sLst {
-               innerErr := s.Close()
-               if innerErr != nil {
-                       err = multierr.Append(err, innerErr)
-               }
+               s.closer()
        }
-       return err
+       return nil
 }
 
-// OpenDatabase returns a new tsdb runtime. This constructor will create a new 
database if it's absent,
+// OpenTSDB returns a new tsdb runtime. This constructor will create a new 
database if it's absent,
 // or load an existing one.
-func OpenDatabase(ctx context.Context, opts DatabaseOpts) (Database, error) {
+func OpenTSDB[T TSTable[T]](ctx context.Context, opts TSDBOpts[T]) (TSDB[T], 
error) {
        if opts.SegmentInterval.Num == 0 {
                return nil, errors.Wrap(errOpenDatabase, "segment interval is 
absent")
        }
@@ -155,38 +95,101 @@ func OpenDatabase(ctx context.Context, opts DatabaseOpts) 
(Database, error) {
                return nil, errors.Wrap(errOpenDatabase, "ttl is absent")
        }
        p := common.GetPosition(ctx)
-       l := logger.Fetch(ctx, p.Database)
-       fileSystem := fs.NewLocalFileSystemWithLogger(l)
-       path := filepath.Clean(opts.Location)
-       fileSystem.Mkdir(path, dirPerm)
-       si, err := newSeriesIndex(ctx, path)
+       location := filepath.Clean(opts.Location)
+       lfs.MkdirIfNotExist(location, dirPerm)
+       si, err := newSeriesIndex(ctx, location)
        if err != nil {
                return nil, errors.Wrap(errOpenDatabase, 
errors.WithMessage(err, "create series index failed").Error())
        }
-       db := &database{
-               location:    path,
-               shardNum:    opts.ShardNum,
-               logger:      logger.Fetch(ctx, p.Database),
-               segmentSize: opts.SegmentInterval,
-               ttl:         opts.TTL,
-               fileSystem:  fileSystem,
-               index:       si,
+       db := &database[T]{
+               location: location,
+               logger:   logger.Fetch(ctx, p.Database),
+               index:    si,
+               opts:     opts,
        }
        db.logger.Info().Str("path", opts.Location).Msg("initialized")
+       if err = db.loadDatabase(); err != nil {
+               return nil, errors.Wrap(errOpenDatabase, 
errors.WithMessage(err, "load database failed").Error())
+       }
        return db, nil
 }
 
-func createDatabase(db *database, startID int) (Database, error) {
+func (d *database[T]) Register(shardID common.ShardID, series *Series) 
(*Series, error) {
        var err error
-       for i := startID; i < int(db.shardNum); i++ {
-               db.logger.Info().Int("shard_id", i).Msg("creating a shard")
-               // so, errNewShard := OpenShard(common.ShardID(i),
-               //      db.location, db.segmentSize, db.blockSize, db.ttl, 
defaultBlockQueueSize, defaultMaxBlockQueueSize, db.enableWAL)
-               // if errNewShard != nil {
-               //      err = multierr.Append(err, errNewShard)
-               //      continue
-               // }
-               // db.sLst = append(db.sLst, so)
+       if series, err = d.index.createPrimary(series); err != nil {
+               return nil, err
+       }
+       id := int(shardID)
+       if id < int(atomic.LoadUint32(&d.sLen)) {
+               return series, nil
+       }
+       d.Lock()
+       defer d.Unlock()
+       if id < len(d.sLst) {
+               return series, nil
+       }
+       d.logger.Info().Int("shard_id", id).Msg("creating a shard")
+       if err = d.registerShard(id); err != nil {
+               return nil, err
+       }
+       return series, nil
+}
+
+func (d *database[T]) CreateTSTableIfNotExist(shardID common.ShardID, ts 
time.Time) (TSTable[T], error) {
+       timeRange := timestamp.NewInclusiveTimeRange(ts, ts)
+       ss := d.sLst[shardID].segmentController.selectTSTables(timeRange)
+       if len(ss) > 0 {
+               return ss[0], nil
+       }
+       return d.sLst[shardID].segmentController.createTSTable(timeRange.Start)
+}
+
+func (d *database[T]) SelectTSTables(shardID common.ShardID, timeRange 
timestamp.TimeRange) ([]TSTable[T], error) {
+       if int(shardID) >= int(atomic.LoadUint32(&d.sLen)) {
+               return nil, ErrUnknownShard
+       }
+       return d.sLst[shardID].segmentController.selectTSTables(timeRange), nil
+}
+
+func (d *database[T]) registerShard(id int) error {
+       ctx := context.WithValue(context.Background(), logger.ContextKey, 
d.logger)
+       so, err := d.openShard(ctx, common.ShardID(id))
+       if err != nil {
+               return err
+       }
+       d.sLst = append(d.sLst, so)
+       d.sLen++
+       return nil
+}
+
+func (d *database[T]) loadDatabase() error {
+       d.Lock()
+       defer d.Unlock()
+       return walkDir(d.location, shardPathPrefix, func(suffix string) error {
+               shardID, err := strconv.Atoi(suffix)
+               if err != nil {
+                       return err
+               }
+               if shardID >= int(d.opts.ShardNum) {
+                       return nil
+               }
+               d.logger.Info().Int("shard_id", shardID).Msg("opening a existed 
shard")
+               return d.registerShard(shardID)
+       })
+}
+
+type walkFn func(suffix string) error
+
+func walkDir(root, prefix string, wf walkFn) error {
+       for _, f := range lfs.ReadDir(root) {
+               if !f.IsDir() || !strings.HasPrefix(f.Name(), prefix) {
+                       continue
+               }
+               segs := strings.Split(f.Name(), "-")
+               errWalk := wf(segs[len(segs)-1])
+               if errWalk != nil {
+                       return errors.WithMessagef(errWalk, "failed to load: 
%s", f.Name())
+               }
        }
-       return db, err
+       return nil
 }
diff --git a/pkg/fs/file_system.go b/pkg/fs/file_system.go
index 8d028fbf..ee1853ef 100644
--- a/pkg/fs/file_system.go
+++ b/pkg/fs/file_system.go
@@ -54,8 +54,14 @@ type File interface {
 
 // FileSystem operation interface.
 type FileSystem interface {
-       // Mkdir creates a new directory with the specified name and permission.
-       Mkdir(path string, permission Mode)
+       // MkdirIfNotExist creates a new directory with the specified name and 
permission if it does not exist.
+       // If the directory exists, it will do nothing.
+       MkdirIfNotExist(path string, permission Mode)
+       // MkdirPanicIfExist creates a new directory with the specified name 
and permission if it does not exist.
+       // If the directory exists, it will panic.
+       MkdirPanicIfExist(path string, permission Mode)
+       // ReadDir reads the directory named by dirname and returns a list of 
directory entries sorted by filename.
+       ReadDir(dirname string) []DirEntry
        // Create and open the file by specified name and mode.
        CreateFile(name string, permission Mode) (File, error)
        // Flush mode, which flushes all data to one file.
@@ -63,3 +69,12 @@ type FileSystem interface {
        // Delete the file.
        DeleteFile(name string) error
 }
+
+// DirEntry is the interface that wraps the basic information about a file or 
directory.
+type DirEntry interface {
+       // Name returns the name of the file or directory.
+       Name() string
+
+       // IsDir reports whether the entry describes a directory.
+       IsDir() bool
+}
diff --git a/pkg/fs/local_file_system.go b/pkg/fs/local_file_system.go
index fdba4e88..0144f8a8 100644
--- a/pkg/fs/local_file_system.go
+++ b/pkg/fs/local_file_system.go
@@ -74,12 +74,22 @@ func readErrorHandle(operation string, err error, name 
string, size int) (int, e
        }
 }
 
-// Mkdir implements FileSystem.
-func (fs *localFileSystem) Mkdir(path string, permission Mode) {
+func (fs *localFileSystem) MkdirIfNotExist(path string, permission Mode) {
        if fs.pathExist(path) {
                return
        }
-       if err := os.MkdirAll(path, 0o755); err != nil {
+       fs.mkdir(path, permission)
+}
+
+func (fs *localFileSystem) MkdirPanicIfExist(path string, permission Mode) {
+       if fs.pathExist(path) {
+               fs.logger.Panic().Str("path", path).Msg("directory is exist")
+       }
+       fs.mkdir(path, permission)
+}
+
+func (fs *localFileSystem) mkdir(path string, permission Mode) {
+       if err := os.MkdirAll(path, os.FileMode(permission)); err != nil {
                fs.logger.Panic().Str("path", path).Err(err).Msg("failed to 
create directory")
        }
        parentDirPath := filepath.Dir(path)
@@ -110,6 +120,18 @@ func (fs *localFileSystem) syncPath(path string) {
        }
 }
 
+func (fs *localFileSystem) ReadDir(dirname string) []DirEntry {
+       des, err := os.ReadDir(dirname)
+       if err != nil {
+               fs.logger.Panic().Str("dirname", dirname).Err(err).Msg("failed 
to read directory")
+       }
+       result := make([]DirEntry, len(des))
+       for i, de := range des {
+               result[i] = DirEntry(de)
+       }
+       return result
+}
+
 // CreateFile is used to create and open the file by specified name and mode.
 func (fs *localFileSystem) CreateFile(name string, permission Mode) (File, 
error) {
        file, err := os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 
os.FileMode(permission))

Reply via email to