LindaSummer commented on code in PR #3312:
URL: https://github.com/apache/kvrocks/pull/3312#discussion_r2884334698


##########
src/commands/cmd_tdigest.cc:
##########
@@ -492,6 +492,57 @@ class CommandTDigestMerge : public Commander {
   TDigestMergeOptions options_;
 };
 
+class CommandTDigestTrimmedMean : public Commander {
+ public:
+  Status Parse(const std::vector<std::string> &args) override {
+    if (args.size() != 4) {
+      return {Status::RedisParseErr, errWrongNumOfArguments};
+    }
+
+    key_name_ = args[1];
+
+    auto low_cut_quantile = ParseFloat(args[2]);
+    if (!low_cut_quantile) {
+      return {Status::RedisParseErr, errValueIsNotFloat};
+    }
+    low_cut_quantile_ = *low_cut_quantile;
+
+    auto high_cut_quantile = ParseFloat(args[3]);
+    if (!high_cut_quantile) {
+      return {Status::RedisParseErr, errValueIsNotFloat};
+    }
+    high_cut_quantile_ = *high_cut_quantile;

Review Comment:
   Check the validation of `high_cut_quantile` and `low_cut_quantile`.
   The parameter validation should be done in the earliest step rather than in 
the command processing.



##########
src/types/tdigest.h:
##########
@@ -309,3 +309,65 @@ inline Status TDigestRank(TD&& td, const 
std::vector<double>& inputs, std::vecto
   }
   return Status::OK();
 }
+
+template <typename TD>
+inline StatusOr<double> TDigestTrimmedMean(TD&& td, double low_cut_quantile, 
double high_cut_quantile) {
+  if (td.Size() == 0) {
+    return Status{Status::InvalidArgument, "empty tdigest"};
+  }
+
+  if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) {
+    return Status{Status::InvalidArgument, "low cut quantile must be between 0 
and 1"};
+  }
+  if (high_cut_quantile < 0.0 || high_cut_quantile > 1.0) {
+    return Status{Status::InvalidArgument, "high cut quantile must be between 
0 and 1"};
+  }
+  if (low_cut_quantile >= high_cut_quantile) {
+    return Status{Status::InvalidArgument, "low cut quantile must be less than 
high cut quantile"};
+  }

Review Comment:
   Move to the command parse step.
   We could add a guard here, but the validation should be in parsing step.



##########
src/types/tdigest.h:
##########
@@ -309,3 +309,65 @@ inline Status TDigestRank(TD&& td, const 
std::vector<double>& inputs, std::vecto
   }
   return Status::OK();
 }
+
+template <typename TD>
+inline StatusOr<double> TDigestTrimmedMean(TD&& td, double low_cut_quantile, 
double high_cut_quantile) {
+  if (td.Size() == 0) {
+    return Status{Status::InvalidArgument, "empty tdigest"};
+  }
+
+  if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) {
+    return Status{Status::InvalidArgument, "low cut quantile must be between 0 
and 1"};
+  }
+  if (high_cut_quantile < 0.0 || high_cut_quantile > 1.0) {
+    return Status{Status::InvalidArgument, "high cut quantile must be between 
0 and 1"};
+  }
+  if (low_cut_quantile >= high_cut_quantile) {
+    return Status{Status::InvalidArgument, "low cut quantile must be less than 
high cut quantile"};
+  }
+
+  double low_boundary = std::numeric_limits<double>::quiet_NaN();
+  double high_boundary = std::numeric_limits<double>::quiet_NaN();
+
+  if (low_cut_quantile == 0.0) {
+    low_boundary = td.Min();
+  } else {
+    auto low_result = TDigestQuantile(td, low_cut_quantile);
+    if (!low_result) {
+      return low_result;
+    }
+    low_boundary = *low_result;
+  }
+
+  if (high_cut_quantile == 1.0) {
+    high_boundary = td.Max();
+  } else {
+    auto high_result = TDigestQuantile(td, high_cut_quantile);

Review Comment:
   Iterate through the whole centroids to get centroids within the boundaries.
   `TDigestQuantile` would return an estimated linear value with solved edge 
cases rather than real centroids you need.



##########
tests/gocase/unit/type/tdigest/tdigest_test.go:
##########
@@ -717,6 +717,101 @@ func tdigestTests(t *testing.T, configs 
util.KvrocksServerConfigs) {
                        require.EqualValues(t, expected[i], rank, "REVRANK 
mismatch at index %d", i)
                }
        })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) {
+               emptyKey := "tdigest_empty"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, 
"compression", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", 
"0.9")
+               require.NoError(t, result.Err())
+               require.Equal(t, "nan", result.Val())
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) {
+               key := "tdigest_basic"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 1.0)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) {
+               key := "tdigest_no_trim"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0", "1")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 0.1)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) {
+               key := "tdigest_skewed"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "1", 
"1", "1", "1", "10", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.2", "0.8")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.Less(t, mean, 50.0)

Review Comment:
   Why we use `Less` rather than a precise result?



##########
tests/gocase/unit/type/tdigest/tdigest_test.go:
##########
@@ -717,6 +717,101 @@ func tdigestTests(t *testing.T, configs 
util.KvrocksServerConfigs) {
                        require.EqualValues(t, expected[i], rank, "REVRANK 
mismatch at index %d", i)
                }
        })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) {
+               emptyKey := "tdigest_empty"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, 
"compression", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", 
"0.9")
+               require.NoError(t, result.Err())
+               require.Equal(t, "nan", result.Val())
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) {
+               key := "tdigest_basic"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 1.0)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) {
+               key := "tdigest_no_trim"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0", "1")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 0.1)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) {
+               key := "tdigest_skewed"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "1", 
"1", "1", "1", "10", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.2", "0.8")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.Less(t, mean, 50.0)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN wrong number of arguments", func(t 
*testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, 
"TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key", "0.1").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key", "0.1", "0.9", "extra").Err(), errMsgWrongNumberArg)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN invalid quantile ranges", func(t 
*testing.T) {
+               key := "tdigest_invalid"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5").Err())
+
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.9", "0.1").Err(), "low cut quantile must be less than high cut 
quantile")

Review Comment:
   Error message could be constant string to reduce duplication.



##########
src/types/tdigest.h:
##########
@@ -309,3 +309,65 @@ inline Status TDigestRank(TD&& td, const 
std::vector<double>& inputs, std::vecto
   }
   return Status::OK();
 }
+
+template <typename TD>
+inline StatusOr<double> TDigestTrimmedMean(TD&& td, double low_cut_quantile, 
double high_cut_quantile) {
+  if (td.Size() == 0) {
+    return Status{Status::InvalidArgument, "empty tdigest"};
+  }
+
+  if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) {
+    return Status{Status::InvalidArgument, "low cut quantile must be between 0 
and 1"};
+  }
+  if (high_cut_quantile < 0.0 || high_cut_quantile > 1.0) {
+    return Status{Status::InvalidArgument, "high cut quantile must be between 
0 and 1"};
+  }
+  if (low_cut_quantile >= high_cut_quantile) {
+    return Status{Status::InvalidArgument, "low cut quantile must be less than 
high cut quantile"};
+  }
+
+  double low_boundary = std::numeric_limits<double>::quiet_NaN();
+  double high_boundary = std::numeric_limits<double>::quiet_NaN();
+
+  if (low_cut_quantile == 0.0) {

Review Comment:
   Use a more stable way of comparing doubles.



##########
tests/cppunit/types/tdigest_test.cc:
##########
@@ -524,3 +524,32 @@ TEST_F(RedisTDigestTest, ByRank_And_ByRevRank) {
   EXPECT_EQ(result[0], 1.0) << "Rank 0 should be minimum";
   EXPECT_TRUE(std::isinf(result[3])) << "Rank >= total_weight should be 
infinity";
 }
+
+TEST_F(RedisTDigestTest, TrimmedMean) {

Review Comment:
   Add cases for invalid arguments and more unordered and complex inputs.



##########
tests/gocase/unit/type/tdigest/tdigest_test.go:
##########
@@ -717,6 +717,101 @@ func tdigestTests(t *testing.T, configs 
util.KvrocksServerConfigs) {
                        require.EqualValues(t, expected[i], rank, "REVRANK 
mismatch at index %d", i)
                }
        })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) {
+               emptyKey := "tdigest_empty"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, 
"compression", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", 
"0.9")
+               require.NoError(t, result.Err())
+               require.Equal(t, "nan", result.Val())
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) {
+               key := "tdigest_basic"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 1.0)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) {
+               key := "tdigest_no_trim"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0", "1")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 0.1)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) {
+               key := "tdigest_skewed"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "1", 
"1", "1", "1", "10", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.2", "0.8")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.Less(t, mean, 50.0)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN wrong number of arguments", func(t 
*testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, 
"TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key", "0.1").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key", "0.1", "0.9", "extra").Err(), errMsgWrongNumberArg)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN invalid quantile ranges", func(t 
*testing.T) {
+               key := "tdigest_invalid"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5").Err())
+
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.9", "0.1").Err(), "low cut quantile must be less than high cut 
quantile")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.5", "0.5").Err(), "low cut quantile must be less than high cut 
quantile")
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) {
+               key := "tdigest_single"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "42").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 42.0, mean, 0.001)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with extreme trimming", func(t *testing.T) {
+               key := "tdigest_extreme"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.4", "0.6")
+               require.NoError(t, result.Err())
+               meanStr := result.Val().(string)
+               if meanStr == "nan" {

Review Comment:
   The result should not be `nan`.



##########
tests/gocase/unit/type/tdigest/tdigest_test.go:
##########
@@ -717,6 +717,101 @@ func tdigestTests(t *testing.T, configs 
util.KvrocksServerConfigs) {
                        require.EqualValues(t, expected[i], rank, "REVRANK 
mismatch at index %d", i)
                }
        })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) {
+               emptyKey := "tdigest_empty"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, 
"compression", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", 
"0.9")
+               require.NoError(t, result.Err())
+               require.Equal(t, "nan", result.Val())
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) {
+               key := "tdigest_basic"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 1.0)

Review Comment:
   Is the delta `1.0` too large for this test case?



##########
tests/gocase/unit/type/tdigest/tdigest_test.go:
##########
@@ -717,6 +717,101 @@ func tdigestTests(t *testing.T, configs 
util.KvrocksServerConfigs) {
                        require.EqualValues(t, expected[i], rank, "REVRANK 
mismatch at index %d", i)
                }
        })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) {
+               emptyKey := "tdigest_empty"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, 
"compression", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", 
"0.9")
+               require.NoError(t, result.Err())
+               require.Equal(t, "nan", result.Val())
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) {
+               key := "tdigest_basic"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 1.0)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) {
+               key := "tdigest_no_trim"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0", "1")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 0.1)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) {
+               key := "tdigest_skewed"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "1", 
"1", "1", "1", "10", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.2", "0.8")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.Less(t, mean, 50.0)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN wrong number of arguments", func(t 
*testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, 
"TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key", "0.1").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key", "0.1", "0.9", "extra").Err(), errMsgWrongNumberArg)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN invalid quantile ranges", func(t 
*testing.T) {
+               key := "tdigest_invalid"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5").Err())
+
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.9", "0.1").Err(), "low cut quantile must be less than high cut 
quantile")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.5", "0.5").Err(), "low cut quantile must be less than high cut 
quantile")
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) {
+               key := "tdigest_single"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "42").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 42.0, mean, 0.001)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with extreme trimming", func(t *testing.T) {
+               key := "tdigest_extreme"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.4", "0.6")
+               require.NoError(t, result.Err())
+               meanStr := result.Val().(string)
+               if meanStr == "nan" {
+                       return
+               }
+               mean, err := strconv.ParseFloat(meanStr, 64)
+               require.NoError(t, err)
+               require.Greater(t, mean, 0.0)

Review Comment:
   We should use precise value for test cases for stable and correction.



##########
tests/gocase/unit/type/tdigest/tdigest_test.go:
##########
@@ -717,6 +717,101 @@ func tdigestTests(t *testing.T, configs 
util.KvrocksServerConfigs) {
                        require.EqualValues(t, expected[i], rank, "REVRANK 
mismatch at index %d", i)
                }
        })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with non-existent key", func(t *testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"nonexistent", "0.1", "0.9").Err(), errMsgKeyNotExist)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with empty tdigest", func(t *testing.T) {
+               emptyKey := "tdigest_empty"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey, 
"compression", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", emptyKey, "0.1", 
"0.9")
+               require.NoError(t, result.Err())
+               require.Equal(t, "nan", result.Val())
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with basic data set", func(t *testing.T) {
+               key := "tdigest_basic"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 1.0)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with no trimming", func(t *testing.T) {
+               key := "tdigest_no_trim"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5", "6", "7", "8", "9", "10").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0", "1")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 5.5, mean, 0.1)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with skewed data", func(t *testing.T) {
+               key := "tdigest_skewed"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "1", 
"1", "1", "1", "10", "100").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.2", "0.8")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.Less(t, mean, 50.0)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN wrong number of arguments", func(t 
*testing.T) {
+               require.ErrorContains(t, rdb.Do(ctx, 
"TDIGEST.TRIMMED_MEAN").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key", "0.1").Err(), errMsgWrongNumberArg)
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
"key", "0.1", "0.9", "extra").Err(), errMsgWrongNumberArg)
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN invalid quantile ranges", func(t 
*testing.T) {
+               key := "tdigest_invalid"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "1", "2", 
"3", "4", "5").Err())
+
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "-0.1", "0.9").Err(), "low cut quantile must be between 0 and 1")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.1", "1.1").Err(), "high cut quantile must be between 0 and 1")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.9", "0.1").Err(), "low cut quantile must be less than high cut 
quantile")
+               require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", 
key, "0.5", "0.5").Err(), "low cut quantile must be less than high cut 
quantile")
+       })
+
+       t.Run("TDIGEST.TRIMMED_MEAN with single value", func(t *testing.T) {
+               key := "tdigest_single"
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, 
"compression", "100").Err())
+               require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "42").Err())
+
+               result := rdb.Do(ctx, "TDIGEST.TRIMMED_MEAN", key, "0.1", "0.9")
+               require.NoError(t, result.Err())
+               mean, err := strconv.ParseFloat(result.Val().(string), 64)
+               require.NoError(t, err)
+               require.InDelta(t, 42.0, mean, 0.001)

Review Comment:
   Why we don't use a stable precision for delta in all cases?



##########
src/types/tdigest.h:
##########
@@ -309,3 +309,65 @@ inline Status TDigestRank(TD&& td, const 
std::vector<double>& inputs, std::vecto
   }
   return Status::OK();
 }
+
+template <typename TD>
+inline StatusOr<double> TDigestTrimmedMean(TD&& td, double low_cut_quantile, 
double high_cut_quantile) {
+  if (td.Size() == 0) {
+    return Status{Status::InvalidArgument, "empty tdigest"};
+  }
+
+  if (low_cut_quantile < 0.0 || low_cut_quantile > 1.0) {
+    return Status{Status::InvalidArgument, "low cut quantile must be between 0 
and 1"};
+  }
+  if (high_cut_quantile < 0.0 || high_cut_quantile > 1.0) {
+    return Status{Status::InvalidArgument, "high cut quantile must be between 
0 and 1"};
+  }
+  if (low_cut_quantile >= high_cut_quantile) {
+    return Status{Status::InvalidArgument, "low cut quantile must be less than 
high cut quantile"};
+  }
+
+  double low_boundary = std::numeric_limits<double>::quiet_NaN();
+  double high_boundary = std::numeric_limits<double>::quiet_NaN();
+
+  if (low_cut_quantile == 0.0) {
+    low_boundary = td.Min();
+  } else {
+    auto low_result = TDigestQuantile(td, low_cut_quantile);
+    if (!low_result) {
+      return low_result;
+    }
+    low_boundary = *low_result;
+  }
+
+  if (high_cut_quantile == 1.0) {
+    high_boundary = td.Max();
+  } else {
+    auto high_result = TDigestQuantile(td, high_cut_quantile);

Review Comment:
   Plus, you have iterated the centroids twice after get the quantile.
   With directly iteration, just scanning for one time is enough.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to