Jefffrey commented on code in PR #18837:
URL: https://github.com/apache/datafusion/pull/18837#discussion_r2548304856
##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
}
}
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+ Min,
+ Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+ aggregate_function: AggregateFunction,
+ info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+ let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+ let params = &aggregate_function.params;
+
+ if params.args.len() != 2 {
+ return Ok(original_expr);
+ }
+
+ let percentile_value = match extract_percentile_literal(¶ms.args[1]) {
+ Some(value) if (0.0..=1.0).contains(&value) => value,
+ _ => return Ok(original_expr),
+ };
+
+ let is_descending = params
+ .order_by
+ .first()
+ .map(|sort| !sort.asc)
+ .unwrap_or(false);
+
+ let rewrite_target = match classify_rewrite_target(percentile_value,
is_descending) {
+ Some(target) => target,
+ None => return Ok(original_expr),
+ };
+
+ let value_expr = params.args[0].clone();
+ let input_type = match info.get_data_type(&value_expr) {
+ Ok(data_type) => data_type,
+ Err(_) => return Ok(original_expr),
+ };
+
+ let expected_return_type = match percentile_cont_result_type(&input_type) {
+ Some(data_type) => data_type,
+ None => return Ok(original_expr),
+ };
+
+ let udaf = match rewrite_target {
+ PercentileRewriteTarget::Min => min_udaf(),
+ PercentileRewriteTarget::Max => max_udaf(),
+ };
+
+ let mut agg_arg = value_expr;
+ if expected_return_type != input_type {
+ agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg),
expected_return_type.clone()));
+ }
+
+ let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
+ udaf,
+ vec![agg_arg],
+ params.distinct,
+ params.filter.clone(),
+ vec![],
+ params.null_treatment,
+ ));
+ Ok(rewritten)
+}
+
+fn classify_rewrite_target(
+ percentile_value: f64,
+ is_descending: bool,
+) -> Option<PercentileRewriteTarget> {
+ if nearly_equals_fraction(percentile_value, 0.0) {
+ Some(if is_descending {
+ PercentileRewriteTarget::Max
+ } else {
+ PercentileRewriteTarget::Min
+ })
+ } else if nearly_equals_fraction(percentile_value, 1.0) {
+ Some(if is_descending {
+ PercentileRewriteTarget::Min
+ } else {
+ PercentileRewriteTarget::Max
+ })
+ } else {
+ None
+ }
+}
+
+fn nearly_equals_fraction(value: f64, target: f64) -> bool {
+ (value - target).abs() < PERCENTILE_LITERAL_EPSILON
+}
Review Comment:
I'm personally of the mind to check directly against 0.0 and 1.0 instead of
doing an epsilon check; I think it's more likely a user would input an expr
like `SELECT percentile_cont(column1, 0.0)` than doing something like `SELECT
percentile_cont(column1, expr)` where `expr` might be some math that could make
it `0.0000001` 🤔
##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
}
}
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+ Min,
+ Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+ aggregate_function: AggregateFunction,
+ info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+ let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+ let params = &aggregate_function.params;
+
+ if params.args.len() != 2 {
+ return Ok(original_expr);
+ }
+
+ let percentile_value = match extract_percentile_literal(¶ms.args[1]) {
+ Some(value) if (0.0..=1.0).contains(&value) => value,
+ _ => return Ok(original_expr),
+ };
+
+ let is_descending = params
+ .order_by
+ .first()
+ .map(|sort| !sort.asc)
+ .unwrap_or(false);
+
+ let rewrite_target = match classify_rewrite_target(percentile_value,
is_descending) {
+ Some(target) => target,
+ None => return Ok(original_expr),
+ };
+
+ let value_expr = params.args[0].clone();
+ let input_type = match info.get_data_type(&value_expr) {
+ Ok(data_type) => data_type,
+ Err(_) => return Ok(original_expr),
+ };
+
+ let expected_return_type = match percentile_cont_result_type(&input_type) {
+ Some(data_type) => data_type,
+ None => return Ok(original_expr),
+ };
+
+ let udaf = match rewrite_target {
+ PercentileRewriteTarget::Min => min_udaf(),
+ PercentileRewriteTarget::Max => max_udaf(),
+ };
+
+ let mut agg_arg = value_expr;
+ if expected_return_type != input_type {
+ agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg),
expected_return_type.clone()));
+ }
Review Comment:
Can we explain why this is necessary in a comment here?
##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
}
}
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+ Min,
+ Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+ aggregate_function: AggregateFunction,
+ info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+ let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+ let params = &aggregate_function.params;
+
+ if params.args.len() != 2 {
+ return Ok(original_expr);
+ }
Review Comment:
```suggestion
let [value, percentile] = take_function_args("percentile_cont",
¶ms.args)?;
```
More ergonomic this way; technically this error path should never occur as
the signature should already guard us by now.
##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
}
}
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+ Min,
+ Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+ aggregate_function: AggregateFunction,
+ info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+ let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+ let params = &aggregate_function.params;
+
+ if params.args.len() != 2 {
+ return Ok(original_expr);
+ }
+
+ let percentile_value = match extract_percentile_literal(¶ms.args[1]) {
+ Some(value) if (0.0..=1.0).contains(&value) => value,
+ _ => return Ok(original_expr),
+ };
+
+ let is_descending = params
+ .order_by
+ .first()
+ .map(|sort| !sort.asc)
+ .unwrap_or(false);
+
+ let rewrite_target = match classify_rewrite_target(percentile_value,
is_descending) {
+ Some(target) => target,
+ None => return Ok(original_expr),
+ };
+
+ let value_expr = params.args[0].clone();
+ let input_type = match info.get_data_type(&value_expr) {
+ Ok(data_type) => data_type,
+ Err(_) => return Ok(original_expr),
+ };
+
+ let expected_return_type = match percentile_cont_result_type(&input_type) {
+ Some(data_type) => data_type,
+ None => return Ok(original_expr),
+ };
+
+ let udaf = match rewrite_target {
+ PercentileRewriteTarget::Min => min_udaf(),
+ PercentileRewriteTarget::Max => max_udaf(),
+ };
+
+ let mut agg_arg = value_expr;
+ if expected_return_type != input_type {
+ agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg),
expected_return_type.clone()));
+ }
+
+ let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
+ udaf,
+ vec![agg_arg],
+ params.distinct,
+ params.filter.clone(),
+ vec![],
+ params.null_treatment,
+ ));
+ Ok(rewritten)
+}
+
+fn classify_rewrite_target(
+ percentile_value: f64,
+ is_descending: bool,
+) -> Option<PercentileRewriteTarget> {
+ if nearly_equals_fraction(percentile_value, 0.0) {
+ Some(if is_descending {
+ PercentileRewriteTarget::Max
+ } else {
+ PercentileRewriteTarget::Min
+ })
+ } else if nearly_equals_fraction(percentile_value, 1.0) {
+ Some(if is_descending {
+ PercentileRewriteTarget::Min
+ } else {
+ PercentileRewriteTarget::Max
+ })
+ } else {
+ None
+ }
+}
+
+fn nearly_equals_fraction(value: f64, target: f64) -> bool {
+ (value - target).abs() < PERCENTILE_LITERAL_EPSILON
+}
+
+fn percentile_cont_result_type(input_type: &DataType) -> Option<DataType> {
+ if !input_type.is_numeric() {
+ return None;
+ }
+
+ let result_type = match input_type {
+ DataType::Float16 | DataType::Float32 | DataType::Float64 =>
input_type.clone(),
+ DataType::Decimal32(_, _)
+ | DataType::Decimal64(_, _)
+ | DataType::Decimal128(_, _)
+ | DataType::Decimal256(_, _) => input_type.clone(),
+ DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ | DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64 => DataType::Float64,
+ _ => return None,
+ };
+
+ Some(result_type)
+}
+
+fn extract_percentile_literal(expr: &Expr) -> Option<f64> {
+ match expr {
+ Expr::Literal(value, _) => literal_scalar_to_f64(value),
+ Expr::Alias(alias) => extract_percentile_literal(alias.expr.as_ref()),
+ Expr::Cast(cast) => extract_percentile_literal(cast.expr.as_ref()),
+ Expr::TryCast(cast) => extract_percentile_literal(cast.expr.as_ref()),
Review Comment:
How strictly necessary are these other arms? Is checking only for `Literal`
not sufficient?
##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
}
}
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+ Min,
+ Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+ aggregate_function: AggregateFunction,
+ info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+ let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+ let params = &aggregate_function.params;
+
+ if params.args.len() != 2 {
+ return Ok(original_expr);
+ }
+
+ let percentile_value = match extract_percentile_literal(¶ms.args[1]) {
+ Some(value) if (0.0..=1.0).contains(&value) => value,
+ _ => return Ok(original_expr),
+ };
+
+ let is_descending = params
+ .order_by
+ .first()
+ .map(|sort| !sort.asc)
+ .unwrap_or(false);
+
+ let rewrite_target = match classify_rewrite_target(percentile_value,
is_descending) {
+ Some(target) => target,
+ None => return Ok(original_expr),
+ };
+
+ let value_expr = params.args[0].clone();
+ let input_type = match info.get_data_type(&value_expr) {
+ Ok(data_type) => data_type,
+ Err(_) => return Ok(original_expr),
+ };
+
+ let expected_return_type = match percentile_cont_result_type(&input_type) {
+ Some(data_type) => data_type,
+ None => return Ok(original_expr),
+ };
+
+ let udaf = match rewrite_target {
+ PercentileRewriteTarget::Min => min_udaf(),
+ PercentileRewriteTarget::Max => max_udaf(),
+ };
+
+ let mut agg_arg = value_expr;
+ if expected_return_type != input_type {
+ agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg),
expected_return_type.clone()));
+ }
+
+ let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
+ udaf,
+ vec![agg_arg],
+ params.distinct,
+ params.filter.clone(),
+ vec![],
+ params.null_treatment,
+ ));
+ Ok(rewritten)
+}
+
+fn classify_rewrite_target(
+ percentile_value: f64,
+ is_descending: bool,
+) -> Option<PercentileRewriteTarget> {
+ if nearly_equals_fraction(percentile_value, 0.0) {
+ Some(if is_descending {
+ PercentileRewriteTarget::Max
+ } else {
+ PercentileRewriteTarget::Min
+ })
+ } else if nearly_equals_fraction(percentile_value, 1.0) {
+ Some(if is_descending {
+ PercentileRewriteTarget::Min
+ } else {
+ PercentileRewriteTarget::Max
+ })
+ } else {
+ None
+ }
+}
+
+fn nearly_equals_fraction(value: f64, target: f64) -> bool {
+ (value - target).abs() < PERCENTILE_LITERAL_EPSILON
+}
+
+fn percentile_cont_result_type(input_type: &DataType) -> Option<DataType> {
Review Comment:
We should reuse the code from `return_type` if possible instead of
duplicating it here
https://github.com/apache/datafusion/blob/f1ecaccd183367086ecb5b7736d93b3aba109e01/datafusion/functions-aggregate/src/percentile_cont.rs#L232-L261
##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
}
}
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+ Min,
+ Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+ aggregate_function: AggregateFunction,
+ info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+ let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+ let params = &aggregate_function.params;
+
+ if params.args.len() != 2 {
+ return Ok(original_expr);
+ }
+
+ let percentile_value = match extract_percentile_literal(¶ms.args[1]) {
+ Some(value) if (0.0..=1.0).contains(&value) => value,
+ _ => return Ok(original_expr),
+ };
+
+ let is_descending = params
+ .order_by
+ .first()
+ .map(|sort| !sort.asc)
+ .unwrap_or(false);
+
+ let rewrite_target = match classify_rewrite_target(percentile_value,
is_descending) {
+ Some(target) => target,
+ None => return Ok(original_expr),
+ };
+
+ let value_expr = params.args[0].clone();
+ let input_type = match info.get_data_type(&value_expr) {
+ Ok(data_type) => data_type,
+ Err(_) => return Ok(original_expr),
+ };
+
+ let expected_return_type = match percentile_cont_result_type(&input_type) {
+ Some(data_type) => data_type,
+ None => return Ok(original_expr),
+ };
+
+ let udaf = match rewrite_target {
+ PercentileRewriteTarget::Min => min_udaf(),
+ PercentileRewriteTarget::Max => max_udaf(),
+ };
+
+ let mut agg_arg = value_expr;
+ if expected_return_type != input_type {
+ agg_arg = Expr::Cast(Cast::new(Box::new(agg_arg),
expected_return_type.clone()));
+ }
+
+ let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf(
+ udaf,
+ vec![agg_arg],
+ params.distinct,
+ params.filter.clone(),
+ vec![],
+ params.null_treatment,
+ ));
+ Ok(rewritten)
+}
+
+fn classify_rewrite_target(
+ percentile_value: f64,
+ is_descending: bool,
+) -> Option<PercentileRewriteTarget> {
+ if nearly_equals_fraction(percentile_value, 0.0) {
+ Some(if is_descending {
+ PercentileRewriteTarget::Max
+ } else {
+ PercentileRewriteTarget::Min
+ })
+ } else if nearly_equals_fraction(percentile_value, 1.0) {
+ Some(if is_descending {
+ PercentileRewriteTarget::Min
+ } else {
+ PercentileRewriteTarget::Max
+ })
+ } else {
+ None
+ }
+}
+
+fn nearly_equals_fraction(value: f64, target: f64) -> bool {
+ (value - target).abs() < PERCENTILE_LITERAL_EPSILON
+}
+
+fn percentile_cont_result_type(input_type: &DataType) -> Option<DataType> {
+ if !input_type.is_numeric() {
+ return None;
+ }
+
+ let result_type = match input_type {
+ DataType::Float16 | DataType::Float32 | DataType::Float64 =>
input_type.clone(),
+ DataType::Decimal32(_, _)
+ | DataType::Decimal64(_, _)
+ | DataType::Decimal128(_, _)
+ | DataType::Decimal256(_, _) => input_type.clone(),
+ DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ | DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64 => DataType::Float64,
+ _ => return None,
+ };
+
+ Some(result_type)
+}
+
+fn extract_percentile_literal(expr: &Expr) -> Option<f64> {
+ match expr {
+ Expr::Literal(value, _) => literal_scalar_to_f64(value),
+ Expr::Alias(alias) => extract_percentile_literal(alias.expr.as_ref()),
+ Expr::Cast(cast) => extract_percentile_literal(cast.expr.as_ref()),
+ Expr::TryCast(cast) => extract_percentile_literal(cast.expr.as_ref()),
+ _ => None,
+ }
+}
+
+fn literal_scalar_to_f64(value: &ScalarValue) -> Option<f64> {
Review Comment:
Can we have percentiles that are not of type `Flaot64`? I thought the
signature guarded us against this
https://github.com/apache/datafusion/blob/f1ecaccd183367086ecb5b7736d93b3aba109e01/datafusion/functions-aggregate/src/percentile_cont.rs#L142-L154
##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
}
}
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+ Min,
+ Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+ aggregate_function: AggregateFunction,
+ info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+ let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+ let params = &aggregate_function.params;
+
+ if params.args.len() != 2 {
+ return Ok(original_expr);
+ }
+
+ let percentile_value = match extract_percentile_literal(¶ms.args[1]) {
+ Some(value) if (0.0..=1.0).contains(&value) => value,
+ _ => return Ok(original_expr),
+ };
+
+ let is_descending = params
+ .order_by
+ .first()
+ .map(|sort| !sort.asc)
+ .unwrap_or(false);
+
+ let rewrite_target = match classify_rewrite_target(percentile_value,
is_descending) {
+ Some(target) => target,
+ None => return Ok(original_expr),
+ };
+
+ let value_expr = params.args[0].clone();
+ let input_type = match info.get_data_type(&value_expr) {
+ Ok(data_type) => data_type,
+ Err(_) => return Ok(original_expr),
+ };
Review Comment:
```suggestion
let input_type = match info.get_data_type(&value_expr)?;
```
##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -760,3 +914,80 @@ fn calculate_percentile<T: ArrowNumericType>(
}
}
}
+
+#[cfg(test)]
+mod tests {
Review Comment:
We should remove the unit tests if they duplicate the sqllogictests
##########
datafusion/functions-aggregate/src/percentile_cont.rs:
##########
@@ -367,6 +377,151 @@ impl AggregateUDFImpl for PercentileCont {
}
}
+const PERCENTILE_LITERAL_EPSILON: f64 = 1e-12;
+
+#[derive(Clone, Copy)]
+enum PercentileRewriteTarget {
+ Min,
+ Max,
+}
+
+#[expect(clippy::needless_pass_by_value)]
+fn simplify_percentile_cont_aggregate(
+ aggregate_function: AggregateFunction,
+ info: &dyn SimplifyInfo,
+) -> Result<Expr> {
+ let original_expr = Expr::AggregateFunction(aggregate_function.clone());
+ let params = &aggregate_function.params;
+
+ if params.args.len() != 2 {
+ return Ok(original_expr);
+ }
+
+ let percentile_value = match extract_percentile_literal(¶ms.args[1]) {
+ Some(value) if (0.0..=1.0).contains(&value) => value,
+ _ => return Ok(original_expr),
+ };
+
+ let is_descending = params
+ .order_by
+ .first()
+ .map(|sort| !sort.asc)
+ .unwrap_or(false);
+
+ let rewrite_target = match classify_rewrite_target(percentile_value,
is_descending) {
+ Some(target) => target,
+ None => return Ok(original_expr),
+ };
Review Comment:
I feel this should be folded directly into line 400 above, instead of
splitting it like this
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]