Copilot commented on code in PR #3312:
URL: https://github.com/apache/kvrocks/pull/3312#discussion_r2882106812
##########
src/types/tdigest.h:
##########
@@ -309,3 +309,65 @@ inline Status TDigestRank(TD&& td, const
std::vector<double>& inputs, std::vecto
}
return Status::OK();
}
+
+template <typename TD>
+inline StatusOr<double> TDigestTrimmedMean(TD&& td, double low_cut_quantile,
double high_cut_quantile) {
+ if (td.Size() == 0) {
+ return Status{Status::InvalidArgument, "empty tdigest"};
+ }
+
+ if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) {
+ return Status{Status::InvalidArgument, "low cut quantile must be between 0
and 1"};
+ }
+ if (high_cut_quantile < 0.0 || high_cut_quantile > 1.0) {
+ return Status{Status::InvalidArgument, "high cut quantile must be between
0 and 1"};
+ }
+ if (low_cut_quantile >= high_cut_quantile) {
+ return Status{Status::InvalidArgument, "low cut quantile must be less than
high cut quantile"};
+ }
+
+ double low_boundary = std::numeric_limits<double>::quiet_NaN();
+ double high_boundary = std::numeric_limits<double>::quiet_NaN();
+
+ if (low_cut_quantile == 0.0) {
+ low_boundary = td.Min();
+ } else {
+ auto low_result = TDigestQuantile(td, low_cut_quantile);
+ if (!low_result) {
+ return low_result;
+ }
+ low_boundary = *low_result;
+ }
+
+ if (high_cut_quantile == 1.0) {
+ high_boundary = td.Max();
+ } else {
+ auto high_result = TDigestQuantile(td, high_cut_quantile);
+ if (!high_result) {
+ return high_result;
+ }
+ high_boundary = *high_result;
+ }
+
+ auto iter = td.Begin();
+ double total_weight_in_range = 0;
+ double weighted_sum = 0;
+
+ while (iter->Valid()) {
+ auto centroid = GET_OR_RET(iter->GetCentroid());
+
+ if ((low_cut_quantile == 0.0 && high_cut_quantile == 1.0) ||
+ (centroid.mean >= low_boundary && centroid.mean <= high_boundary)) {
+ total_weight_in_range += centroid.weight;
+ weighted_sum += centroid.mean * centroid.weight;
+ }
+
+ iter->Next();
+ }
+
+ if (total_weight_in_range == 0) {
Review Comment:
TDigestTrimmedMean can incorrectly return NaN when the low/high cut
boundaries fall between centroid means (e.g., after quantile interpolation).
The current logic only includes whole centroids whose mean is within
[low_boundary, high_boundary], so it may exclude all centroids even though the
quantile range contains weight. Consider computing overlap by weight position
(e.g., iterate centroids with cumulative weight and add partial centroid weight
for the boundary centroids) rather than filtering by centroid.mean.
```suggestion
// First, compute the total weight of the t-digest.
double total_weight = 0.0;
{
auto iter = td.Begin();
while (iter->Valid()) {
auto centroid = GET_OR_RET(iter->GetCentroid());
total_weight += centroid.weight;
iter->Next();
}
}
if (total_weight == 0.0) {
return std::numeric_limits<double>::quiet_NaN();
}
// If no trimming is requested, just return the global weighted mean.
if (low_cut_quantile == 0.0 && high_cut_quantile == 1.0) {
double weighted_sum = 0.0;
auto iter = td.Begin();
while (iter->Valid()) {
auto centroid = GET_OR_RET(iter->GetCentroid());
weighted_sum += centroid.mean * centroid.weight;
iter->Next();
}
return weighted_sum / total_weight;
}
// Compute rank boundaries in weight space.
const double low_rank = low_cut_quantile * total_weight;
const double high_rank = high_cut_quantile * total_weight;
double cumulative_weight = 0.0;
double total_weight_in_range = 0.0;
double weighted_sum = 0.0;
auto iter = td.Begin();
while (iter->Valid()) {
auto centroid = GET_OR_RET(iter->GetCentroid());
const double start_rank = cumulative_weight;
const double end_rank = cumulative_weight + centroid.weight;
// If this centroid is entirely before the trimmed region, skip it.
if (end_rank <= low_rank) {
cumulative_weight = end_rank;
iter->Next();
continue;
}
// If we've passed the trimmed region, we can stop.
if (start_rank >= high_rank) {
break;
}
// Compute overlap of this centroid's weight with [low_rank, high_rank).
double overlap_start = start_rank;
if (overlap_start < low_rank) {
overlap_start = low_rank;
}
double overlap_end = end_rank;
if (overlap_end > high_rank) {
overlap_end = high_rank;
}
const double overlap = overlap_end - overlap_start;
if (overlap > 0.0) {
total_weight_in_range += overlap;
weighted_sum += centroid.mean * overlap;
}
cumulative_weight = end_rank;
iter->Next();
}
if (total_weight_in_range == 0.0) {
```
##########
tests/gocase/unit/type/tdigest/tdigest_test.go:
##########
@@ -717,6 +717,101 @@ func tdigestTests(t *testing.T, configs
util.KvrocksServerConfigs) {
require.EqualValues(t, expected[i], rank, "REVRANK
mismatch at index %d", i)
}
})
+
+ t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) {
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) {
+ emptyKey := "tdigest_empty"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey,
"compression", "100").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1",
"0.9")
+ require.NoError(t, result.Err())
+ require.Equal(t, "nan", result.Val())
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) {
+ key := "tdigest_basic"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2",
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+ require.NoError(t, result.Err())
+ mean, err := strconv.ParseFloat(result.Val().(string), 64)
+ require.NoError(t, err)
+ require.InDelta(t, 5.5, mean, 1.0)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) {
+ key := "tdigest_no_trim"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2",
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0", "1")
+ require.NoError(t, result.Err())
+ mean, err := strconv.ParseFloat(result.Val().(string), 64)
+ require.NoError(t, err)
+ require.InDelta(t, 5.5, mean, 0.1)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) {
+ key := "tdigest_skewed"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "1",
"1", "1", "1", "10", "100").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.2", "0.8")
+ require.NoError(t, result.Err())
+ mean, err := strconv.ParseFloat(result.Val().(string), 64)
+ require.NoError(t, err)
+ require.Less(t, mean, 50.0)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN wrong number of arguments", func(t
*testing.T) {
+ require.ErrorContains(t, rdb.Do(ctx,
"TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg)
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
"key").Err(), errMsgWrongNumberArg)
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
"key", "0.1").Err(), errMsgWrongNumberArg)
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
"key", "0.1", "0.9", "extra").Err(), errMsgWrongNumberArg)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN invalid quantile ranges", func(t
*testing.T) {
+ key := "tdigest_invalid"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2",
"3", "4", "5").Err())
+
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
key, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1")
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
key, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1")
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
key, "0.9", "0.1").Err(), "low cut quantile must be less than high cut
quantile")
+ require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN",
key, "0.5", "0.5").Err(), "low cut quantile must be less than high cut
quantile")
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) {
+ key := "tdigest_single"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "42").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+ require.NoError(t, result.Err())
+ mean, err := strconv.ParseFloat(result.Val().(string), 64)
+ require.NoError(t, err)
+ require.InDelta(t, 42.0, mean, 0.001)
+ })
+
+ t.Run("TDIGEST.TRIMMED_MEAN with extreme trimming", func(t *testing.T) {
+ key := "tdigest_extreme"
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key,
"compression", "100").Err())
+ require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2",
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+ result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.4", "0.6")
+ require.NoError(t, result.Err())
+ meanStr := result.Val().(string)
+ if meanStr == "nan" {
+ return
+ }
+ mean, err := strconv.ParseFloat(meanStr, 64)
+ require.NoError(t, err)
+ require.Greater(t, mean, 0.0)
Review Comment:
This test allows "nan" and returns early, which can mask real correctness
issues (a non-empty digest with low_cut < high_cut should always have some
weight in the trimmed range). It would be better to assert the result is not
NaN for this dataset and verify it’s within an expected numeric range/value.
```suggestion
mean, err := strconv.ParseFloat(meanStr, 64)
require.NoError(t, err)
require.False(t, math.IsNaN(mean))
require.Greater(t, mean, 4.0)
require.Less(t, mean, 7.0)
```
##########
src/types/redis_tdigest.cc:
##########
@@ -759,6 +759,41 @@ rocksdb::Status
TDigest::applyNewCentroids(ObserverOrUniquePtr<rocksdb::WriteBat
return rocksdb::Status::OK();
}
+rocksdb::Status TDigest::TrimmedMean(engine::Context& ctx, const Slice&
digest_name, double low_cut_quantile,
+ double high_cut_quantile,
TDigestTrimmedMeanResult* result) {
Review Comment:
TDigest::TrimmedMean can leave TDigestTrimmedMeanResult populated with a
stale value if the caller reuses the result object: on success you only assign
result->mean when there are observations, and the empty-digest early return
doesn’t reset it. Reset/clear result->mean at function entry (and before the
early return) so the output is well-defined.
```suggestion
double high_cut_quantile,
TDigestTrimmedMeanResult* result) {
result->mean = 0;
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]