This is an automated email from the ASF dual-hosted git repository.

mneumann pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 183ff6643a Support `centroids` config for 
`approx_percentile_cont_with_weight` (#17003)
183ff6643a is described below

commit 183ff6643ace3b4708103b5fda8eecd6c31ca3f7
Author: Liam Bao <liam.zw....@gmail.com>
AuthorDate: Wed Aug 6 06:00:14 2025 -0400

    Support `centroids` config for `approx_percentile_cont_with_weight` (#17003)
    
    * Support centroids config for `approx_percentile_cont_with_weight`
    
    * Match two functions' signature
    
    * Update docs
    
    * Address comments and unify centroids config
---
 .../src/approx_percentile_cont.rs                  |  15 ++-
 .../src/approx_percentile_cont_with_weight.rs      | 111 +++++++++++++++------
 .../proto/tests/cases/roundtrip_logical_plan.rs    |  13 ++-
 datafusion/sqllogictest/test_files/aggregate.slt   |  10 ++
 docs/source/user-guide/expressions.md              |  42 ++++----
 docs/source/user-guide/sql/aggregate_functions.md  |  17 +++-
 6 files changed, 151 insertions(+), 57 deletions(-)

diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs 
b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
index 55c8c847ad..863ee15d89 100644
--- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs
+++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
@@ -77,8 +77,14 @@ pub fn approx_percentile_cont(
 #[user_doc(
     doc_section(label = "Approximate Functions"),
     description = "Returns the approximate percentile of input values using 
the t-digest algorithm.",
-    syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN 
GROUP (ORDER BY expression)",
+    syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN 
GROUP (ORDER BY expression)",
     sql_example = r#"```sql
+> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM 
table_name;
++------------------------------------------------------------------+
+| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
++------------------------------------------------------------------+
+| 65.0                                                             |
++------------------------------------------------------------------+
 > SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) 
 > FROM table_name;
 +-----------------------------------------------------------------------+
 | approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
@@ -313,7 +319,7 @@ impl AggregateUDFImpl for ApproxPercentileCont {
         }
         if arg_types.len() == 3 && !arg_types[2].is_integer() {
             return plan_err!(
-                "approx_percentile_cont requires integer max_size input types"
+                "approx_percentile_cont requires integer centroids input types"
             );
         }
         Ok(arg_types[0].clone())
@@ -360,6 +366,11 @@ impl ApproxPercentileAccumulator {
         }
     }
 
+    // public for approx_percentile_cont_with_weight
+    pub(crate) fn max_size(&self) -> usize {
+        self.digest.max_size()
+    }
+
     // public for approx_percentile_cont_with_weight
     pub fn merge_digests(&mut self, digests: &[TDigest]) {
         let digests = digests.iter().chain(std::iter::once(&self.digest));
diff --git 
a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs 
b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs
index ab847e8388..d30ea624ca 100644
--- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs
+++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs
@@ -25,32 +25,53 @@ use arrow::datatypes::FieldRef;
 use arrow::{array::ArrayRef, datatypes::DataType};
 use datafusion_common::ScalarValue;
 use datafusion_common::{not_impl_err, plan_err, Result};
+use datafusion_expr::expr::{AggregateFunction, Sort};
 use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
-use datafusion_expr::type_coercion::aggregates::NUMERICS;
+use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
 use datafusion_expr::Volatility::Immutable;
 use datafusion_expr::{
-    Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature,
-};
-use datafusion_functions_aggregate_common::tdigest::{
-    Centroid, TDigest, DEFAULT_MAX_SIZE,
+    Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, 
TypeSignature,
 };
+use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest};
 use datafusion_macros::user_doc;
 
 use crate::approx_percentile_cont::{ApproxPercentileAccumulator, 
ApproxPercentileCont};
 
-make_udaf_expr_and_func!(
+create_func!(
     ApproxPercentileContWithWeight,
-    approx_percentile_cont_with_weight,
-    expression weight percentile,
-    "Computes the approximate percentile continuous with weight of a set of 
numbers",
     approx_percentile_cont_with_weight_udaf
 );
 
+/// Computes the approximate percentile continuous with weight of a set of 
numbers
+pub fn approx_percentile_cont_with_weight(
+    order_by: Sort,
+    weight: Expr,
+    percentile: Expr,
+    centroids: Option<Expr>,
+) -> Expr {
+    let expr = order_by.expr.clone();
+
+    let args = if let Some(centroids) = centroids {
+        vec![expr, weight, percentile, centroids]
+    } else {
+        vec![expr, weight, percentile]
+    };
+
+    Expr::AggregateFunction(AggregateFunction::new_udf(
+        approx_percentile_cont_with_weight_udaf(),
+        args,
+        false,
+        None,
+        vec![order_by],
+        None,
+    ))
+}
+
 /// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression
 #[user_doc(
     doc_section(label = "Approximate Functions"),
     description = "Returns the weighted approximate percentile of input values 
using the t-digest algorithm.",
-    syntax_example = "approx_percentile_cont_with_weight(weight, percentile) 
WITHIN GROUP (ORDER BY expression)",
+    syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, 
centroids]) WITHIN GROUP (ORDER BY expression)",
     sql_example = r#"```sql
 > SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP 
 > (ORDER BY column_name) FROM table_name;
 
+---------------------------------------------------------------------------------------------+
@@ -58,6 +79,12 @@ make_udaf_expr_and_func!(
 
+---------------------------------------------------------------------------------------------+
 | 78.5                                                                         
               |
 
+---------------------------------------------------------------------------------------------+
+> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN 
GROUP (ORDER BY column_name) FROM table_name;
++--------------------------------------------------------------------------------------------------+
+| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP 
(ORDER BY column_name) |
++--------------------------------------------------------------------------------------------------+
+| 78.5                                                                         
                    |
++--------------------------------------------------------------------------------------------------+
 ```"#,
     standard_argument(name = "expression", prefix = "The"),
     argument(
@@ -67,6 +94,10 @@ make_udaf_expr_and_func!(
     argument(
         name = "percentile",
         description = "Percentile to compute. Must be a float value between 0 
and 1 (inclusive)."
+    ),
+    argument(
+        name = "centroids",
+        description = "Number of centroids to use in the t-digest algorithm. 
_Default is 100_. A higher number results in more accurate approximation but 
requires more memory."
     )
 )]
 pub struct ApproxPercentileContWithWeight {
@@ -91,21 +122,26 @@ impl Default for ApproxPercentileContWithWeight {
 impl ApproxPercentileContWithWeight {
     /// Create a new [`ApproxPercentileContWithWeight`] aggregate function.
     pub fn new() -> Self {
+        let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() 
+ 1));
+        // Accept any numeric value paired with weight and float64 percentile
+        for num in NUMERICS {
+            variants.push(TypeSignature::Exact(vec![
+                num.clone(),
+                num.clone(),
+                DataType::Float64,
+            ]));
+            // Additionally accept an integer number of centroids for T-Digest
+            for int in INTEGERS {
+                variants.push(TypeSignature::Exact(vec![
+                    num.clone(),
+                    num.clone(),
+                    DataType::Float64,
+                    int.clone(),
+                ]));
+            }
+        }
         Self {
-            signature: Signature::one_of(
-                // Accept any numeric value paired with a float64 percentile
-                NUMERICS
-                    .iter()
-                    .map(|t| {
-                        TypeSignature::Exact(vec![
-                            t.clone(),
-                            t.clone(),
-                            DataType::Float64,
-                        ])
-                    })
-                    .collect(),
-                Immutable,
-            ),
+            signature: Signature::one_of(variants, Immutable),
             approx_percentile_cont: ApproxPercentileCont::new(),
         }
     }
@@ -138,6 +174,11 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight {
         if arg_types[2] != DataType::Float64 {
             return plan_err!("approx_percentile_cont_with_weight requires 
float64 percentile input types");
         }
+        if arg_types.len() == 4 && !arg_types[3].is_integer() {
+            return plan_err!(
+                "approx_percentile_cont_with_weight requires integer centroids 
input types"
+            );
+        }
         Ok(arg_types[0].clone())
     }
 
@@ -148,17 +189,25 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight {
             );
         }
 
-        if acc_args.exprs.len() != 3 {
+        if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 {
             return plan_err!(
-                "approx_percentile_cont_with_weight requires three arguments: 
value, weight, percentile"
+                "approx_percentile_cont_with_weight requires three or four 
arguments: value, weight, percentile[, centroids]"
             );
         }
 
         let sub_args = AccumulatorArgs {
-            exprs: &[
-                Arc::clone(&acc_args.exprs[0]),
-                Arc::clone(&acc_args.exprs[2]),
-            ],
+            exprs: if acc_args.exprs.len() == 4 {
+                &[
+                    Arc::clone(&acc_args.exprs[0]), // value
+                    Arc::clone(&acc_args.exprs[2]), // percentile
+                    Arc::clone(&acc_args.exprs[3]), // centroids
+                ]
+            } else {
+                &[
+                    Arc::clone(&acc_args.exprs[0]), // value
+                    Arc::clone(&acc_args.exprs[2]), // percentile
+                ]
+            },
             ..acc_args
         };
         let approx_percentile_cont_accumulator =
@@ -244,7 +293,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator {
         let mut digests: Vec<TDigest> = vec![];
         for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
             digests.push(TDigest::new_with_centroid(
-                DEFAULT_MAX_SIZE,
+                self.approx_percentile_cont_accumulator.max_size(),
                 Centroid::new(*mean, *weight),
             ))
         }
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 170c2675f7..96d4ea7642 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -982,7 +982,18 @@ async fn roundtrip_expr_api() -> Result<()> {
         approx_median(lit(2)),
         approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None),
         approx_percentile_cont(lit(2).sort(true, false), lit(0.5), 
Some(lit(50))),
-        approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)),
+        approx_percentile_cont_with_weight(
+            lit(2).sort(true, false),
+            lit(1),
+            lit(0.5),
+            None,
+        ),
+        approx_percentile_cont_with_weight(
+            lit(2).sort(true, false),
+            lit(1),
+            lit(0.5),
+            Some(lit(50)),
+        ),
         grouping(lit(1)),
         bit_and(lit(2)),
         bit_or(lit(2)),
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 882de5dc54..1af0bbf6e8 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -1840,6 +1840,16 @@ c 123
 d 124
 e 115
 
+# approx_percentile_cont_with_weight with centroids
+query TI
+SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP 
(ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1
+----
+a 74
+b 68
+c 123
+d 124
+e 115
+
 # csv_query_sum_crossjoin
 query TTI
 SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN 
aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1
diff --git a/docs/source/user-guide/expressions.md 
b/docs/source/user-guide/expressions.md
index 03ab86eeb8..abf0286fa8 100644
--- a/docs/source/user-guide/expressions.md
+++ b/docs/source/user-guide/expressions.md
@@ -285,27 +285,27 @@ select log(-1), log(0), sqrt(-1);
 
 ## Aggregate Functions
 
-| Syntax                                                            | 
Description                                                                     
        |
-| ----------------------------------------------------------------- | 
---------------------------------------------------------------------------------------
 |
-| avg(expr)                                                         | 
Сalculates the average value for `expr`.                                        
        |
-| approx_distinct(expr)                                             | 
Calculates an approximate count of the number of distinct values for `expr`.    
        |
-| approx_median(expr)                                               | 
Calculates an approximation of the median for `expr`.                           
        |
-| approx_percentile_cont(expr, percentile)                          | 
Calculates an approximation of the specified `percentile` for `expr`.           
        |
-| approx_percentile_cont_with_weight(expr, weight_expr, percentile) | 
Calculates an approximation of the specified `percentile` for `expr` and 
`weight_expr`. |
-| bit_and(expr)                                                     | Computes 
the bitwise AND of all non-null input values for `expr`.                       |
-| bit_or(expr)                                                      | Computes 
the bitwise OR of all non-null input values for `expr`.                        |
-| bit_xor(expr)                                                     | Computes 
the bitwise exclusive OR of all non-null input values for `expr`.              |
-| bool_and(expr)                                                    | Returns 
true if all non-null input values (`expr`) are true, otherwise false.           
|
-| bool_or(expr)                                                     | Returns 
true if any non-null input value (`expr`) is true, otherwise false.             
|
-| count(expr)                                                       | Returns 
the number of rows for `expr`.                                                  
|
-| count_distinct                                                    | Creates 
an expression to represent the count(distinct) aggregate function               
|
-| cube(exprs)                                                       | Creates 
a grouping set for all combination of `exprs`                                   
|
-| grouping_set(exprs)                                               | Create a 
grouping set.                                                                  |
-| max(expr)                                                         | Finds 
the maximum value of `expr`.                                                    
  |
-| median(expr)                                                      | 
Сalculates the median of `expr`.                                                
        |
-| min(expr)                                                         | Finds 
the minimum value of `expr`.                                                    
  |
-| rollup(exprs)                                                     | Creates 
a grouping set for rollup sets.                                                 
|
-| sum(expr)                                                         | 
Сalculates the sum of `expr`.                                                   
        |
+| Syntax                                                                       
   | Description                                                                
                                                                              |
+| 
------------------------------------------------------------------------------- 
| 
--------------------------------------------------------------------------------------------------------------------------------------------------------
 |
+| avg(expr)                                                                    
   | Сalculates the average value for `expr`.                                   
                                                                              |
+| approx_distinct(expr)                                                        
   | Calculates an approximate count of the number of distinct values for 
`expr`.                                                                         
    |
+| approx_median(expr)                                                          
   | Calculates an approximation of the median for `expr`.                      
                                                                              |
+| approx_percentile_cont(expr, percentile [, centroids])                       
   | Calculates an approximation of the specified `percentile` for `expr`. 
Optional `centroids` parameter controls accuracy (default: 100).                
   |
+| approx_percentile_cont_with_weight(expr, weight_expr, percentile [, 
centroids]) | Calculates an approximation of the specified `percentile` for 
`expr` and `weight_expr`. Optional `centroids` parameter controls accuracy 
(default: 100). |
+| bit_and(expr)                                                                
   | Computes the bitwise AND of all non-null input values for `expr`.          
                                                                              |
+| bit_or(expr)                                                                 
   | Computes the bitwise OR of all non-null input values for `expr`.           
                                                                              |
+| bit_xor(expr)                                                                
   | Computes the bitwise exclusive OR of all non-null input values for `expr`. 
                                                                              |
+| bool_and(expr)                                                               
   | Returns true if all non-null input values (`expr`) are true, otherwise 
false.                                                                          
  |
+| bool_or(expr)                                                                
   | Returns true if any non-null input value (`expr`) is true, otherwise 
false.                                                                          
    |
+| count(expr)                                                                  
   | Returns the number of rows for `expr`.                                     
                                                                              |
+| count_distinct                                                               
   | Creates an expression to represent the count(distinct) aggregate function  
                                                                              |
+| cube(exprs)                                                                  
   | Creates a grouping set for all combination of `exprs`                      
                                                                              |
+| grouping_set(exprs)                                                          
   | Create a grouping set.                                                     
                                                                              |
+| max(expr)                                                                    
   | Finds the maximum value of `expr`.                                         
                                                                              |
+| median(expr)                                                                 
   | Сalculates the median of `expr`.                                           
                                                                              |
+| min(expr)                                                                    
   | Finds the minimum value of `expr`.                                         
                                                                              |
+| rollup(exprs)                                                                
   | Creates a grouping set for rollup sets.                                    
                                                                              |
+| sum(expr)                                                                    
   | Сalculates the sum of `expr`.                                              
                                                                              |
 
 ## Aggregate Function Builder
 
diff --git a/docs/source/user-guide/sql/aggregate_functions.md 
b/docs/source/user-guide/sql/aggregate_functions.md
index 88241770a4..e3396cd7bd 100644
--- a/docs/source/user-guide/sql/aggregate_functions.md
+++ b/docs/source/user-guide/sql/aggregate_functions.md
@@ -1039,7 +1039,7 @@ approx_median(expression)
 Returns the approximate percentile of input values using the t-digest 
algorithm.
 
 ```sql
-approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY 
expression)
+approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY 
expression)
 ```
 
 #### Arguments
@@ -1051,6 +1051,12 @@ approx_percentile_cont(percentile, centroids) WITHIN 
GROUP (ORDER BY expression)
 #### Example
 
 ```sql
+> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM 
table_name;
++------------------------------------------------------------------+
+| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) |
++------------------------------------------------------------------+
+| 65.0                                                             |
++------------------------------------------------------------------+
 > SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) 
 > FROM table_name;
 +-----------------------------------------------------------------------+
 | approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) |
@@ -1064,7 +1070,7 @@ approx_percentile_cont(percentile, centroids) WITHIN 
GROUP (ORDER BY expression)
 Returns the weighted approximate percentile of input values using the t-digest 
algorithm.
 
 ```sql
-approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY 
expression)
+approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN 
GROUP (ORDER BY expression)
 ```
 
 #### Arguments
@@ -1072,6 +1078,7 @@ approx_percentile_cont_with_weight(weight, percentile) 
WITHIN GROUP (ORDER BY ex
 - **expression**: The expression to operate on. Can be a constant, column, or 
function, and any combination of operators.
 - **weight**: Expression to use as weight. Can be a constant, column, or 
function, and any combination of arithmetic operators.
 - **percentile**: Percentile to compute. Must be a float value between 0 and 1 
(inclusive).
+- **centroids**: Number of centroids to use in the t-digest algorithm. 
_Default is 100_. A higher number results in more accurate approximation but 
requires more memory.
 
 #### Example
 
@@ -1082,4 +1089,10 @@ approx_percentile_cont_with_weight(weight, percentile) 
WITHIN GROUP (ORDER BY ex
 
+---------------------------------------------------------------------------------------------+
 | 78.5                                                                         
               |
 
+---------------------------------------------------------------------------------------------+
+> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN 
GROUP (ORDER BY column_name) FROM table_name;
++--------------------------------------------------------------------------------------------------+
+| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP 
(ORDER BY column_name) |
++--------------------------------------------------------------------------------------------------+
+| 78.5                                                                         
                    |
++--------------------------------------------------------------------------------------------------+
 ```


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to