This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 1bcaac4835 Minor: support complex expr as the arg in the 
ApproxPercentileCont function (#8580)
1bcaac4835 is described below

commit 1bcaac4835457627d881f755a87dbd140ec3388c
Author: Kun Liu <[email protected]>
AuthorDate: Wed Dec 20 10:11:29 2023 +0800

    Minor: support complex expr as the arg in the ApproxPercentileCont function 
(#8580)
    
    * support complex lit expr for the arg
    
    * enchancement the percentile
---
 .../core/tests/dataframe/dataframe_functions.rs    | 20 ++++++++++
 .../src/aggregate/approx_percentile_cont.rs        | 45 ++++++++++------------
 2 files changed, 41 insertions(+), 24 deletions(-)

diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs 
b/datafusion/core/tests/dataframe/dataframe_functions.rs
index 9677003ec2..fe56fc22ea 100644
--- a/datafusion/core/tests/dataframe/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe/dataframe_functions.rs
@@ -31,6 +31,7 @@ use datafusion::prelude::*;
 use datafusion::execution::context::SessionContext;
 
 use datafusion::assert_batches_eq;
+use datafusion_expr::expr::Alias;
 use datafusion_expr::{approx_median, cast};
 
 async fn create_test_table() -> Result<DataFrame> {
@@ -186,6 +187,25 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
 
     assert_batches_eq!(expected, &batches);
 
+    // the arg2 parameter is a complex expr, but it can be evaluated to the 
literal value
+    let alias_expr = Expr::Alias(Alias::new(
+        cast(lit(0.5), DataType::Float32),
+        None::<&str>,
+        "arg_2".to_string(),
+    ));
+    let expr = approx_percentile_cont(col("b"), alias_expr);
+    let df = create_test_table().await?;
+    let expected = [
+        "+--------------------------------------+",
+        "| APPROX_PERCENTILE_CONT(test.b,arg_2) |",
+        "+--------------------------------------+",
+        "| 10                                   |",
+        "+--------------------------------------+",
+    ];
+    let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
+
+    assert_batches_eq!(expected, &batches);
+
     Ok(())
 }
 
diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs 
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
index aa4749f64a..15c0fb3ace 100644
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
@@ -18,7 +18,7 @@
 use crate::aggregate::tdigest::TryIntoF64;
 use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE};
 use crate::aggregate::utils::down_cast_any_ref;
-use crate::expressions::{format_state_name, Literal};
+use crate::expressions::format_state_name;
 use crate::{AggregateExpr, PhysicalExpr};
 use arrow::{
     array::{
@@ -27,11 +27,13 @@ use arrow::{
     },
     datatypes::{DataType, Field},
 };
+use arrow_array::RecordBatch;
+use arrow_schema::Schema;
 use datafusion_common::{
     downcast_value, exec_err, internal_err, not_impl_err, plan_err, 
DataFusionError,
     Result, ScalarValue,
 };
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, ColumnarValue};
 use std::{any::Any, iter, sync::Arc};
 
 /// APPROX_PERCENTILE_CONT aggregate expression
@@ -131,18 +133,22 @@ impl PartialEq for ApproxPercentileCont {
     }
 }
 
+fn get_lit_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
+    let empty_schema = Schema::empty();
+    let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema));
+    let result = expr.evaluate(&empty_batch)?;
+    match result {
+        ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!(
+            "The expr {:?} can't be evaluated to scalar value",
+            expr
+        ))),
+        ColumnarValue::Scalar(scalar_value) => Ok(scalar_value),
+    }
+}
+
 fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> 
{
-    // Extract the desired percentile literal
-    let lit = expr
-        .as_any()
-        .downcast_ref::<Literal>()
-        .ok_or_else(|| {
-            DataFusionError::Internal(
-                "desired percentile argument must be float 
literal".to_string(),
-            )
-        })?
-        .value();
-    let percentile = match lit {
+    let lit = get_lit_value(expr)?;
+    let percentile = match &lit {
         ScalarValue::Float32(Some(q)) => *q as f64,
         ScalarValue::Float64(Some(q)) => *q,
         got => return not_impl_err!(
@@ -161,17 +167,8 @@ fn validate_input_percentile_expr(expr: &Arc<dyn 
PhysicalExpr>) -> Result<f64> {
 }
 
 fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> 
{
-    // Extract the desired percentile literal
-    let lit = expr
-        .as_any()
-        .downcast_ref::<Literal>()
-        .ok_or_else(|| {
-            DataFusionError::Internal(
-                "desired percentile argument must be float 
literal".to_string(),
-            )
-        })?
-        .value();
-    let max_size = match lit {
+    let lit = get_lit_value(expr)?;
+    let max_size = match &lit {
         ScalarValue::UInt8(Some(q)) => *q as usize,
         ScalarValue::UInt16(Some(q)) => *q as usize,
         ScalarValue::UInt32(Some(q)) => *q as usize,

Reply via email to