Copilot commented on code in PR #3312:
URL: https://github.com/apache/kvrocks/pull/3312#discussion_r2882106812


##########
src/types/tdigest.h:
##########
@@ -309,3 +309,65 @@ inline Status TDigestRank(TD&& td, const 
std::vector<double>& inputs, std::vecto
   }
   return Status::OK();
 }
+
+template <typename TD>
+inline StatusOr<double> TDigestTrimmedMean(TD&& td, double low_cut_quantile, 
double high_cut_quantile) {
+  if (td.Size() == 0) {
+    return Status{Status::InvalidArgument, "empty tdigest"};
+  }
+
+  if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) {
+    return Status{Status::InvalidArgument, "low cut quantile must be between 0 
and 1"};
+  }
+  if (high_cut_quantile < 0.0 || high_cut_quantile > 1.0) {
+    return Status{Status::InvalidArgument, "high cut quantile must be between 
0 and 1"};
+  }
+  if (low_cut_quantile >= high_cut_quantile) {
+    return Status{Status::InvalidArgument, "low cut quantile must be less than 
high cut quantile"};
+  }
+
+  double low_boundary = std::numeric_limits<double>::quiet_NaN();
+  double high_boundary = std::numeric_limits<double>::quiet_NaN();
+
+  if (low_cut_quantile == 0.0) {
+    low_boundary = td.Min();
+  } else {
+    auto low_result = TDigestQuantile(td, low_cut_quantile);
+    if (!low_result) {
+      return low_result;
+    }
+    low_boundary = *low_result;
+  }
+
+  if (high_cut_quantile == 1.0) {
+    high_boundary = td.Max();
+  } else {
+    auto high_result = TDigestQuantile(td, high_cut_quantile);
+    if (!high_result) {
+      return high_result;
+    }
+    high_boundary = *high_result;
+  }
+
+  auto iter = td.Begin();
+  double total_weight_in_range = 0;
+  double weighted_sum = 0;
+
+  while (iter->Valid()) {
+    auto centroid = GET_OR_RET(iter->GetCentroid());
+
+    if ((low_cut_quantile == 0.0 && high_cut_quantile == 1.0) ||
+        (centroid.mean >= low_boundary && centroid.mean <= high_boundary)) {
+      total_weight_in_range += centroid.weight;
+      weighted_sum += centroid.mean * centroid.weight;
+    }
+
+    iter->Next();
+  }
+
+  if (total_weight_in_range == 0) {

Review Comment:
   TDigestTrimmedMean can incorrectly return NaN when the low/high cut 
boundaries fall between centroid means (e.g., after quantile interpolation). 
The current logic only includes whole centroids whose mean is within 
[low_boundary, high_boundary], so it may exclude all centroids even though the 
quantile range contains weight. Consider computing overlap by weight position 
(e.g., iterate centroids with cumulative weight and add partial centroid weight 
for the boundary centroids) rather than filtering by centroid.mean.
   ```suggestion
     // First, compute the total weight of the t-digest.
     double total_weight = 0.0;
     {
       auto iter = td.Begin();
       while (iter->Valid()) {
         auto centroid = GET_OR_RET(iter->GetCentroid());
         total_weight += centroid.weight;
         iter->Next();
       }
     }
   
     if (total_weight == 0.0) {
       return std::numeric_limits<double>::quiet_NaN();
     }
   
     // If no trimming is requested, just return the global weighted mean.
     if (low_cut_quantile == 0.0 && high_cut_quantile == 1.0) {
       double weighted_sum = 0.0;
       auto iter = td.Begin();
       while (iter->Valid()) {
         auto centroid = GET_OR_RET(iter->GetCentroid());
         weighted_sum += centroid.mean * centroid.weight;
         iter->Next();
       }
       return weighted_sum / total_weight;
     }
   
     // Compute rank boundaries in weight space.
     const double low_rank = low_cut_quantile * total_weight;
     const double high_rank = high_cut_quantile * total_weight;
   
     double cumulative_weight = 0.0;
     double total_weight_in_range = 0.0;
     double weighted_sum = 0.0;
   
     auto iter = td.Begin();
     while (iter->Valid()) {
       auto centroid = GET_OR_RET(iter->GetCentroid());
   
       const double start_rank = cumulative_weight;
       const double end_rank = cumulative_weight + centroid.weight;
   
       // If this centroid is entirely before the trimmed region, skip it.
       if (end_rank <= low_rank) {
         cumulative_weight = end_rank;
         iter->Next();
         continue;
       }
   
       // If we've passed the trimmed region, we can stop.
       if (start_rank >= high_rank) {
         break;
       }
   
       // Compute overlap of this centroid's weight with [low_rank, high_rank).
       double overlap_start = start_rank;
       if (overlap_start < low_rank) {
         overlap_start = low_rank;
       }
       double overlap_end = end_rank;
       if (overlap_end > high_rank) {
         overlap_end = high_rank;
       }
   
       const double overlap = overlap_end - overlap_start;
       if (overlap > 0.0) {
         total_weight_in_range += overlap;
         weighted_sum += centroid.mean * overlap;
       }
   
       cumulative_weight = end_rank;
       iter->Next();
     }
   
     if (total_weight_in_range == 0.0) {
   ```



##########
tests/gocase/unit/type/tdigest/tdigest_test.go:
##########
@@ -717,6 +717,101 @@ func tdigestTests(t *testing.T, configs 
util.KvrocksServerConfigs) {
                        require.EqualValues(t, expected[i], rank, "REVRANK 
mismatch at index %d", i)
                }
        })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) {
+               emptyKey := "tdigest_empty"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, 
"compression", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", 
"0.9")
+               require.NoError(t, result.Err())
+               require.Equal(t, "nan", result.Val())
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) {
+               key := "tdigest_basic"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 1.0)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) {
+               key := "tdigest_no_trim"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0", "1")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 0.1)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) {
+               key := "tdigest_skewed"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "1", 
"1", "1", "1", "10", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.2", "0.8")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.Less(t, mean, 50.0)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN wrong number of arguments", func(t 
*testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, 
"TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key", "0.1").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key", "0.1", "0.9", "extra").Err(), errMsgWrongNumberArg)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN invalid quantile ranges", func(t 
*testing.T) {
+               key := "tdigest_invalid"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5").Err())
+
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.9", "0.1").Err(), "low cut quantile must be less than high cut 
quantile")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.5", "0.5").Err(), "low cut quantile must be less than high cut 
quantile")
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) {
+               key := "tdigest_single"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "42").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 42.0, mean, 0.001)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with extreme trimming", func(t *testing.T) {
+               key := "tdigest_extreme"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.4", "0.6")
+               require.NoError(t, result.Err())
+               meanStr := result.Val().(string)
+               if meanStr == "nan" {
+                       return
+               }
+               mean, err := strconv.ParseFloat(meanStr, 64)
+               require.NoError(t, err)
+               require.Greater(t, mean, 0.0)

Review Comment:
   This test allows "nan" and returns early, which can mask real correctness 
issues (a non-empty digest with low_cut < high_cut should always have some 
weight in the trimmed range). It would be better to assert the result is not 
NaN for this dataset and verify it’s within an expected numeric range/value.
   ```suggestion
                mean, err := strconv.ParseFloat(meanStr, 64)
                require.NoError(t, err)
                require.False(t, math.IsNaN(mean))
                require.Greater(t, mean, 4.0)
                require.Less(t, mean, 7.0)
   ```



##########
src/types/redis_tdigest.cc:
##########
@@ -759,6 +759,41 @@ rocksdb::Status 
TDigest::applyNewCentroids(ObserverOrUniquePtr<rocksdb::WriteBat
   return rocksdb::Status::OK();
 }
 
+rocksdb::Status TDigest::TrimmedMean(engine::Context& ctx, const Slice& 
digest_name, double low_cut_quantile,
+                                     double high_cut_quantile, 
TDigestTrimmedMeanResult* result) {

Review Comment:
   TDigest::TrimmedMean can leave TDigestTrimmedMeanResult populated with a 
stale value if the caller reuses the result object: on success you only assign 
result->mean when there are observations, and the empty-digest early return 
doesn’t reset it. Reset/clear result->mean at function entry (and before the 
early return) so the output is well-defined.
   ```suggestion
                                        double high_cut_quantile, 
TDigestTrimmedMeanResult* result) {
     result->mean = 0;
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to