This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 1bcaac4835 Minor: support complex expr as the arg in the
ApproxPercentileCont function (#8580)
1bcaac4835 is described below
commit 1bcaac4835457627d881f755a87dbd140ec3388c
Author: Kun Liu <[email protected]>
AuthorDate: Wed Dec 20 10:11:29 2023 +0800
Minor: support complex expr as the arg in the ApproxPercentileCont function
(#8580)
* support complex lit expr for the arg
* enchancement the percentile
---
.../core/tests/dataframe/dataframe_functions.rs | 20 ++++++++++
.../src/aggregate/approx_percentile_cont.rs | 45 ++++++++++------------
2 files changed, 41 insertions(+), 24 deletions(-)
diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs
b/datafusion/core/tests/dataframe/dataframe_functions.rs
index 9677003ec2..fe56fc22ea 100644
--- a/datafusion/core/tests/dataframe/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe/dataframe_functions.rs
@@ -31,6 +31,7 @@ use datafusion::prelude::*;
use datafusion::execution::context::SessionContext;
use datafusion::assert_batches_eq;
+use datafusion_expr::expr::Alias;
use datafusion_expr::{approx_median, cast};
async fn create_test_table() -> Result<DataFrame> {
@@ -186,6 +187,25 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
assert_batches_eq!(expected, &batches);
+ // the arg2 parameter is a complex expr, but it can be evaluated to the
literal value
+ let alias_expr = Expr::Alias(Alias::new(
+ cast(lit(0.5), DataType::Float32),
+ None::<&str>,
+ "arg_2".to_string(),
+ ));
+ let expr = approx_percentile_cont(col("b"), alias_expr);
+ let df = create_test_table().await?;
+ let expected = [
+ "+--------------------------------------+",
+ "| APPROX_PERCENTILE_CONT(test.b,arg_2) |",
+ "+--------------------------------------+",
+ "| 10 |",
+ "+--------------------------------------+",
+ ];
+ let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
+
+ assert_batches_eq!(expected, &batches);
+
Ok(())
}
diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
index aa4749f64a..15c0fb3ace 100644
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
@@ -18,7 +18,7 @@
use crate::aggregate::tdigest::TryIntoF64;
use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE};
use crate::aggregate::utils::down_cast_any_ref;
-use crate::expressions::{format_state_name, Literal};
+use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::{
array::{
@@ -27,11 +27,13 @@ use arrow::{
},
datatypes::{DataType, Field},
};
+use arrow_array::RecordBatch;
+use arrow_schema::Schema;
use datafusion_common::{
downcast_value, exec_err, internal_err, not_impl_err, plan_err,
DataFusionError,
Result, ScalarValue,
};
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, ColumnarValue};
use std::{any::Any, iter, sync::Arc};
/// APPROX_PERCENTILE_CONT aggregate expression
@@ -131,18 +133,22 @@ impl PartialEq for ApproxPercentileCont {
}
}
+fn get_lit_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
+ let empty_schema = Schema::empty();
+ let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema));
+ let result = expr.evaluate(&empty_batch)?;
+ match result {
+ ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!(
+ "The expr {:?} can't be evaluated to scalar value",
+ expr
+ ))),
+ ColumnarValue::Scalar(scalar_value) => Ok(scalar_value),
+ }
+}
+
fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64>
{
- // Extract the desired percentile literal
- let lit = expr
- .as_any()
- .downcast_ref::<Literal>()
- .ok_or_else(|| {
- DataFusionError::Internal(
- "desired percentile argument must be float
literal".to_string(),
- )
- })?
- .value();
- let percentile = match lit {
+ let lit = get_lit_value(expr)?;
+ let percentile = match &lit {
ScalarValue::Float32(Some(q)) => *q as f64,
ScalarValue::Float64(Some(q)) => *q,
got => return not_impl_err!(
@@ -161,17 +167,8 @@ fn validate_input_percentile_expr(expr: &Arc<dyn
PhysicalExpr>) -> Result<f64> {
}
fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize>
{
- // Extract the desired percentile literal
- let lit = expr
- .as_any()
- .downcast_ref::<Literal>()
- .ok_or_else(|| {
- DataFusionError::Internal(
- "desired percentile argument must be float
literal".to_string(),
- )
- })?
- .value();
- let max_size = match lit {
+ let lit = get_lit_value(expr)?;
+ let max_size = match &lit {
ScalarValue::UInt8(Some(q)) => *q as usize,
ScalarValue::UInt16(Some(q)) => *q as usize,
ScalarValue::UInt32(Some(q)) => *q as usize,