This is an automated email from the ASF dual-hosted git repository.

hanahmily pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git


The following commit(s) were added to refs/heads/main by this push:
     new dd101064 Apply traceIDFilter in the query process to perform filtering 
(#813)
dd101064 is described below

commit dd101064b9326522300e8d679e7a4e841292c608
Author: Huang Youliang <[email protected]>
AuthorDate: Tue Oct 21 21:32:04 2025 +0800

    Apply traceIDFilter in the query process to perform filtering (#813)
    
    * Apply traceIDFilter in the query process to perform filtering
    
    ---------
    
    Co-authored-by: Gao Hongtao <[email protected]>
    Co-authored-by: 吴晟 Wu Sheng <[email protected]>
---
 banyand/trace/block_writer.go  |   3 +-
 banyand/trace/query.go         |   5 ++-
 banyand/trace/query_test.go    |   2 +-
 banyand/trace/snapshot.go      |  19 +++++++-
 banyand/trace/snapshot_test.go | 100 ++++++++++++++++++++++++++++++++++++++++-
 banyand/trace/tstable_test.go  |   2 +-
 6 files changed, 122 insertions(+), 9 deletions(-)

diff --git a/banyand/trace/block_writer.go b/banyand/trace/block_writer.go
index 8bfb06f8..cb2ebdcc 100644
--- a/banyand/trace/block_writer.go
+++ b/banyand/trace/block_writer.go
@@ -22,6 +22,7 @@ import (
 
        "github.com/apache/skywalking-banyandb/banyand/internal/storage"
        "github.com/apache/skywalking-banyandb/pkg/compress/zstd"
+       "github.com/apache/skywalking-banyandb/pkg/convert"
        "github.com/apache/skywalking-banyandb/pkg/filter"
        "github.com/apache/skywalking-banyandb/pkg/fs"
        "github.com/apache/skywalking-banyandb/pkg/logger"
@@ -367,7 +368,7 @@ func (bw *blockWriter) Flush(pm *partMetadata, tf 
*traceIDFilter, tt *tagType) {
                tf.filter.SetN(len(bw.traceIDs))
                tf.filter.ResizeBits((len(bw.traceIDs)*filter.B + 63) / 64)
                for _, traceID := range bw.traceIDs {
-                       tf.filter.Add([]byte(traceID))
+                       tf.filter.Add(convert.StringToBytes(traceID))
                }
        }
        tt.copyFrom(bw.tagType)
diff --git a/banyand/trace/query.go b/banyand/trace/query.go
index 6da868fa..265e7840 100644
--- a/banyand/trace/query.go
+++ b/banyand/trace/query.go
@@ -97,7 +97,7 @@ func (t *trace) Query(ctx context.Context, tqo 
model.TraceQueryOptions) (model.T
                return nilResult, nil
        }
 
-       parts := t.attachSnapshots(&result, tables, qo.minTimestamp, 
qo.maxTimestamp)
+       parts := t.attachSnapshots(&result, tables, qo.minTimestamp, 
qo.maxTimestamp, qo.traceIDs)
 
        pipelineCtx, cancel := context.WithCancel(ctx)
        result.ctx = pipelineCtx
@@ -240,6 +240,7 @@ func (t *trace) attachSnapshots(
        tables []*tsTable,
        minTimestamp int64,
        maxTimestamp int64,
+       traceIDs []string,
 ) []*part {
        parts := make([]*part, 0)
        for i := range tables {
@@ -249,7 +250,7 @@ func (t *trace) attachSnapshots(
                }
 
                var count int
-               parts, count = s.getParts(parts, minTimestamp, maxTimestamp)
+               parts, count = s.getParts(parts, minTimestamp, maxTimestamp, 
traceIDs)
                if count < 1 {
                        s.decRef()
                        continue
diff --git a/banyand/trace/query_test.go b/banyand/trace/query_test.go
index b5bace8c..1194f177 100644
--- a/banyand/trace/query_test.go
+++ b/banyand/trace/query_test.go
@@ -98,7 +98,7 @@ func TestQueryResult(t *testing.T) {
                                s := tst.currentSnapshot()
                                require.NotNil(t, s)
                                defer s.decRef()
-                               pp, _ := s.getParts(nil, 
queryOpts.minTimestamp, queryOpts.maxTimestamp)
+                               pp, _ := s.getParts(nil, 
queryOpts.minTimestamp, queryOpts.maxTimestamp, []string{tt.traceID})
                                bma := generateBlockMetadataArray()
                                defer releaseBlockMetadataArray(bma)
                                ti := &tstIter{}
diff --git a/banyand/trace/snapshot.go b/banyand/trace/snapshot.go
index ca741782..915ab951 100644
--- a/banyand/trace/snapshot.go
+++ b/banyand/trace/snapshot.go
@@ -27,6 +27,7 @@ import (
        "github.com/pkg/errors"
 
        "github.com/apache/skywalking-banyandb/banyand/internal/storage"
+       "github.com/apache/skywalking-banyandb/pkg/convert"
        "github.com/apache/skywalking-banyandb/pkg/logger"
 )
 
@@ -59,14 +60,28 @@ type snapshot struct {
        ref int32
 }
 
-func (s *snapshot) getParts(dst []*part, minTimestamp int64, maxTimestamp 
int64) ([]*part, int) {
+func (s *snapshot) getParts(dst []*part, minTimestamp int64, maxTimestamp 
int64, traceIDs []string) ([]*part, int) {
+       shouldSkip := func(p *part) bool {
+               if p.traceIDFilter.filter == nil || len(traceIDs) == 0 {
+                       return false
+               }
+               for _, traceID := range traceIDs {
+                       if 
p.traceIDFilter.filter.MightContain(convert.StringToBytes(traceID)) {
+                               return false
+                       }
+               }
+               return true
+       }
+
        var count int
        for _, p := range s.parts {
                pm := p.p.partMetadata
                if maxTimestamp < pm.MinTimestamp || minTimestamp > 
pm.MaxTimestamp {
                        continue
                }
-               // TODO: filter parts
+               if shouldSkip(p.p) {
+                       continue
+               }
                dst = append(dst, p.p)
                count++
        }
diff --git a/banyand/trace/snapshot_test.go b/banyand/trace/snapshot_test.go
index d033fdd6..65082ac2 100644
--- a/banyand/trace/snapshot_test.go
+++ b/banyand/trace/snapshot_test.go
@@ -29,6 +29,8 @@ import (
        "github.com/apache/skywalking-banyandb/api/common"
        "github.com/apache/skywalking-banyandb/banyand/protector"
        "github.com/apache/skywalking-banyandb/pkg/bytes"
+       "github.com/apache/skywalking-banyandb/pkg/convert"
+       "github.com/apache/skywalking-banyandb/pkg/filter"
        "github.com/apache/skywalking-banyandb/pkg/fs"
        "github.com/apache/skywalking-banyandb/pkg/logger"
        "github.com/apache/skywalking-banyandb/pkg/test"
@@ -119,13 +121,107 @@ func TestSnapshotGetParts(t *testing.T) {
                        },
                        count: 2,
                },
+               {
+                       name: "Test with non-empty snapshot and matching 
traceID",
+                       snapshot: func() *snapshot {
+                               bf1 := filter.NewBloomFilter(0)
+                               bf1.SetN(2)
+                               bf1.ResizeBits((2*filter.B + 63) / 64)
+                               bf1.Add(convert.StringToBytes("trace1"))
+                               bf1.Add(convert.StringToBytes("trace2"))
+
+                               bf2 := filter.NewBloomFilter(0)
+                               bf2.SetN(1)
+                               bf2.ResizeBits((1*filter.B + 63) / 64)
+                               bf2.Add(convert.StringToBytes("trace3"))
+
+                               return &snapshot{
+                                       parts: []*partWrapper{
+                                               {
+                                                       p: &part{
+                                                               partMetadata: 
partMetadata{
+                                                                       
MinTimestamp: 0,
+                                                                       
MaxTimestamp: 5,
+                                                               },
+                                                               traceIDFilter: 
traceIDFilter{
+                                                                       filter: 
bf1,
+                                                               },
+                                                       },
+                                               },
+                                               {
+                                                       p: &part{
+                                                               partMetadata: 
partMetadata{
+                                                                       
MinTimestamp: 6,
+                                                                       
MaxTimestamp: 10,
+                                                               },
+                                                               traceIDFilter: 
traceIDFilter{
+                                                                       filter: 
bf2,
+                                                               },
+                                                       },
+                                               },
+                                       },
+                               }
+                       }(),
+                       dst: []*part{},
+                       opts: queryOptions{
+                               minTimestamp: 0,
+                               maxTimestamp: 10,
+                               traceIDs:     []string{"trace1"},
+                       },
+                       expected: []*part{
+                               {
+                                       partMetadata: partMetadata{
+                                               MinTimestamp: 0,
+                                               MaxTimestamp: 5,
+                                       },
+                               },
+                       },
+                       count: 1,
+               },
+               {
+                       name: "Test with non-empty snapshot and non-matching 
traceID",
+                       snapshot: func() *snapshot {
+                               bf := filter.NewBloomFilter(0)
+                               bf.SetN(2)
+                               bf.ResizeBits((2*filter.B + 63) / 64)
+                               bf.Add(convert.StringToBytes("trace1"))
+                               bf.Add(convert.StringToBytes("trace2"))
+
+                               return &snapshot{
+                                       parts: []*partWrapper{
+                                               {
+                                                       p: &part{
+                                                               partMetadata: 
partMetadata{
+                                                                       
MinTimestamp: 0,
+                                                                       
MaxTimestamp: 5,
+                                                               },
+                                                               traceIDFilter: 
traceIDFilter{
+                                                                       filter: 
bf,
+                                                               },
+                                                       },
+                                               },
+                                       },
+                               }
+                       }(),
+                       dst: []*part{},
+                       opts: queryOptions{
+                               minTimestamp: 0,
+                               maxTimestamp: 10,
+                               traceIDs:     []string{"trace0"},
+                       },
+                       expected: []*part{},
+                       count:    0,
+               },
        }
 
        for _, tt := range tests {
                t.Run(tt.name, func(t *testing.T) {
-                       result, count := tt.snapshot.getParts(tt.dst, 
tt.opts.minTimestamp, tt.opts.maxTimestamp)
-                       assert.Equal(t, tt.expected, result)
+                       result, count := tt.snapshot.getParts(tt.dst, 
tt.opts.minTimestamp, tt.opts.maxTimestamp, tt.opts.traceIDs)
                        assert.Equal(t, tt.count, count)
+                       require.Equal(t, len(tt.expected), len(result))
+                       for i := range tt.expected {
+                               assert.Equal(t, tt.expected[i].partMetadata, 
result[i].partMetadata)
+                       }
                })
        }
 }
diff --git a/banyand/trace/tstable_test.go b/banyand/trace/tstable_test.go
index c9ac7b5b..c64e4486 100644
--- a/banyand/trace/tstable_test.go
+++ b/banyand/trace/tstable_test.go
@@ -137,7 +137,7 @@ func Test_tstIter(t *testing.T) {
                        s = new(snapshot)
                }
                defer s.decRef()
-               pp, n := s.getParts(nil, tt.minTimestamp, tt.maxTimestamp)
+               pp, n := s.getParts(nil, tt.minTimestamp, tt.maxTimestamp, 
[]string{tt.tid})
                require.Equal(t, len(s.parts), n)
                ti := &tstIter{}
                ti.init(bma, pp, []string{tt.tid})

Reply via email to