hanahmily commented on code in PR #957:
URL:
https://github.com/apache/skywalking-banyandb/pull/957#discussion_r2757048824
##########
banyand/measure/topn_post_processor.go:
##########
@@ -250,7 +250,7 @@ func (taggr *topNPostProcessor) Flush()
([]*topNAggregatorItem, error) {
continue
}
- aggrFunc, err :=
aggregation.NewFunc[int64](taggr.aggrFunc)
+ aggrFunc, err :=
aggregation.NewFunc[int64](taggr.aggrFunc, false)
Review Comment:
```suggestion
aggrFunc, err :=
aggregation.NewFunc[int64](taggr.aggrFunc)
```
##########
pkg/query/aggregation/aggregation.go:
##########
@@ -44,11 +44,16 @@ type Number interface {
}
// NewFunc returns a aggregation function based on function type.
-func NewFunc[N Number](af modelv1.AggregationFunction) (Func[N], error) {
+// If forDistributedMean is true and af is MEAN, it returns a
distributedMeanFunc that aggregates sum and count.
+func NewFunc[N Number](af modelv1.AggregationFunction, forDistributedMean
bool) (Func[N], error) {
Review Comment:
```suggestion
func NewFunc[N Number](af modelv1.AggregationFunction) (Func[N], error) {
```
##########
pkg/query/logical/measure/measure_plan_aggregation.go:
##########
@@ -177,7 +184,11 @@ func (ami *aggGroupIterator[N]) Current()
[]*measurev1.InternalDataPoint {
ami.err = err
return nil
}
- ami.aggrFunc.In(v)
+ if aggregation.IsDistributedMean(ami.aggrFunc) {
+ ami.aggrFunc.In(v, 1)
+ } else {
+ ami.aggrFunc.In(v)
+ }
Review Comment:
```suggestion
ami.aggrFunc.In(v)
```
##########
pkg/query/logical/measure/measure_plan_aggregation.go:
##########
@@ -189,17 +200,43 @@ func (ami *aggGroupIterator[N]) Current()
[]*measurev1.InternalDataPoint {
if resultDp == nil {
return nil
}
- val, err := aggregation.ToFieldValue(ami.aggrFunc.Val())
- if err != nil {
- ami.err = err
- return nil
- }
- resultDp.Fields = []*measurev1.DataPoint_Field{
- {
- Name: ami.aggregationFieldRef.Field.Name,
- Value: val,
- },
+ var fields []*measurev1.DataPoint_Field
+ sumVal, countVal, isDistributedMean :=
aggregation.GetSumCount(ami.aggrFunc)
+ if isDistributedMean {
+ sumFieldVal, sumErr := aggregation.ToFieldValue(sumVal)
+ if sumErr != nil {
+ ami.err = sumErr
+ return nil
+ }
+ countFieldVal, countErr := aggregation.ToFieldValue(countVal)
+ if countErr != nil {
+ ami.err = countErr
+ return nil
+ }
+ fields = []*measurev1.DataPoint_Field{
+ {
+ Name: ami.aggregationFieldRef.Field.Name +
"_sum",
Review Comment:
Extract the naming logic into a function to use in the distributed query.
##########
pkg/query/aggregation/function.go:
##########
@@ -115,3 +125,28 @@ func (m minFunc[N]) Val() N {
func (m *minFunc[N]) Reset() {
m.val = m.max
}
+
+// distributedMeanFunc is used for distributed mean aggregation on data nodes.
+type distributedMeanFunc[N Number] struct {
+ sum N
+ count N
+ zero N
+}
+
+func (m *distributedMeanFunc[N]) In(vals ...N) {
+ if len(vals) != 2 {
+ panic("expected 2 values for distributed mean: (sum, count)")
+ }
+ m.sum += vals[0]
+ m.count += vals[1]
+}
+
+func (m *distributedMeanFunc[N]) Val() N {
+ // For distributed mean, this value is not used
+ return m.zero
Review Comment:
```suggestion
if m.count == m.zero {
return m.zero
}
v := m.sum / m.count
if v < 1 {
return 1
}
return v
```
##########
pkg/query/logical/measure/measure_plan_distributed.go:
##########
@@ -568,10 +589,222 @@ func deduplicateAggregatedDataPointsWithShard(dataPoints
[]*measurev1.InternalDa
return result, nil
}
+// mergeNonGroupByAggregation merges aggregation results from multiple shards
when there's no groupBy.
+// Uses aggregation.Func to properly merge results according to the
aggregation function type.
+func mergeNonGroupByAggregation(
Review Comment:
Delete the function. The aggregation executor should aggregate the sum and
count from data nodes.
##########
pkg/query/logical/measure/measure_plan_aggregation.go:
##########
@@ -91,7 +99,7 @@ type aggregationPlan[N aggregation.Number] struct {
func newAggregationPlan[N aggregation.Number](gba *unresolvedAggregation,
prevPlan logical.Plan,
measureSchema logical.Schema, fieldRef *logical.FieldRef,
) (*aggregationPlan[N], error) {
- aggrFunc, err := aggregation.NewFunc[N](gba.aggrFunc)
+ aggrFunc, err := aggregation.NewFunc[N](gba.aggrFunc,
gba.distributedMean)
Review Comment:
```suggestion
aggrFunc, err := aggregation.NewFunc[N](gba.aggrFunc)
```
##########
pkg/query/aggregation/aggregation.go:
##########
@@ -44,11 +44,16 @@ type Number interface {
}
// NewFunc returns a aggregation function based on function type.
-func NewFunc[N Number](af modelv1.AggregationFunction) (Func[N], error) {
+// If forDistributedMean is true and af is MEAN, it returns a
distributedMeanFunc that aggregates sum and count.
+func NewFunc[N Number](af modelv1.AggregationFunction, forDistributedMean
bool) (Func[N], error) {
var result Func[N]
switch af {
case modelv1.AggregationFunction_AGGREGATION_FUNCTION_MEAN:
- result = &meanFunc[N]{zero: zero[N]()}
+ if forDistributedMean {
+ result = &distributedMeanFunc[N]{zero: zero[N]()}
+ } else {
+ result = &meanFunc[N]{zero: zero[N]()}
+ }
Review Comment:
```suggestion
result = &meanFunc[N]{zero: zero[N]()}
case modelv1.AggregationFunction_AGGREGATION_FUNCTION_DISTRIBUTED_MEAN
result = &distributedMeanFunc[N]{zero: zero[N]()}
```
##########
pkg/query/logical/measure/measure_plan_distributed.go:
##########
@@ -568,10 +589,222 @@ func deduplicateAggregatedDataPointsWithShard(dataPoints
[]*measurev1.InternalDa
return result, nil
}
+// mergeNonGroupByAggregation merges aggregation results from multiple shards
when there's no groupBy.
+// Uses aggregation.Func to properly merge results according to the
aggregation function type.
+func mergeNonGroupByAggregation(
+ dataPoints []*measurev1.InternalDataPoint,
+ agg *measurev1.QueryRequest_Aggregation,
+) ([]*measurev1.InternalDataPoint, error) {
+ if len(dataPoints) == 0 {
+ return nil, nil
+ }
+ if len(dataPoints) == 1 {
+ return dataPoints, nil
+ }
+ // Deduplicate by shard_id first (keep the one with highest version)
+ shardMap := make(map[uint32]*measurev1.InternalDataPoint)
+ for _, idp := range dataPoints {
+ existing, exists := shardMap[idp.ShardId]
+ if !exists || idp.GetDataPoint().Version >
existing.GetDataPoint().Version {
+ shardMap[idp.ShardId] = idp
+ }
+ }
+ // Now merge results from different shards
+ deduplicatedDps := make([]*measurev1.InternalDataPoint, 0,
len(shardMap))
+ for _, idp := range shardMap {
+ deduplicatedDps = append(deduplicatedDps, idp)
+ }
+ if len(deduplicatedDps) == 1 {
+ return deduplicatedDps, nil
+ }
+ // Determine field type from the first data point
+ fieldName := agg.FieldName
+ firstFieldVal := getFieldValue(deduplicatedDps[0].GetDataPoint(),
fieldName)
+ if firstFieldVal == nil {
+ return deduplicatedDps[:1], nil
+ }
+ // Create aggregation function based on field type
+ var isInt bool
+ switch firstFieldVal.Value.(type) {
+ case *modelv1.FieldValue_Int:
+ isInt = true
+ case *modelv1.FieldValue_Float:
+ isInt = false
+ default:
+ return deduplicatedDps[:1], nil
+ }
+ // Merge using aggregation.Func
+ if isInt {
+ return
mergeNonGroupByAggregationWithFunc[int64](deduplicatedDps, agg, fieldName)
+ }
+ return mergeNonGroupByAggregationWithFunc[float64](deduplicatedDps,
agg, fieldName)
+}
+
+// mergeNonGroupByAggregationWithFunc merges aggregation results using
aggregation.Func.
+func mergeNonGroupByAggregationWithFunc[N aggregation.Number](
+ dataPoints []*measurev1.InternalDataPoint,
+ agg *measurev1.QueryRequest_Aggregation,
+ fieldName string,
+) ([]*measurev1.InternalDataPoint, error) {
+ // Create aggregation function
+ aggrFunc, aggrErr := aggregation.NewFunc[N](agg.Function, false)
+ if aggrErr != nil {
+ return nil, fmt.Errorf("failed to create aggregation function:
%w", aggrErr)
+ }
+ // Feed aggregated values from each shard into the aggregation function
+ for _, idp := range dataPoints {
+ dp := idp.GetDataPoint()
+ fieldVal := getFieldValue(dp, fieldName)
+ if fieldVal == nil {
+ continue
+ }
+ val, fromErr := aggregation.FromFieldValue[N](fieldVal)
+ if fromErr != nil {
+ return nil, fmt.Errorf("failed to convert field value:
%w", fromErr)
+ }
+ aggrFunc.In(val)
+ }
+ resultVal := aggrFunc.Val()
+ resultFieldVal, toErr := aggregation.ToFieldValue(resultVal)
+ if toErr != nil {
+ return nil, fmt.Errorf("failed to convert result value: %w",
toErr)
+ }
+ // Create a new result data point (don't modify the original)
+ firstDp := dataPoints[0].GetDataPoint()
+ resultDp := &measurev1.DataPoint{
+ TagFamilies: firstDp.TagFamilies,
+ Fields: []*measurev1.DataPoint_Field{
+ {
+ Name: fieldName,
+ Value: resultFieldVal,
+ },
+ },
+ }
+ result := &measurev1.InternalDataPoint{
+ DataPoint: resultDp,
+ ShardId: dataPoints[0].ShardId,
+ }
+ return []*measurev1.InternalDataPoint{result}, nil
+}
+
+// getFieldValue extracts the field value from a data point.
+func getFieldValue(dp *measurev1.DataPoint, fieldName string)
*modelv1.FieldValue {
+ for _, field := range dp.Fields {
+ if field.Name == fieldName {
+ return field.Value
+ }
+ }
+ return nil
+}
+
// hashWithShard combines shard_id and group_key into a single hash.
func hashWithShard(shardID, groupKey uint64) uint64 {
h := uint64(offset64)
h = (h ^ shardID) * prime64
h = (h ^ groupKey) * prime64
return h
}
+
+type meanGroup struct {
Review Comment:
Delete the structure as well.
##########
pkg/query/logical/measure/measure_analyzer.go:
##########
@@ -119,10 +125,17 @@ func Analyze(criteria *measurev1.QueryRequest, metadata
[]*commonv1.Metadata, ss
}
if criteria.GetAgg() != nil {
+ // Check if this is a distributed mean aggregation that needs
to return sum and count
+ // This happens when the query is pushed down from liaison node
to data node
+ distributedMean := false
+ if isDistributed && criteria.GetAgg().GetFunction() ==
modelv1.AggregationFunction_AGGREGATION_FUNCTION_MEAN {
+ distributedMean = true
+ }
plan = newUnresolvedAggregation(plan,
logical.NewField(criteria.GetAgg().GetFieldName()),
criteria.GetAgg().GetFunction(),
criteria.GetGroupBy() != nil,
+ distributedMean,
)
Review Comment:
```suggestion
aggFunc := criteria.GetAgg().GetFunction()
if distributedMean {
fields = append(fields, logical.NewField("sum"))
fields = append(fields, logical.NewField("count"))
} else {
fields = append(fields,
logical.NewField(criteria.GetAgg().GetFieldName()))
}
if isDistributed && criteria.GetAgg().GetFunction() ==
modelv1.AggregationFunction_AGGREGATION_FUNCTION_MEAN {
aggFunc =
modelv1.AggregationFunction_AGGREGATION_FUNCTION_DISTRIBUTED_MEAN
}
plan = newUnresolvedAggregation(plan,
fields,
aggFunc,
criteria.GetGroupBy() != nil,
)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]