This is an automated email from the ASF dual-hosted git repository.

hanahmily pushed a commit to branch sidx/interface
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git

commit 2bda889a85a90738e1c40fb96a117e887df21f5d
Author: Gao Hongtao <hanahm...@gmail.com>
AuthorDate: Tue Aug 19 11:18:43 2025 +0700

    Refactor cache implementation in segment and shard management
    
    - Updated `groupCache` to use a more generic `Cache` interface instead of a 
specific `serviceCache`, enhancing flexibility and abstraction.
    - Modified the `newSegmentController` and `NewShardCache` functions to 
accommodate the new cache structure.
    - Improved `groupCache` methods for better handling of cache operations, 
including `Get`, `Put`, and metrics retrieval methods.
---
 banyand/internal/encoding/tag_encoder_test.go | 188 ++++++++++++++++++++++++++
 banyand/internal/storage/segment.go           |   4 +-
 banyand/internal/storage/shard.go             |   6 +-
 banyand/internal/storage/tsdb.go              |  58 +++++++-
 4 files changed, 246 insertions(+), 10 deletions(-)

diff --git a/banyand/internal/encoding/tag_encoder_test.go 
b/banyand/internal/encoding/tag_encoder_test.go
new file mode 100644
index 00000000..a0b1a54d
--- /dev/null
+++ b/banyand/internal/encoding/tag_encoder_test.go
@@ -0,0 +1,188 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package encoding
+
+import (
+       "testing"
+
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
+
+       "github.com/apache/skywalking-banyandb/pkg/convert"
+       pbv1 "github.com/apache/skywalking-banyandb/pkg/pb/v1"
+)
+
+func TestEncodeDecodeTagValues_Int64_WithNilValues(t *testing.T) {
+       tests := []struct {
+               name   string
+               values [][]byte
+       }{
+               {
+                       name:   "single nil value",
+                       values: [][]byte{nil},
+               },
+               {
+                       name:   "mixed nil and valid int64 values",
+                       values: [][]byte{convert.Int64ToBytes(42), nil, 
convert.Int64ToBytes(100)},
+               },
+               {
+                       name:   "all nil values",
+                       values: [][]byte{nil, nil, nil},
+               },
+               {
+                       name:   "nil at beginning",
+                       values: [][]byte{nil, convert.Int64ToBytes(1), 
convert.Int64ToBytes(2)},
+               },
+               {
+                       name:   "nil at end",
+                       values: [][]byte{convert.Int64ToBytes(1), 
convert.Int64ToBytes(2), nil},
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       encoded, err := EncodeTagValues(tt.values, 
pbv1.ValueTypeInt64)
+                       require.NoError(t, err)
+                       require.NotNil(t, encoded)
+
+                       decoded, err := DecodeTagValues(encoded, 
pbv1.ValueTypeInt64, len(tt.values))
+                       require.NoError(t, err)
+                       require.Len(t, decoded, len(tt.values))
+
+                       for i, original := range tt.values {
+                               if original == nil {
+                                       assert.Nil(t, decoded[i], "nil value 
should be decoded as nil")
+                               } else {
+                                       assert.Equal(t, original, decoded[i], 
"non-nil value should remain unchanged")
+                               }
+                       }
+               })
+       }
+}
+
+func TestEncodeDecodeTagValues_Int64_WithNullStringValues(t *testing.T) {
+       tests := []struct {
+               name   string
+               values [][]byte
+       }{
+               {
+                       name:   "single null string value",
+                       values: [][]byte{[]byte("null")},
+               },
+               {
+                       name:   "mixed null string and valid int64 values",
+                       values: [][]byte{convert.Int64ToBytes(42), 
[]byte("null"), convert.Int64ToBytes(100)},
+               },
+               {
+                       name:   "all null string values",
+                       values: [][]byte{[]byte("null"), []byte("null"), 
[]byte("null")},
+               },
+               {
+                       name:   "null string at beginning",
+                       values: [][]byte{[]byte("null"), 
convert.Int64ToBytes(1), convert.Int64ToBytes(2)},
+               },
+               {
+                       name:   "null string at end",
+                       values: [][]byte{convert.Int64ToBytes(1), 
convert.Int64ToBytes(2), []byte("null")},
+               },
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       encoded, err := EncodeTagValues(tt.values, 
pbv1.ValueTypeInt64)
+                       require.NoError(t, err)
+                       require.NotNil(t, encoded)
+
+                       decoded, err := DecodeTagValues(encoded, 
pbv1.ValueTypeInt64, len(tt.values))
+                       require.NoError(t, err)
+                       require.Len(t, decoded, len(tt.values))
+
+                       for i, original := range tt.values {
+                               if string(original) == "null" {
+                                       assert.Equal(t, []byte("null"), 
decoded[i], "null string value should remain as 'null' string")
+                               } else {
+                                       assert.Equal(t, original, decoded[i], 
"non-null value should remain unchanged")
+                               }
+                       }
+               })
+       }
+}
+
+func TestEncodeDecodeTagValues_Int64_MixedNilAndNullString(t *testing.T) {
+       values := [][]byte{
+               convert.Int64ToBytes(42),
+               nil,
+               []byte("null"),
+               convert.Int64ToBytes(100),
+               nil,
+               []byte("null"),
+       }
+
+       encoded, err := EncodeTagValues(values, pbv1.ValueTypeInt64)
+       require.NoError(t, err)
+       require.NotNil(t, encoded)
+
+       decoded, err := DecodeTagValues(encoded, pbv1.ValueTypeInt64, 
len(values))
+       require.NoError(t, err)
+       require.Len(t, decoded, len(values))
+
+       expected := [][]byte{
+               convert.Int64ToBytes(42),
+               nil,
+               []byte("null"),
+               convert.Int64ToBytes(100),
+               nil,
+               []byte("null"),
+       }
+
+       for i, expectedValue := range expected {
+               assert.Equal(t, expectedValue, decoded[i], "value at index %d 
should match expected", i)
+       }
+}
+
+func TestEncodeDecodeTagValues_Int64_ValidValues(t *testing.T) {
+       values := [][]byte{
+               convert.Int64ToBytes(42),
+               convert.Int64ToBytes(-100),
+               convert.Int64ToBytes(0),
+               convert.Int64ToBytes(9223372036854775807),  // max int64
+               convert.Int64ToBytes(-9223372036854775808), // min int64
+       }
+
+       encoded, err := EncodeTagValues(values, pbv1.ValueTypeInt64)
+       require.NoError(t, err)
+       require.NotNil(t, encoded)
+
+       decoded, err := DecodeTagValues(encoded, pbv1.ValueTypeInt64, 
len(values))
+       require.NoError(t, err)
+       require.Len(t, decoded, len(values))
+
+       for i, original := range values {
+               assert.Equal(t, original, decoded[i], "valid int64 value should 
remain unchanged")
+       }
+}
+
+func TestEncodeDecodeTagValues_Int64_EmptyInput(t *testing.T) {
+       encoded, err := EncodeTagValues(nil, pbv1.ValueTypeInt64)
+       require.NoError(t, err)
+       assert.Nil(t, encoded)
+
+       decoded, err := DecodeTagValues(nil, pbv1.ValueTypeInt64, 0)
+       require.NoError(t, err)
+       assert.Nil(t, decoded)
+}
diff --git a/banyand/internal/storage/segment.go 
b/banyand/internal/storage/segment.go
index 914eb7ed..26ca1e00 100644
--- a/banyand/internal/storage/segment.go
+++ b/banyand/internal/storage/segment.go
@@ -319,7 +319,7 @@ type segmentController[T TSTable, O any] struct {
 
 func newSegmentController[T TSTable, O any](ctx context.Context, location 
string,
        l *logger.Logger, opts TSDBOpts[T, O], indexMetrics *inverted.Metrics, 
metrics Metrics,
-       idleTimeout time.Duration, lfs banyanfs.FileSystem, serviceCache 
*serviceCache, group string,
+       idleTimeout time.Duration, lfs banyanfs.FileSystem, cache Cache, group 
string,
 ) *segmentController[T, O] {
        clock, _ := timestamp.GetClock(ctx)
        p := common.GetPosition(ctx)
@@ -335,7 +335,7 @@ func newSegmentController[T TSTable, O any](ctx 
context.Context, location string
                db:           p.Database,
                idleTimeout:  idleTimeout,
                lfs:          lfs,
-               groupCache:   &groupCache{serviceCache, group},
+               groupCache:   &groupCache{cache, group},
        }
 }
 
diff --git a/banyand/internal/storage/shard.go 
b/banyand/internal/storage/shard.go
index 6b97f302..983a1fe0 100644
--- a/banyand/internal/storage/shard.go
+++ b/banyand/internal/storage/shard.go
@@ -37,10 +37,10 @@ type shardCache struct {
 
 // NewShardCache creates a new shard cache.
 func NewShardCache(group string, segmentID segmentID, shardID common.ShardID) 
Cache {
-       serviceCache := NewServiceCache().(*serviceCache)
+       serviceCache := NewServiceCache()
        groupCache := &groupCache{
-               serviceCache: serviceCache,
-               group:        group,
+               cache: serviceCache,
+               group: group,
        }
        segmentCache := &segmentCache{
                groupCache: groupCache,
diff --git a/banyand/internal/storage/tsdb.go b/banyand/internal/storage/tsdb.go
index 474bb255..6334f10f 100644
--- a/banyand/internal/storage/tsdb.go
+++ b/banyand/internal/storage/tsdb.go
@@ -77,17 +77,65 @@ func generateSegID(unit IntervalUnit, suffix int) segmentID 
{
 var _ Cache = (*groupCache)(nil)
 
 type groupCache struct {
-       *serviceCache
+       cache Cache
        group string
 }
 
-func (gc *groupCache) get(key EntryKey) Sizable {
+func (gc *groupCache) Get(key EntryKey) Sizable {
+       if gc.cache == nil {
+               return nil
+       }
+       key.group = gc.group
+       return gc.cache.Get(key)
+}
+
+func (gc *groupCache) Put(key EntryKey, value Sizable) {
+       if gc.cache == nil {
+               return
+       }
        key.group = gc.group
+       gc.cache.Put(key, value)
+}
+
+func (gc *groupCache) Close() {
+       if gc.cache != nil {
+               gc.cache.Close()
+       }
+}
+
+func (gc *groupCache) Requests() uint64 {
+       if gc.cache == nil {
+               return 0
+       }
+       return gc.cache.Requests()
+}
+
+func (gc *groupCache) Misses() uint64 {
+       if gc.cache == nil {
+               return 0
+       }
+       return gc.cache.Misses()
+}
+
+func (gc *groupCache) Entries() uint64 {
+       if gc.cache == nil {
+               return 0
+       }
+       return gc.cache.Entries()
+}
+
+func (gc *groupCache) Size() uint64 {
+       if gc.cache == nil {
+               return 0
+       }
+       return gc.cache.Size()
+}
+
+func (gc *groupCache) get(key EntryKey) Sizable {
        return gc.Get(key)
 }
 
 func (gc *groupCache) put(key EntryKey, value Sizable) {
-       key.group = gc.group
        gc.Put(key, value)
 }
 
@@ -146,9 +194,9 @@ func OpenTSDB[T TSTable, O any](ctx context.Context, opts 
TSDBOpts[T, O], cache
        if opts.StorageMetricsFactory != nil {
                indexMetrics = inverted.NewMetrics(opts.StorageMetricsFactory, 
common.SegLabelNames()...)
        }
-       var sc *serviceCache
+       var sc Cache
        if cache != nil {
-               sc = cache.(*serviceCache)
+               sc = cache
        }
        db := &database[T, O]{
                location:  location,

Reply via email to