This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new c884bdb692 Convert ApproxPercentileCont and
ApproxPercentileContWithWeight to UDAF (#10917)
c884bdb692 is described below
commit c884bdb692020d8feb9599c9e455a406b98a6f46
Author: Jax Liu <[email protected]>
AuthorDate: Sun Jun 16 22:42:44 2024 +0800
Convert ApproxPercentileCont and ApproxPercentileContWithWeight to UDAF
(#10917)
* pass logical expr of arguments for udaf
* implement approx_percentile_cont udaf
* register udaf
* remove ApproxPercentileCont
* convert with_wegiht to udaf and remove original
* fix conflict
* fix compile check
* fix doc and testing
* evaluate args through physical plan
* public use Literal
* fix tests
* rollback the experimental tests
* remove unused import
* rename args and inline code
* remove unnecessary partial eq trait
* fix error message
---
.../src/physical_optimizer/aggregate_statistics.rs | 1 +
.../combine_partial_final_agg.rs | 2 +
.../core/src/physical_optimizer/test_utils.rs | 1 +
datafusion/core/src/physical_planner.rs | 16 +-
.../core/tests/dataframe/dataframe_functions.rs | 6 +-
datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs | 1 +
datafusion/core/tests/fuzz_cases/window_fuzz.rs | 4 +
datafusion/expr/src/aggregate_function.rs | 50 +----
datafusion/expr/src/expr_fn.rs | 28 ---
datafusion/expr/src/function.rs | 4 +-
datafusion/expr/src/type_coercion/aggregates.rs | 82 -------
.../functions-aggregate/src/approx_median.rs | 10 -
.../src/approx_percentile_cont.rs | 235 ++++++++++++++++++-
.../src}/approx_percentile_cont_with_weight.rs | 159 ++++++++-----
datafusion/functions-aggregate/src/count.rs | 2 +-
datafusion/functions-aggregate/src/lib.rs | 7 +
datafusion/functions-aggregate/src/stddev.rs | 4 +-
datafusion/optimizer/src/analyzer/type_coercion.rs | 25 ---
.../physical-expr-common/src/aggregate/mod.rs | 13 +-
.../physical-expr-common/src/expressions/mod.rs | 2 +-
datafusion/physical-expr-common/src/utils.rs | 26 ++-
.../src/aggregate/approx_percentile_cont.rs | 249 ---------------------
datafusion/physical-expr/src/aggregate/build_in.rs | 101 +--------
datafusion/physical-expr/src/aggregate/mod.rs | 2 -
datafusion/physical-expr/src/expressions/mod.rs | 4 +-
datafusion/physical-plan/src/aggregates/mod.rs | 8 +-
.../src/windows/bounded_window_agg_exec.rs | 6 +-
datafusion/physical-plan/src/windows/mod.rs | 3 +
datafusion/proto/proto/datafusion.proto | 4 +-
datafusion/proto/src/generated/pbjson.rs | 6 -
datafusion/proto/src/generated/prost.rs | 12 +-
datafusion/proto/src/logical_plan/from_proto.rs | 6 -
datafusion/proto/src/logical_plan/to_proto.rs | 10 -
datafusion/proto/src/physical_plan/from_proto.rs | 4 +-
datafusion/proto/src/physical_plan/mod.rs | 5 +-
datafusion/proto/src/physical_plan/to_proto.rs | 18 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 25 +--
.../proto/tests/cases/roundtrip_physical_plan.rs | 2 +
datafusion/sqllogictest/test_files/aggregate.slt | 14 +-
39 files changed, 443 insertions(+), 714 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index eeacc48b85..ca1582bcb3 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -390,6 +390,7 @@ pub(crate) mod tests {
&[self.column()],
&[],
&[],
+ &[],
schema,
self.column_name(),
false,
diff --git
a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
index 38b92959e8..b57f36f728 100644
--- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
+++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
@@ -315,6 +315,7 @@ mod tests {
&[expr],
&[],
&[],
+ &[],
schema,
name,
false,
@@ -404,6 +405,7 @@ mod tests {
&[col("b", &schema)?],
&[],
&[],
+ &[],
&schema,
"Sum(b)",
false,
diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs
b/datafusion/core/src/physical_optimizer/test_utils.rs
index 154e77cd23..5320938d2e 100644
--- a/datafusion/core/src/physical_optimizer/test_utils.rs
+++ b/datafusion/core/src/physical_optimizer/test_utils.rs
@@ -245,6 +245,7 @@ pub fn bounded_window_exec(
"count".to_owned(),
&[col(col_name, &schema).unwrap()],
&[],
+ &[],
&sort_exprs,
Arc::new(WindowFrame::new(Some(false))),
schema.as_ref(),
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index 4f91875950..404bcbb2e7 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -1766,7 +1766,8 @@ pub fn create_window_expr_with_name(
window_frame,
null_treatment,
}) => {
- let args = create_physical_exprs(args, logical_schema,
execution_props)?;
+ let physical_args =
+ create_physical_exprs(args, logical_schema, execution_props)?;
let partition_by =
create_physical_exprs(partition_by, logical_schema,
execution_props)?;
let order_by =
@@ -1780,13 +1781,13 @@ pub fn create_window_expr_with_name(
}
let window_frame = Arc::new(window_frame.clone());
- let ignore_nulls = null_treatment
- .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
+ let ignore_nulls =
null_treatment.unwrap_or(NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;
windows::create_window_expr(
fun,
name,
- &args,
+ &physical_args,
+ args,
&partition_by,
&order_by,
window_frame,
@@ -1837,7 +1838,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
order_by,
null_treatment,
}) => {
- let args =
+ let physical_args =
create_physical_exprs(args, logical_input_schema,
execution_props)?;
let filter = match filter {
Some(e) => Some(create_physical_expr(
@@ -1867,7 +1868,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
let agg_expr = aggregates::create_aggregate_expr(
fun,
*distinct,
- &args,
+ &physical_args,
&ordering_reqs,
physical_input_schema,
name,
@@ -1889,7 +1890,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = udaf::create_aggregate_expr(
fun,
- &args,
+ &physical_args,
+ args,
&sort_exprs,
&ordering_reqs,
physical_input_schema,
diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs
b/datafusion/core/tests/dataframe/dataframe_functions.rs
index b05769a6ce..1c55c48fea 100644
--- a/datafusion/core/tests/dataframe/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe/dataframe_functions.rs
@@ -33,7 +33,7 @@ use datafusion::assert_batches_eq;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::Alias;
use datafusion_expr::ExprSchemable;
-use datafusion_functions_aggregate::expr_fn::approx_median;
+use datafusion_functions_aggregate::expr_fn::{approx_median,
approx_percentile_cont};
fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
@@ -363,7 +363,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
let expected = [
"+---------------------------------------------+",
- "| APPROX_PERCENTILE_CONT(test.b,Float64(0.5)) |",
+ "| approx_percentile_cont(test.b,Float64(0.5)) |",
"+---------------------------------------------+",
"| 10 |",
"+---------------------------------------------+",
@@ -384,7 +384,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
let df = create_test_table().await?;
let expected = [
"+--------------------------------------+",
- "| APPROX_PERCENTILE_CONT(test.b,arg_2) |",
+ "| approx_percentile_cont(test.b,arg_2) |",
"+--------------------------------------+",
"| 10 |",
"+--------------------------------------+",
diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
index c76c1fc2c7..a04f4f3491 100644
--- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
@@ -108,6 +108,7 @@ async fn run_aggregate_test(input1: Vec<RecordBatch>,
group_by_columns: Vec<&str
&[col("d", &schema).unwrap()],
&[],
&[],
+ &[],
&schema,
"sum1",
false,
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index 4358691ee5..5bd19850ca 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -252,6 +252,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
let partitionby_exprs = vec![];
let orderby_exprs = vec![];
+ let logical_exprs = vec![];
// Window frame starts with "UNBOUNDED PRECEDING":
let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None));
@@ -283,6 +284,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
&window_fn,
fn_name.to_string(),
&args,
+ &logical_exprs,
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame),
@@ -699,6 +701,7 @@ async fn run_window_test(
&window_fn,
fn_name.clone(),
&args,
+ &[],
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
@@ -717,6 +720,7 @@ async fn run_window_test(
&window_fn,
fn_name,
&args,
+ &[],
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
diff --git a/datafusion/expr/src/aggregate_function.rs
b/datafusion/expr/src/aggregate_function.rs
index 81562bf124..441e8953df 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -21,7 +21,7 @@ use std::sync::Arc;
use std::{fmt, str::FromStr};
use crate::utils;
-use crate::{type_coercion::aggregates::*, Signature, TypeSignature,
Volatility};
+use crate::{type_coercion::aggregates::*, Signature, Volatility};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError,
Result};
@@ -45,10 +45,6 @@ pub enum AggregateFunction {
NthValue,
/// Correlation
Correlation,
- /// Approximate continuous percentile function
- ApproxPercentileCont,
- /// Approximate continuous percentile function with weight
- ApproxPercentileContWithWeight,
/// Grouping
Grouping,
/// Bit And
@@ -75,8 +71,6 @@ impl AggregateFunction {
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
Correlation => "CORR",
- ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
- ApproxPercentileContWithWeight =>
"APPROX_PERCENTILE_CONT_WITH_WEIGHT",
Grouping => "GROUPING",
BitAnd => "BIT_AND",
BitOr => "BIT_OR",
@@ -113,11 +107,6 @@ impl FromStr for AggregateFunction {
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
- // approximate
- "approx_percentile_cont" =>
AggregateFunction::ApproxPercentileCont,
- "approx_percentile_cont_with_weight" => {
- AggregateFunction::ApproxPercentileContWithWeight
- }
// other
"grouping" => AggregateFunction::Grouping,
_ => {
@@ -170,10 +159,6 @@ impl AggregateFunction {
coerced_data_types[0].clone(),
true,
)))),
- AggregateFunction::ApproxPercentileCont =>
Ok(coerced_data_types[0].clone()),
- AggregateFunction::ApproxPercentileContWithWeight => {
- Ok(coerced_data_types[0].clone())
- }
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
@@ -230,39 +215,6 @@ impl AggregateFunction {
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
- AggregateFunction::ApproxPercentileCont => {
- let mut variants =
- Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
- // Accept any numeric value paired with a float64 percentile
- for num in NUMERICS {
- variants
- .push(TypeSignature::Exact(vec![num.clone(),
DataType::Float64]));
- // Additionally accept an integer number of centroids for
T-Digest
- for int in INTEGERS {
- variants.push(TypeSignature::Exact(vec![
- num.clone(),
- DataType::Float64,
- int.clone(),
- ]))
- }
- }
-
- Signature::one_of(variants, Volatility::Immutable)
- }
- AggregateFunction::ApproxPercentileContWithWeight =>
Signature::one_of(
- // Accept any numeric value paired with a float64 percentile
- NUMERICS
- .iter()
- .map(|t| {
- TypeSignature::Exact(vec![
- t.clone(),
- t.clone(),
- DataType::Float64,
- ])
- })
- .collect(),
- Volatility::Immutable,
- ),
AggregateFunction::StringAgg => {
Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable)
}
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index fb5b3991ec..099851aece 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -242,34 +242,6 @@ pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool)
-> Expr {
Expr::InList(InList::new(Box::new(expr), list, negated))
}
-/// Calculate an approximation of the specified `percentile` for `expr`.
-pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
- Expr::AggregateFunction(AggregateFunction::new(
- aggregate_function::AggregateFunction::ApproxPercentileCont,
- vec![expr, percentile],
- false,
- None,
- None,
- None,
- ))
-}
-
-/// Calculate an approximation of the specified `percentile` for `expr` and
`weight_expr`.
-pub fn approx_percentile_cont_with_weight(
- expr: Expr,
- weight_expr: Expr,
- percentile: Expr,
-) -> Expr {
- Expr::AggregateFunction(AggregateFunction::new(
- aggregate_function::AggregateFunction::ApproxPercentileContWithWeight,
- vec![expr, weight_expr, percentile],
- false,
- None,
- None,
- None,
- ))
-}
-
/// Create an EXISTS subquery expression
pub fn exists(subquery: Arc<LogicalPlan>) -> Expr {
let outer_ref_columns = subquery.all_out_ref_exprs();
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index c06f177510..169436145a 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -83,8 +83,8 @@ pub struct AccumulatorArgs<'a> {
/// The input type of the aggregate function.
pub input_type: &'a DataType,
- /// The number of arguments the aggregate function takes.
- pub args_num: usize,
+ /// The logical expression of arguments the aggregate function takes.
+ pub input_exprs: &'a [Expr],
}
/// [`StateFieldsArgs`] contains information about the fields that an
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs
b/datafusion/expr/src/type_coercion/aggregates.rs
index 6c9a71bab4..98324ed612 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -17,7 +17,6 @@
use std::ops::Deref;
-use super::functions::can_coerce_from;
use crate::{AggregateFunction, Signature, TypeSignature};
use arrow::datatypes::{
@@ -158,55 +157,6 @@ pub fn coerce_types(
}
Ok(vec![Float64, Float64])
}
- AggregateFunction::ApproxPercentileCont => {
- if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
- return plan_err!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun,
- input_types[0]
- );
- }
- if input_types.len() == 3 && !input_types[2].is_integer() {
- return plan_err!(
- "The percentile sample points count for {:?} must be
integer, not {:?}.",
- agg_fun, input_types[2]
- );
- }
- let mut result = input_types.to_vec();
- if can_coerce_from(&Float64, &input_types[1]) {
- result[1] = Float64;
- } else {
- return plan_err!(
- "Could not coerce the percent argument for {:?} to
Float64. Was {:?}.",
- agg_fun, input_types[1]
- );
- }
- Ok(result)
- }
- AggregateFunction::ApproxPercentileContWithWeight => {
- if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
- return plan_err!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun,
- input_types[0]
- );
- }
- if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) {
- return plan_err!(
- "The weight argument for {:?} does not support inputs of
type {:?}.",
- agg_fun,
- input_types[1]
- );
- }
- if !matches!(input_types[2], Float64) {
- return plan_err!(
- "The percentile argument for {:?} must be Float64, not
{:?}.",
- agg_fun,
- input_types[2]
- );
- }
- Ok(input_types.to_vec())
- }
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
AggregateFunction::StringAgg => {
@@ -459,15 +409,6 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
arg_type.is_integer()
}
-/// Return `true` if `arg_type` is of a [`DataType`] that the
-/// [`AggregateFunction::ApproxPercentileCont`] aggregation can operate on.
-pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) ->
bool {
- matches!(
- arg_type,
- arg_type if NUMERICS.contains(arg_type)
- )
-}
-
/// Return `true` if `arg_type` is of a [`DataType`] that the
/// [`AggregateFunction::StringAgg`] aggregation can operate on.
pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool {
@@ -532,29 +473,6 @@ mod tests {
assert_eq!(r[0], DataType::Decimal128(20, 3));
let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)],
&signature).unwrap();
assert_eq!(r[0], DataType::Decimal256(20, 3));
-
- // ApproxPercentileCont input types
- let input_types = vec![
- vec![DataType::Int8, DataType::Float64],
- vec![DataType::Int16, DataType::Float64],
- vec![DataType::Int32, DataType::Float64],
- vec![DataType::Int64, DataType::Float64],
- vec![DataType::UInt8, DataType::Float64],
- vec![DataType::UInt16, DataType::Float64],
- vec![DataType::UInt32, DataType::Float64],
- vec![DataType::UInt64, DataType::Float64],
- vec![DataType::Float32, DataType::Float64],
- vec![DataType::Float64, DataType::Float64],
- ];
- for input_type in &input_types {
- let signature =
AggregateFunction::ApproxPercentileCont.signature();
- let result = coerce_types(
- &AggregateFunction::ApproxPercentileCont,
- input_type,
- &signature,
- );
- assert_eq!(*input_type, result.unwrap());
- }
}
#[test]
diff --git a/datafusion/functions-aggregate/src/approx_median.rs
b/datafusion/functions-aggregate/src/approx_median.rs
index b8b86d3055..bc723c8629 100644
--- a/datafusion/functions-aggregate/src/approx_median.rs
+++ b/datafusion/functions-aggregate/src/approx_median.rs
@@ -28,7 +28,6 @@ use datafusion_expr::function::{AccumulatorArgs,
StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
-use datafusion_physical_expr_common::aggregate::utils::down_cast_any_ref;
use crate::approx_percentile_cont::ApproxPercentileAccumulator;
@@ -118,12 +117,3 @@ impl AggregateUDFImpl for ApproxMedian {
)))
}
}
-
-impl PartialEq<dyn Any> for ApproxMedian {
- fn eq(&self, other: &dyn Any) -> bool {
- down_cast_any_ref(other)
- .downcast_ref::<Self>()
- .map(|x| self.signature == x.signature)
- .unwrap_or(false)
- }
-}
diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs
b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
index e75417efc6..5ae5684d9c 100644
--- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs
+++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
@@ -15,6 +15,11 @@
// specific language governing permissions and limitations
// under the License.
+use std::any::Any;
+use std::fmt::{Debug, Formatter};
+use std::sync::Arc;
+
+use arrow::array::RecordBatch;
use arrow::{
array::{
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array,
@@ -22,12 +27,238 @@ use arrow::{
},
datatypes::DataType,
};
+use arrow_schema::{Field, Schema};
-use datafusion_common::{downcast_value, internal_err, DataFusionError,
ScalarValue};
-use datafusion_expr::Accumulator;
+use datafusion_common::{
+ downcast_value, internal_err, not_impl_err, plan_err, DataFusionError,
ScalarValue,
+};
+use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
+use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS};
+use datafusion_expr::utils::format_state_name;
+use datafusion_expr::{
+ Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature,
TypeSignature,
+ Volatility,
+};
use datafusion_physical_expr_common::aggregate::tdigest::{
TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
};
+use
datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr;
+
+make_udaf_expr_and_func!(
+ ApproxPercentileCont,
+ approx_percentile_cont,
+ expression percentile,
+ "Computes the approximate percentile continuous of a set of numbers",
+ approx_percentile_cont_udaf
+);
+
+pub struct ApproxPercentileCont {
+ signature: Signature,
+}
+
+impl Debug for ApproxPercentileCont {
+ fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+ f.debug_struct("ApproxPercentileCont")
+ .field("name", &self.name())
+ .field("signature", &self.signature)
+ .finish()
+ }
+}
+
+impl Default for ApproxPercentileCont {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl ApproxPercentileCont {
+ /// Create a new [`ApproxPercentileCont`] aggregate function.
+ pub fn new() -> Self {
+ let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len()
+ 1));
+ // Accept any numeric value paired with a float64 percentile
+ for num in NUMERICS {
+ variants.push(TypeSignature::Exact(vec![num.clone(),
DataType::Float64]));
+ // Additionally accept an integer number of centroids for T-Digest
+ for int in INTEGERS {
+ variants.push(TypeSignature::Exact(vec![
+ num.clone(),
+ DataType::Float64,
+ int.clone(),
+ ]))
+ }
+ }
+ Self {
+ signature: Signature::one_of(variants, Volatility::Immutable),
+ }
+ }
+
+ pub(crate) fn create_accumulator(
+ &self,
+ args: AccumulatorArgs,
+ ) -> datafusion_common::Result<ApproxPercentileAccumulator> {
+ let percentile = validate_input_percentile_expr(&args.input_exprs[1])?;
+ let tdigest_max_size = if args.input_exprs.len() == 3 {
+ Some(validate_input_max_size_expr(&args.input_exprs[2])?)
+ } else {
+ None
+ };
+
+ let accumulator: ApproxPercentileAccumulator = match args.input_type {
+ t @ (DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ | DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::Float32
+ | DataType::Float64) => {
+ if let Some(max_size) = tdigest_max_size {
+ ApproxPercentileAccumulator::new_with_max_size(percentile,
t.clone(), max_size)
+ }else{
+ ApproxPercentileAccumulator::new(percentile, t.clone())
+
+ }
+ }
+ other => {
+ return not_impl_err!(
+ "Support for 'APPROX_PERCENTILE_CONT' for data type
{other} is not implemented"
+ )
+ }
+ };
+
+ Ok(accumulator)
+ }
+}
+
+fn get_lit_value(expr: &Expr) -> datafusion_common::Result<ScalarValue> {
+ let empty_schema = Arc::new(Schema::empty());
+ let empty_batch = RecordBatch::new_empty(Arc::clone(&empty_schema));
+ let expr = limited_convert_logical_expr_to_physical_expr(expr,
&empty_schema)?;
+ let result = expr.evaluate(&empty_batch)?;
+ match result {
+ ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!(
+ "The expr {:?} can't be evaluated to scalar value",
+ expr
+ ))),
+ ColumnarValue::Scalar(scalar_value) => Ok(scalar_value),
+ }
+}
+
+fn validate_input_percentile_expr(expr: &Expr) ->
datafusion_common::Result<f64> {
+ let lit = get_lit_value(expr)?;
+ let percentile = match &lit {
+ ScalarValue::Float32(Some(q)) => *q as f64,
+ ScalarValue::Float64(Some(q)) => *q,
+ got => return not_impl_err!(
+ "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or
Float64 literal (got data type {})",
+ got.data_type()
+ )
+ };
+
+ // Ensure the percentile is between 0 and 1.
+ if !(0.0..=1.0).contains(&percentile) {
+ return plan_err!(
+ "Percentile value must be between 0.0 and 1.0 inclusive,
{percentile} is invalid"
+ );
+ }
+ Ok(percentile)
+}
+
+fn validate_input_max_size_expr(expr: &Expr) ->
datafusion_common::Result<usize> {
+ let lit = get_lit_value(expr)?;
+ let max_size = match &lit {
+ ScalarValue::UInt8(Some(q)) => *q as usize,
+ ScalarValue::UInt16(Some(q)) => *q as usize,
+ ScalarValue::UInt32(Some(q)) => *q as usize,
+ ScalarValue::UInt64(Some(q)) => *q as usize,
+ ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize,
+ ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize,
+ ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize,
+ ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize,
+ got => return not_impl_err!(
+ "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt
> 0 literal (got data type {}).",
+ got.data_type()
+ )
+ };
+ Ok(max_size)
+}
+
+impl AggregateUDFImpl for ApproxPercentileCont {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ #[allow(rustdoc::private_intra_doc_links)]
+ /// See
[`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`]
for a description of the serialised
+ /// state.
+ fn state_fields(
+ &self,
+ args: StateFieldsArgs,
+ ) -> datafusion_common::Result<Vec<Field>> {
+ Ok(vec![
+ Field::new(
+ format_state_name(args.name, "max_size"),
+ DataType::UInt64,
+ false,
+ ),
+ Field::new(
+ format_state_name(args.name, "sum"),
+ DataType::Float64,
+ false,
+ ),
+ Field::new(
+ format_state_name(args.name, "count"),
+ DataType::Float64,
+ false,
+ ),
+ Field::new(
+ format_state_name(args.name, "max"),
+ DataType::Float64,
+ false,
+ ),
+ Field::new(
+ format_state_name(args.name, "min"),
+ DataType::Float64,
+ false,
+ ),
+ Field::new_list(
+ format_state_name(args.name, "centroids"),
+ Field::new("item", DataType::Float64, true),
+ false,
+ ),
+ ])
+ }
+
+ fn name(&self) -> &str {
+ "approx_percentile_cont"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ #[inline]
+ fn accumulator(
+ &self,
+ acc_args: AccumulatorArgs,
+ ) -> datafusion_common::Result<Box<dyn Accumulator>> {
+ Ok(Box::new(self.create_accumulator(acc_args)?))
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) ->
datafusion_common::Result<DataType> {
+ if !arg_types[0].is_numeric() {
+ return plan_err!("approx_percentile_cont requires numeric input
types");
+ }
+ if arg_types.len() == 3 && !arg_types[2].is_integer() {
+ return plan_err!(
+ "approx_percentile_cont requires integer max_size input types"
+ );
+ }
+ Ok(arg_types[0].clone())
+ }
+}
#[derive(Debug)]
pub struct ApproxPercentileAccumulator {
diff --git
a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs
similarity index 51%
rename from
datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
rename to
datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs
index 07c2aff343..a64218c606 100644
---
a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
+++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs
@@ -15,105 +15,140 @@
// specific language governing permissions and limitations
// under the License.
-use crate::expressions::ApproxPercentileCont;
-use crate::{AggregateExpr, PhysicalExpr};
+use std::any::Any;
+use std::fmt::{Debug, Formatter};
+
use arrow::{
array::ArrayRef,
datatypes::{DataType, Field},
};
-use
datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator;
+
+use datafusion_common::ScalarValue;
+use datafusion_common::{not_impl_err, plan_err, Result};
+use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
+use datafusion_expr::type_coercion::aggregates::NUMERICS;
+use datafusion_expr::Volatility::Immutable;
+use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature};
use datafusion_physical_expr_common::aggregate::tdigest::{
Centroid, TDigest, DEFAULT_MAX_SIZE,
};
-use datafusion_common::Result;
-use datafusion_common::ScalarValue;
-use datafusion_expr::Accumulator;
+use crate::approx_percentile_cont::{ApproxPercentileAccumulator,
ApproxPercentileCont};
-use crate::aggregate::utils::down_cast_any_ref;
-use std::{any::Any, sync::Arc};
+make_udaf_expr_and_func!(
+ ApproxPercentileContWithWeight,
+ approx_percentile_cont_with_weight,
+ expression weight percentile,
+ "Computes the approximate percentile continuous with weight of a set of
numbers",
+ approx_percentile_cont_with_weight_udaf
+);
/// APPROX_PERCENTILE_CONT_WITH_WEIGTH aggregate expression
-#[derive(Debug)]
pub struct ApproxPercentileContWithWeight {
+ signature: Signature,
approx_percentile_cont: ApproxPercentileCont,
- column_expr: Arc<dyn PhysicalExpr>,
- weight_expr: Arc<dyn PhysicalExpr>,
- percentile_expr: Arc<dyn PhysicalExpr>,
+}
+
+impl Debug for ApproxPercentileContWithWeight {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("ApproxPercentileContWithWeight")
+ .field("signature", &self.signature)
+ .finish()
+ }
+}
+
+impl Default for ApproxPercentileContWithWeight {
+ fn default() -> Self {
+ Self::new()
+ }
}
impl ApproxPercentileContWithWeight {
/// Create a new [`ApproxPercentileContWithWeight`] aggregate function.
- pub fn new(
- expr: Vec<Arc<dyn PhysicalExpr>>,
- name: impl Into<String>,
- return_type: DataType,
- ) -> Result<Self> {
- // Arguments should be [ColumnExpr, WeightExpr,
DesiredPercentileLiteral]
- debug_assert_eq!(expr.len(), 3);
-
- let sub_expr = vec![expr[0].clone(), expr[2].clone()];
- let approx_percentile_cont =
- ApproxPercentileCont::new(sub_expr, name, return_type)?;
-
- Ok(Self {
- approx_percentile_cont,
- column_expr: expr[0].clone(),
- weight_expr: expr[1].clone(),
- percentile_expr: expr[2].clone(),
- })
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::one_of(
+ // Accept any numeric value paired with a float64 percentile
+ NUMERICS
+ .iter()
+ .map(|t| {
+ TypeSignature::Exact(vec![
+ t.clone(),
+ t.clone(),
+ DataType::Float64,
+ ])
+ })
+ .collect(),
+ Immutable,
+ ),
+ approx_percentile_cont: ApproxPercentileCont::new(),
+ }
}
}
-impl AggregateExpr for ApproxPercentileContWithWeight {
+impl AggregateUDFImpl for ApproxPercentileContWithWeight {
fn as_any(&self) -> &dyn Any {
self
}
- fn field(&self) -> Result<Field> {
- self.approx_percentile_cont.field()
+ fn name(&self) -> &str {
+ "approx_percentile_cont_with_weight"
}
- #[allow(rustdoc::private_intra_doc_links)]
- /// See [`TDigest::to_scalar_state()`] for a description of the serialised
- /// state.
- fn state_fields(&self) -> Result<Vec<Field>> {
- self.approx_percentile_cont.state_fields()
+ fn signature(&self) -> &Signature {
+ &self.signature
}
- fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- vec![
- self.column_expr.clone(),
- self.weight_expr.clone(),
- self.percentile_expr.clone(),
- ]
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ if !arg_types[0].is_numeric() {
+ return plan_err!(
+ "approx_percentile_cont_with_weight requires numeric input
types"
+ );
+ }
+ if !arg_types[1].is_numeric() {
+ return plan_err!(
+ "approx_percentile_cont_with_weight requires numeric weight
input types"
+ );
+ }
+ if arg_types[2] != DataType::Float64 {
+ return plan_err!("approx_percentile_cont_with_weight requires
float64 percentile input types");
+ }
+ Ok(arg_types[0].clone())
}
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ if acc_args.is_distinct {
+ return not_impl_err!(
+ "approx_percentile_cont_with_weight(DISTINCT) aggregations are
not available"
+ );
+ }
+
+ if acc_args.input_exprs.len() != 3 {
+ return plan_err!(
+ "approx_percentile_cont_with_weight requires three arguments:
value, weight, percentile"
+ );
+ }
+
+ let sub_args = AccumulatorArgs {
+ input_exprs: &[
+ acc_args.input_exprs[0].clone(),
+ acc_args.input_exprs[2].clone(),
+ ],
+ ..acc_args
+ };
let approx_percentile_cont_accumulator =
- self.approx_percentile_cont.create_plain_accumulator()?;
+ self.approx_percentile_cont.create_accumulator(sub_args)?;
let accumulator = ApproxPercentileWithWeightAccumulator::new(
approx_percentile_cont_accumulator,
);
Ok(Box::new(accumulator))
}
- fn name(&self) -> &str {
- self.approx_percentile_cont.name()
- }
-}
-
-impl PartialEq<dyn Any> for ApproxPercentileContWithWeight {
- fn eq(&self, other: &dyn Any) -> bool {
- down_cast_any_ref(other)
- .downcast_ref::<Self>()
- .map(|x| {
- self.approx_percentile_cont == x.approx_percentile_cont
- && self.column_expr.eq(&x.column_expr)
- && self.weight_expr.eq(&x.weight_expr)
- && self.percentile_expr.eq(&x.percentile_expr)
- })
- .unwrap_or(false)
+ #[allow(rustdoc::private_intra_doc_links)]
+ /// See [`TDigest::to_scalar_state()`] for a description of the serialised
+ /// state.
+ fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+ self.approx_percentile_cont.state_fields(args)
}
}
diff --git a/datafusion/functions-aggregate/src/count.rs
b/datafusion/functions-aggregate/src/count.rs
index cfd5661953..062e148975 100644
--- a/datafusion/functions-aggregate/src/count.rs
+++ b/datafusion/functions-aggregate/src/count.rs
@@ -258,7 +258,7 @@ impl AggregateUDFImpl for Count {
if args.is_distinct {
return false;
}
- args.args_num == 1
+ args.input_exprs.len() == 1
}
fn create_groups_accumulator(
diff --git a/datafusion/functions-aggregate/src/lib.rs
b/datafusion/functions-aggregate/src/lib.rs
index fabe15e416..daddb9d93f 100644
--- a/datafusion/functions-aggregate/src/lib.rs
+++ b/datafusion/functions-aggregate/src/lib.rs
@@ -68,7 +68,10 @@ pub mod variance;
pub mod approx_median;
pub mod approx_percentile_cont;
+pub mod approx_percentile_cont_with_weight;
+use crate::approx_percentile_cont::approx_percentile_cont_udaf;
+use
crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf;
use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
use datafusion_expr::AggregateUDF;
@@ -79,6 +82,8 @@ use std::sync::Arc;
pub mod expr_fn {
pub use super::approx_distinct;
pub use super::approx_median::approx_median;
+ pub use super::approx_percentile_cont::approx_percentile_cont;
+ pub use
super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight;
pub use super::count::count;
pub use super::count::count_distinct;
pub use super::covariance::covar_pop;
@@ -127,6 +132,8 @@ pub fn all_default_aggregate_functions() ->
Vec<Arc<AggregateUDF>> {
stddev::stddev_pop_udaf(),
approx_median::approx_median_udaf(),
approx_distinct::approx_distinct_udaf(),
+ approx_percentile_cont_udaf(),
+ approx_percentile_cont_with_weight_udaf(),
]
}
diff --git a/datafusion/functions-aggregate/src/stddev.rs
b/datafusion/functions-aggregate/src/stddev.rs
index 4c3effe765..42cf44f65d 100644
--- a/datafusion/functions-aggregate/src/stddev.rs
+++ b/datafusion/functions-aggregate/src/stddev.rs
@@ -332,7 +332,7 @@ mod tests {
name: "a",
is_distinct: false,
input_type: &DataType::Float64,
- args_num: 1,
+ input_exprs: &[datafusion_expr::col("a")],
};
let args2 = AccumulatorArgs {
@@ -343,7 +343,7 @@ mod tests {
name: "a",
is_distinct: false,
input_type: &DataType::Float64,
- args_num: 1,
+ input_exprs: &[datafusion_expr::col("a")],
};
let mut accum1 = agg1.accumulator(args1)?;
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 0c8e4ae34a..acc21f14f4 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -1055,31 +1055,6 @@ mod test {
Ok(())
}
- #[test]
- fn agg_function_invalid_input_percentile() {
- let empty = empty();
- let fun: AggregateFunction = AggregateFunction::ApproxPercentileCont;
- let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
- fun,
- vec![lit(0.95), lit(42.0), lit(100.0)],
- false,
- None,
- None,
- None,
- ));
-
- let err = Projection::try_new(vec![agg_expr], empty)
- .err()
- .unwrap()
- .strip_backtrace();
-
- let prefix = "Error during planning: No function matches the given
name and argument types 'APPROX_PERCENTILE_CONT(Float64, Float64, Float64)'.
You might need to add explicit type casts.\n\tCandidate functions:";
- assert!(!err
- .strip_prefix(prefix)
- .unwrap()
- .contains("APPROX_PERCENTILE_CONT(Float64, Float64, Float64)"));
- }
-
#[test]
fn binary_op_date32_op_interval() -> Result<()> {
// CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...")
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index 21884f840d..432267e045 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -46,6 +46,7 @@ use datafusion_expr::utils::AggregateOrderSensitivity;
pub fn create_aggregate_expr(
fun: &AggregateUDF,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
+ input_exprs: &[Expr],
sort_exprs: &[Expr],
ordering_req: &[PhysicalSortExpr],
schema: &Schema,
@@ -76,6 +77,7 @@ pub fn create_aggregate_expr(
Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
args: input_phy_exprs.to_vec(),
+ logical_args: input_exprs.to_vec(),
data_type: fun.return_type(&input_exprs_types)?,
name: name.into(),
schema: schema.clone(),
@@ -231,6 +233,7 @@ pub struct AggregatePhysicalExpressions {
pub struct AggregateFunctionExpr {
fun: AggregateUDF,
args: Vec<Arc<dyn PhysicalExpr>>,
+ logical_args: Vec<Expr>,
/// Output / return type of this aggregate
data_type: DataType,
name: String,
@@ -293,7 +296,7 @@ impl AggregateExpr for AggregateFunctionExpr {
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
- args_num: self.args.len(),
+ input_exprs: &self.logical_args,
name: &self.name,
};
@@ -308,7 +311,7 @@ impl AggregateExpr for AggregateFunctionExpr {
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
- args_num: self.args.len(),
+ input_exprs: &self.logical_args,
name: &self.name,
};
@@ -378,7 +381,7 @@ impl AggregateExpr for AggregateFunctionExpr {
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
- args_num: self.args.len(),
+ input_exprs: &self.logical_args,
name: &self.name,
};
self.fun.groups_accumulator_supported(args)
@@ -392,7 +395,7 @@ impl AggregateExpr for AggregateFunctionExpr {
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
- args_num: self.args.len(),
+ input_exprs: &self.logical_args,
name: &self.name,
};
self.fun.create_groups_accumulator(args)
@@ -434,6 +437,7 @@ impl AggregateExpr for AggregateFunctionExpr {
create_aggregate_expr(
&updated_fn,
&self.args,
+ &self.logical_args,
&self.sort_exprs,
&self.ordering_req,
&self.schema,
@@ -468,6 +472,7 @@ impl AggregateExpr for AggregateFunctionExpr {
let reverse_aggr = create_aggregate_expr(
&reverse_udf,
&self.args,
+ &self.logical_args,
&reverse_sort_exprs,
&reverse_ordering_req,
&self.schema,
diff --git a/datafusion/physical-expr-common/src/expressions/mod.rs
b/datafusion/physical-expr-common/src/expressions/mod.rs
index ea21c8e9a9..dd534cc07d 100644
--- a/datafusion/physical-expr-common/src/expressions/mod.rs
+++ b/datafusion/physical-expr-common/src/expressions/mod.rs
@@ -17,7 +17,7 @@
mod cast;
pub mod column;
-mod literal;
+pub mod literal;
pub use cast::{cast, cast_with_options, CastExpr};
pub use literal::{lit, Literal};
diff --git a/datafusion/physical-expr-common/src/utils.rs
b/datafusion/physical-expr-common/src/utils.rs
index f661400fcb..d5cd3c6f4a 100644
--- a/datafusion/physical-expr-common/src/utils.rs
+++ b/datafusion/physical-expr-common/src/utils.rs
@@ -17,18 +17,21 @@
use std::sync::Arc;
-use crate::expressions::{self, CastExpr};
-use crate::physical_expr::PhysicalExpr;
-use crate::sort_expr::PhysicalSortExpr;
-use crate::tree_node::ExprContext;
-
use arrow::array::{make_array, Array, ArrayRef, BooleanArray,
MutableArrayData};
use arrow::compute::{and_kleene, is_not_null, SlicesIterator};
use arrow::datatypes::Schema;
+
use datafusion_common::{exec_err, Result};
+use datafusion_expr::expr::Alias;
use datafusion_expr::sort_properties::ExprProperties;
use datafusion_expr::Expr;
+use crate::expressions::literal::Literal;
+use crate::expressions::{self, CastExpr};
+use crate::physical_expr::PhysicalExpr;
+use crate::sort_expr::PhysicalSortExpr;
+use crate::tree_node::ExprContext;
+
/// Represents a [`PhysicalExpr`] node with associated properties (order and
/// range) in a context where properties are tracked.
pub type ExprPropertiesNode = ExprContext<ExprProperties>;
@@ -115,6 +118,9 @@ pub fn limited_convert_logical_expr_to_physical_expr(
schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
match expr {
+ Expr::Alias(Alias { expr, .. }) => {
+ Ok(limited_convert_logical_expr_to_physical_expr(expr, schema)?)
+ }
Expr::Column(col) => expressions::column::col(&col.name, schema),
Expr::Cast(cast_expr) => Ok(Arc::new(CastExpr::new(
limited_convert_logical_expr_to_physical_expr(
@@ -124,10 +130,7 @@ pub fn limited_convert_logical_expr_to_physical_expr(
cast_expr.data_type.clone(),
None,
))),
- Expr::Alias(alias_expr) =>
limited_convert_logical_expr_to_physical_expr(
- alias_expr.expr.as_ref(),
- schema,
- ),
+ Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))),
_ => exec_err!(
"Unsupported expression: {expr} for conversion to Arc<dyn
PhysicalExpr>"
),
@@ -138,11 +141,12 @@ pub fn limited_convert_logical_expr_to_physical_expr(
mod tests {
use std::sync::Arc;
- use super::*;
-
use arrow::array::Int32Array;
+
use datafusion_common::cast::{as_boolean_array, as_int32_array};
+ use super::*;
+
#[test]
fn scatter_int() -> Result<()> {
let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100]));
diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
deleted file mode 100644
index f2068bbc92..0000000000
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
+++ /dev/null
@@ -1,249 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use std::{any::Any, sync::Arc};
-
-use arrow::datatypes::{DataType, Field};
-use arrow_array::RecordBatch;
-use arrow_schema::Schema;
-
-use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result,
ScalarValue};
-use datafusion_expr::{Accumulator, ColumnarValue};
-use
datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator;
-
-use crate::aggregate::utils::down_cast_any_ref;
-use crate::expressions::format_state_name;
-use crate::{AggregateExpr, PhysicalExpr};
-
-/// APPROX_PERCENTILE_CONT aggregate expression
-#[derive(Debug)]
-pub struct ApproxPercentileCont {
- name: String,
- input_data_type: DataType,
- expr: Vec<Arc<dyn PhysicalExpr>>,
- percentile: f64,
- tdigest_max_size: Option<usize>,
-}
-
-impl ApproxPercentileCont {
- /// Create a new [`ApproxPercentileCont`] aggregate function.
- pub fn new(
- expr: Vec<Arc<dyn PhysicalExpr>>,
- name: impl Into<String>,
- input_data_type: DataType,
- ) -> Result<Self> {
- // Arguments should be [ColumnExpr, DesiredPercentileLiteral]
- debug_assert_eq!(expr.len(), 2);
-
- let percentile = validate_input_percentile_expr(&expr[1])?;
-
- Ok(Self {
- name: name.into(),
- input_data_type,
- // The physical expr to evaluate during accumulation
- expr,
- percentile,
- tdigest_max_size: None,
- })
- }
-
- /// Create a new [`ApproxPercentileCont`] aggregate function.
- pub fn new_with_max_size(
- expr: Vec<Arc<dyn PhysicalExpr>>,
- name: impl Into<String>,
- input_data_type: DataType,
- ) -> Result<Self> {
- // Arguments should be [ColumnExpr, DesiredPercentileLiteral,
TDigestMaxSize]
- debug_assert_eq!(expr.len(), 3);
- let percentile = validate_input_percentile_expr(&expr[1])?;
- let max_size = validate_input_max_size_expr(&expr[2])?;
- Ok(Self {
- name: name.into(),
- input_data_type,
- // The physical expr to evaluate during accumulation
- expr,
- percentile,
- tdigest_max_size: Some(max_size),
- })
- }
-
- pub(crate) fn create_plain_accumulator(&self) ->
Result<ApproxPercentileAccumulator> {
- let accumulator: ApproxPercentileAccumulator = match
&self.input_data_type {
- t @ (DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64) => {
- if let Some(max_size) = self.tdigest_max_size {
-
ApproxPercentileAccumulator::new_with_max_size(self.percentile, t.clone(),
max_size)
-
- }else{
- ApproxPercentileAccumulator::new(self.percentile,
t.clone())
-
- }
- }
- other => {
- return not_impl_err!(
- "Support for 'APPROX_PERCENTILE_CONT' for data type
{other} is not implemented"
- )
- }
- };
- Ok(accumulator)
- }
-}
-
-impl PartialEq for ApproxPercentileCont {
- fn eq(&self, other: &ApproxPercentileCont) -> bool {
- self.name == other.name
- && self.input_data_type == other.input_data_type
- && self.percentile == other.percentile
- && self.tdigest_max_size == other.tdigest_max_size
- && self.expr.len() == other.expr.len()
- && self
- .expr
- .iter()
- .zip(other.expr.iter())
- .all(|(this, other)| this.eq(other))
- }
-}
-
-fn get_lit_value(expr: &Arc<dyn PhysicalExpr>) -> Result<ScalarValue> {
- let empty_schema = Schema::empty();
- let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema));
- let result = expr.evaluate(&empty_batch)?;
- match result {
- ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!(
- "The expr {:?} can't be evaluated to scalar value",
- expr
- ))),
- ColumnarValue::Scalar(scalar_value) => Ok(scalar_value),
- }
-}
-
-fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64>
{
- let lit = get_lit_value(expr)?;
- let percentile = match &lit {
- ScalarValue::Float32(Some(q)) => *q as f64,
- ScalarValue::Float64(Some(q)) => *q,
- got => return not_impl_err!(
- "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or
Float64 literal (got data type {})",
- got.data_type()
- )
- };
-
- // Ensure the percentile is between 0 and 1.
- if !(0.0..=1.0).contains(&percentile) {
- return plan_err!(
- "Percentile value must be between 0.0 and 1.0 inclusive,
{percentile} is invalid"
- );
- }
- Ok(percentile)
-}
-
-fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize>
{
- let lit = get_lit_value(expr)?;
- let max_size = match &lit {
- ScalarValue::UInt8(Some(q)) => *q as usize,
- ScalarValue::UInt16(Some(q)) => *q as usize,
- ScalarValue::UInt32(Some(q)) => *q as usize,
- ScalarValue::UInt64(Some(q)) => *q as usize,
- ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize,
- ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize,
- ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize,
- ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize,
- got => return not_impl_err!(
- "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt
> 0 literal (got data type {}).",
- got.data_type()
- )
- };
- Ok(max_size)
-}
-
-impl AggregateExpr for ApproxPercentileCont {
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, self.input_data_type.clone(), false))
- }
-
- #[allow(rustdoc::private_intra_doc_links)]
- /// See
[`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`]
for a description of the serialised
- /// state.
- fn state_fields(&self) -> Result<Vec<Field>> {
- Ok(vec![
- Field::new(
- format_state_name(&self.name, "max_size"),
- DataType::UInt64,
- false,
- ),
- Field::new(
- format_state_name(&self.name, "sum"),
- DataType::Float64,
- false,
- ),
- Field::new(
- format_state_name(&self.name, "count"),
- DataType::Float64,
- false,
- ),
- Field::new(
- format_state_name(&self.name, "max"),
- DataType::Float64,
- false,
- ),
- Field::new(
- format_state_name(&self.name, "min"),
- DataType::Float64,
- false,
- ),
- Field::new_list(
- format_state_name(&self.name, "centroids"),
- Field::new("item", DataType::Float64, true),
- false,
- ),
- ])
- }
-
- fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- self.expr.clone()
- }
-
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- let accumulator = self.create_plain_accumulator()?;
- Ok(Box::new(accumulator))
- }
-
- fn name(&self) -> &str {
- &self.name
- }
-}
-
-impl PartialEq<dyn Any> for ApproxPercentileCont {
- fn eq(&self, other: &dyn Any) -> bool {
- down_cast_any_ref(other)
- .downcast_ref::<Self>()
- .map(|x| self.eq(x))
- .unwrap_or(false)
- }
-}
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index df87a2e261..a1f5f153a9 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -36,6 +36,7 @@ use datafusion_expr::AggregateFunction;
use crate::aggregate::average::Avg;
use crate::expressions::{self, Literal};
use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
+
/// Create a physical aggregation expression.
/// This function errors when `input_phy_exprs`' can't be coerced to a valid
argument type of the aggregation function.
pub fn create_aggregate_expr(
@@ -154,41 +155,6 @@ pub fn create_aggregate_expr(
(AggregateFunction::Correlation, true) => {
return not_impl_err!("CORR(DISTINCT) aggregations are not
available");
}
- (AggregateFunction::ApproxPercentileCont, false) => {
- if input_phy_exprs.len() == 2 {
- Arc::new(expressions::ApproxPercentileCont::new(
- // Pass in the desired percentile expr
- input_phy_exprs,
- name,
- data_type,
- )?)
- } else {
- Arc::new(expressions::ApproxPercentileCont::new_with_max_size(
- // Pass in the desired percentile expr
- input_phy_exprs,
- name,
- data_type,
- )?)
- }
- }
- (AggregateFunction::ApproxPercentileCont, true) => {
- return not_impl_err!(
- "approx_percentile_cont(DISTINCT) aggregations are not
available"
- );
- }
- (AggregateFunction::ApproxPercentileContWithWeight, false) => {
- Arc::new(expressions::ApproxPercentileContWithWeight::new(
- // Pass in the desired percentile expr
- input_phy_exprs,
- name,
- data_type,
- )?)
- }
- (AggregateFunction::ApproxPercentileContWithWeight, true) => {
- return not_impl_err!(
- "approx_percentile_cont_with_weight(DISTINCT) aggregations are
not available"
- );
- }
(AggregateFunction::NthValue, _) => {
let expr = &input_phy_exprs[0];
let Some(n) = input_phy_exprs[1]
@@ -232,15 +198,15 @@ pub fn create_aggregate_expr(
mod tests {
use arrow::datatypes::{DataType, Field};
- use super::*;
+ use datafusion_common::plan_err;
+ use datafusion_expr::{type_coercion, Signature};
+
use crate::expressions::{
- try_cast, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor,
BoolAnd,
- BoolOr, DistinctArrayAgg, Max, Min,
+ try_cast, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr,
+ DistinctArrayAgg, Max, Min,
};
- use datafusion_common::{plan_err, DataFusionError, ScalarValue};
- use datafusion_expr::type_coercion::aggregates::NUMERICS;
- use datafusion_expr::{type_coercion, Signature};
+ use super::*;
#[test]
fn test_approx_expr() -> Result<()> {
@@ -304,59 +270,6 @@ mod tests {
Ok(())
}
- #[test]
- fn test_agg_approx_percentile_phy_expr() {
- for data_type in NUMERICS {
- let input_schema =
- Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
- let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
- Arc::new(
- expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
- ),
-
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))),
- ];
- let result_agg_phy_exprs = create_physical_agg_expr_for_test(
- &AggregateFunction::ApproxPercentileCont,
- false,
- &input_phy_exprs[..],
- &input_schema,
- "c1",
- )
- .expect("failed to create aggregate expr");
-
-
assert!(result_agg_phy_exprs.as_any().is::<ApproxPercentileCont>());
- assert_eq!("c1", result_agg_phy_exprs.name());
- assert_eq!(
- Field::new("c1", data_type.clone(), false),
- result_agg_phy_exprs.field().unwrap()
- );
- }
- }
-
- #[test]
- fn test_agg_approx_percentile_invalid_phy_expr() {
- for data_type in NUMERICS {
- let input_schema =
- Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
- let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
- Arc::new(
- expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
- ),
-
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))),
- ];
- let err = create_physical_agg_expr_for_test(
- &AggregateFunction::ApproxPercentileCont,
- false,
- &input_phy_exprs[..],
- &input_schema,
- "c1",
- )
- .expect_err("should fail due to invalid percentile");
-
- assert!(matches!(err, DataFusionError::Plan(_)));
- }
- }
-
#[test]
fn test_min_max_expr() -> Result<()> {
let funcs = vec![AggregateFunction::Min, AggregateFunction::Max];
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs
b/datafusion/physical-expr/src/aggregate/mod.rs
index 9079a81e62..c20902c11b 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -17,8 +17,6 @@
pub use datafusion_physical_expr_common::aggregate::AggregateExpr;
-pub(crate) mod approx_percentile_cont;
-pub(crate) mod approx_percentile_cont_with_weight;
pub(crate) mod array_agg;
pub(crate) mod array_agg_distinct;
pub(crate) mod array_agg_ordered;
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index 592393f800..b9a159b21e 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -35,8 +35,6 @@ mod try_cast;
pub mod helpers {
pub use crate::aggregate::min_max::{max, min};
}
-pub use crate::aggregate::approx_percentile_cont::ApproxPercentileCont;
-pub use
crate::aggregate::approx_percentile_cont_with_weight::ApproxPercentileContWithWeight;
pub use crate::aggregate::array_agg::ArrayAgg;
pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg;
pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg;
@@ -65,8 +63,8 @@ pub use column::UnKnownColumn;
pub use datafusion_expr::utils::format_state_name;
pub use datafusion_functions_aggregate::first_last::{FirstValue, LastValue};
pub use datafusion_physical_expr_common::expressions::column::{col, Column};
+pub use datafusion_physical_expr_common::expressions::literal::{lit, Literal};
pub use datafusion_physical_expr_common::expressions::{cast, CastExpr};
-pub use datafusion_physical_expr_common::expressions::{lit, Literal};
pub use in_list::{in_list, InListExpr};
pub use is_not_null::{is_not_null, IsNotNullExpr};
pub use is_null::{is_null, IsNullExpr};
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs
b/datafusion/physical-plan/src/aggregates/mod.rs
index b6fc70be7c..b7d8d60f4f 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -1339,6 +1339,7 @@ mod tests {
let aggregates = vec![create_aggregate_expr(
&count_udaf(),
&[lit(1i8)],
+ &[datafusion_expr::lit(1i8)],
&[],
&[],
&input_schema,
@@ -1787,6 +1788,7 @@ mod tests {
&args,
&[],
&[],
+ &[],
schema,
"MEDIAN(a)",
false,
@@ -1975,10 +1977,12 @@ mod tests {
options: sort_options,
}];
let args = vec![col("b", schema)?];
+ let logical_args = vec![datafusion_expr::col("b")];
let func =
datafusion_expr::AggregateUDF::new_from_impl(FirstValue::new());
datafusion_physical_expr_common::aggregate::create_aggregate_expr(
&func,
&args,
+ &logical_args,
&sort_exprs,
&ordering_req,
schema,
@@ -2005,10 +2009,12 @@ mod tests {
options: sort_options,
}];
let args = vec![col("b", schema)?];
+ let logical_args = vec![datafusion_expr::col("b")];
let func =
datafusion_expr::AggregateUDF::new_from_impl(LastValue::new());
- datafusion_physical_expr_common::aggregate::create_aggregate_expr(
+ create_aggregate_expr(
&func,
&args,
+ &logical_args,
&sort_exprs,
&ordering_req,
schema,
diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
index 56d780e513..fc60ab9973 100644
--- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
+++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs
@@ -1194,7 +1194,7 @@ mod tests {
RecordBatchStream, SendableRecordBatchStream, TaskContext,
};
use datafusion_expr::{
- WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
+ Expr, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::expressions::{col, Column, NthValue};
@@ -1301,7 +1301,10 @@ mod tests {
let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf());
let col_expr =
Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc<dyn
PhysicalExpr>;
+ let log_expr =
+
Expr::Column(datafusion_common::Column::from(schema.fields[0].name()));
let args = vec![col_expr];
+ let log_args = vec![log_expr];
let partitionby_exprs = vec![col(hash, &schema)?];
let orderby_exprs = vec![PhysicalSortExpr {
expr: col(order_by, &schema)?,
@@ -1322,6 +1325,7 @@ mod tests {
&window_fn,
fn_name,
&args,
+ &log_args,
&partitionby_exprs,
&orderby_exprs,
Arc::new(window_frame.clone()),
diff --git a/datafusion/physical-plan/src/windows/mod.rs
b/datafusion/physical-plan/src/windows/mod.rs
index 63ce473fc5..ecfe123a43 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -90,6 +90,7 @@ pub fn create_window_expr(
fun: &WindowFunctionDefinition,
name: String,
args: &[Arc<dyn PhysicalExpr>],
+ logical_args: &[Expr],
partition_by: &[Arc<dyn PhysicalExpr>],
order_by: &[PhysicalSortExpr],
window_frame: Arc<WindowFrame>,
@@ -144,6 +145,7 @@ pub fn create_window_expr(
let aggregate = udaf::create_aggregate_expr(
fun.as_ref(),
args,
+ logical_args,
&sort_exprs,
order_by,
input_schema,
@@ -754,6 +756,7 @@ mod tests {
&[col("a", &schema)?],
&[],
&[],
+ &[],
Arc::new(WindowFrame::new(None)),
schema.as_ref(),
false,
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 83223a04d0..e5578ae62f 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -486,9 +486,9 @@ enum AggregateFunction {
// STDDEV = 11;
// STDDEV_POP = 12;
CORRELATION = 13;
- APPROX_PERCENTILE_CONT = 14;
+ // APPROX_PERCENTILE_CONT = 14;
// APPROX_MEDIAN = 15;
- APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
+ // APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
GROUPING = 17;
// MEDIAN = 18;
BIT_AND = 19;
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index f298dd241a..4a7b9610e5 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -537,8 +537,6 @@ impl serde::Serialize for AggregateFunction {
Self::Avg => "AVG",
Self::ArrayAgg => "ARRAY_AGG",
Self::Correlation => "CORRELATION",
- Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
- Self::ApproxPercentileContWithWeight =>
"APPROX_PERCENTILE_CONT_WITH_WEIGHT",
Self::Grouping => "GROUPING",
Self::BitAnd => "BIT_AND",
Self::BitOr => "BIT_OR",
@@ -563,8 +561,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"AVG",
"ARRAY_AGG",
"CORRELATION",
- "APPROX_PERCENTILE_CONT",
- "APPROX_PERCENTILE_CONT_WITH_WEIGHT",
"GROUPING",
"BIT_AND",
"BIT_OR",
@@ -618,8 +614,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"AVG" => Ok(AggregateFunction::Avg),
"ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg),
"CORRELATION" => Ok(AggregateFunction::Correlation),
- "APPROX_PERCENTILE_CONT" =>
Ok(AggregateFunction::ApproxPercentileCont),
- "APPROX_PERCENTILE_CONT_WITH_WEIGHT" =>
Ok(AggregateFunction::ApproxPercentileContWithWeight),
"GROUPING" => Ok(AggregateFunction::Grouping),
"BIT_AND" => Ok(AggregateFunction::BitAnd),
"BIT_OR" => Ok(AggregateFunction::BitOr),
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index fa0217e9ef..ffaef445d6 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1940,9 +1940,9 @@ pub enum AggregateFunction {
/// STDDEV = 11;
/// STDDEV_POP = 12;
Correlation = 13,
- ApproxPercentileCont = 14,
+ /// APPROX_PERCENTILE_CONT = 14;
/// APPROX_MEDIAN = 15;
- ApproxPercentileContWithWeight = 16,
+ /// APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
Grouping = 17,
/// MEDIAN = 18;
BitAnd = 19,
@@ -1974,10 +1974,6 @@ impl AggregateFunction {
AggregateFunction::Avg => "AVG",
AggregateFunction::ArrayAgg => "ARRAY_AGG",
AggregateFunction::Correlation => "CORRELATION",
- AggregateFunction::ApproxPercentileCont =>
"APPROX_PERCENTILE_CONT",
- AggregateFunction::ApproxPercentileContWithWeight => {
- "APPROX_PERCENTILE_CONT_WITH_WEIGHT"
- }
AggregateFunction::Grouping => "GROUPING",
AggregateFunction::BitAnd => "BIT_AND",
AggregateFunction::BitOr => "BIT_OR",
@@ -1996,10 +1992,6 @@ impl AggregateFunction {
"AVG" => Some(Self::Avg),
"ARRAY_AGG" => Some(Self::ArrayAgg),
"CORRELATION" => Some(Self::Correlation),
- "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont),
- "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => {
- Some(Self::ApproxPercentileContWithWeight)
- }
"GROUPING" => Some(Self::Grouping),
"BIT_AND" => Some(Self::BitAnd),
"BIT_OR" => Some(Self::BitOr),
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index ed7b0129cc..25b7413a98 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -147,12 +147,6 @@ impl From<protobuf::AggregateFunction> for
AggregateFunction {
protobuf::AggregateFunction::BoolOr => Self::BoolOr,
protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg,
protobuf::AggregateFunction::Correlation => Self::Correlation,
- protobuf::AggregateFunction::ApproxPercentileCont => {
- Self::ApproxPercentileCont
- }
- protobuf::AggregateFunction::ApproxPercentileContWithWeight => {
- Self::ApproxPercentileContWithWeight
- }
protobuf::AggregateFunction::Grouping => Self::Grouping,
protobuf::AggregateFunction::NthValueAgg => Self::NthValue,
protobuf::AggregateFunction::StringAgg => Self::StringAgg,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 04f7b596fe..d9548325da 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -118,10 +118,6 @@ impl From<&AggregateFunction> for
protobuf::AggregateFunction {
AggregateFunction::BoolOr => Self::BoolOr,
AggregateFunction::ArrayAgg => Self::ArrayAgg,
AggregateFunction::Correlation => Self::Correlation,
- AggregateFunction::ApproxPercentileCont =>
Self::ApproxPercentileCont,
- AggregateFunction::ApproxPercentileContWithWeight => {
- Self::ApproxPercentileContWithWeight
- }
AggregateFunction::Grouping => Self::Grouping,
AggregateFunction::NthValue => Self::NthValueAgg,
AggregateFunction::StringAgg => Self::StringAgg,
@@ -381,12 +377,6 @@ pub fn serialize_expr(
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let aggr_function = match fun {
- AggregateFunction::ApproxPercentileCont => {
- protobuf::AggregateFunction::ApproxPercentileCont
- }
- AggregateFunction::ApproxPercentileContWithWeight => {
-
protobuf::AggregateFunction::ApproxPercentileContWithWeight
- }
AggregateFunction::ArrayAgg =>
protobuf::AggregateFunction::ArrayAgg,
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index 0a91df568a..b636c77641 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -126,7 +126,6 @@ pub fn parse_physical_window_expr(
) -> Result<Arc<dyn WindowExpr>> {
let window_node_expr =
parse_physical_exprs(&proto.args, registry, input_schema, codec)?;
-
let partition_by =
parse_physical_exprs(&proto.partition_by, registry, input_schema,
codec)?;
@@ -178,10 +177,13 @@ pub fn parse_physical_window_expr(
// TODO: Remove extended_schema if functions are all UDAF
let extended_schema =
schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?;
+ // approx_percentile_cont and approx_percentile_cont_weight are not
supported for UDAF from protobuf yet.
+ let logical_exprs = &[];
create_window_expr(
&fun,
name,
&window_node_expr,
+ logical_exprs,
&partition_by,
&order_by,
Arc::new(window_frame),
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index d0011e4917..8a488d30cf 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -496,11 +496,14 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
}
AggregateFunction::UserDefinedAggrFunction(udaf_name) => {
let agg_udf =
registry.udaf(udaf_name)?;
+ // TODO: 'logical_exprs' is not
supported for UDAF yet.
+ // approx_percentile_cont and
approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
+ let logical_exprs = &[];
// TODO: `order by` is not
supported for UDAF yet
let sort_exprs = &[];
let ordering_req = &[];
let ignore_nulls = false;
-
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs,
ordering_req, &physical_schema, name, ignore_nulls, false)
+
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs,
sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false)
}
}
}).transpose()?.ok_or_else(|| {
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index ef462ac94b..3a4c35a93e 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -23,12 +23,11 @@ use
datafusion::datasource::file_format::parquet::ParquetSink;
use datafusion::physical_expr::window::{NthValueKind,
SlidingAggregateWindowExpr};
use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr};
use datafusion::physical_plan::expressions::{
- ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg,
BinaryExpr,
- BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column,
Correlation,
- CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr,
IsNotNullExpr,
- IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue,
NthValueAgg, Ntile,
- OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, TryCastExpr,
- WindowShift,
+ ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr,
CaseExpr,
+ CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, DistinctBitXor,
Grouping,
+ InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr,
NotExpr,
+ NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType,
RowNumber,
+ StringAgg, TryCastExpr, WindowShift,
};
use datafusion::physical_plan::udaf::AggregateFunctionExpr;
use datafusion::physical_plan::windows::{BuiltInWindowExpr,
PlainAggregateWindowExpr};
@@ -270,13 +269,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) ->
Result<AggrFn> {
protobuf::AggregateFunction::Avg
} else if aggr_expr.downcast_ref::<Correlation>().is_some() {
protobuf::AggregateFunction::Correlation
- } else if aggr_expr.downcast_ref::<ApproxPercentileCont>().is_some() {
- protobuf::AggregateFunction::ApproxPercentileCont
- } else if aggr_expr
- .downcast_ref::<ApproxPercentileContWithWeight>()
- .is_some()
- {
- protobuf::AggregateFunction::ApproxPercentileContWithWeight
} else if aggr_expr.downcast_ref::<StringAgg>().is_some() {
protobuf::AggregateFunction::StringAgg
} else if aggr_expr.downcast_ref::<NthValueAgg>().is_some() {
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index d0f1c4aade..a496e22685 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -26,7 +26,6 @@ use arrow::datatypes::{
DataType, Field, Fields, Int32Type, IntervalDayTimeType,
IntervalMonthDayNanoType,
IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode,
};
-use datafusion_functions_aggregate::count::count_udaf;
use prost::Message;
use datafusion::datasource::provider::TableProviderFactory;
@@ -34,10 +33,11 @@ use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::execution::FunctionRegistry;
-use datafusion::functions_aggregate::approx_median::approx_median;
+use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::expr_fn::{
- count, count_distinct, covar_pop, covar_samp, first_value, median, stddev,
- stddev_pop, sum, var_pop, var_sample,
+ approx_median, approx_percentile_cont, approx_percentile_cont_with_weight,
count,
+ count_distinct, covar_pop, covar_samp, first_value, median, stddev,
stddev_pop, sum,
+ var_pop, var_sample,
};
use datafusion::prelude::*;
use datafusion::test_util::{TestTableFactory, TestTableProvider};
@@ -663,6 +663,8 @@ async fn roundtrip_expr_api() -> Result<()> {
stddev(lit(2.2)),
stddev_pop(lit(2.2)),
approx_median(lit(2)),
+ approx_percentile_cont(lit(2), lit(0.5)),
+ approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)),
];
// ensure expressions created with the expr api can be round tripped
@@ -1799,21 +1801,6 @@ fn roundtrip_count_distinct() {
roundtrip_expr_test(test_expr, ctx);
}
-#[test]
-fn roundtrip_approx_percentile_cont() {
- let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
- AggregateFunction::ApproxPercentileCont,
- vec![col("bananas"), lit(0.42_f32)],
- false,
- None,
- None,
- None,
- ));
-
- let ctx = SessionContext::new();
- roundtrip_expr_test(test_expr, ctx);
-}
-
#[test]
fn roundtrip_aggregate_udf() {
#[derive(Debug)]
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index e517482f1d..7f66cdbf76 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -303,6 +303,7 @@ fn roundtrip_window() -> Result<()> {
&args,
&[],
&[],
+ &[],
&schema,
"SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING",
false,
@@ -458,6 +459,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
&[col("b", &schema)?],
&[],
&[],
+ &[],
&schema,
"example_agg",
false,
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index 7ba1893bb1..0a6def3d6f 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -76,26 +76,26 @@ statement error DataFusion error: Schema error: Schema
contains duplicate unqual
SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar))
count_c9_str FROM aggregate_test_100
# csv_query_approx_percentile_cont_with_weight
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Utf8,
Int8, Float64\)'. You might need to add explicit type casts.
+statement error DataFusion error: Error during planning: Error during
planning: Coercion from \[Utf8, Int8, Float64\] to the signature OneOf(.*)
failed(.|\n)*
SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16,
Utf8, Float64\)'\. You might need to add explicit type casts\.
+statement error DataFusion error: Error during planning: Error during
planning: Coercion from \[Int16, Utf8, Float64\] to the signature OneOf(.*)
failed(.|\n)*
SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16,
Int8, Utf8\)'\. You might need to add explicit type casts\.
+statement error DataFusion error: Error during planning: Error during
planning: Coercion from \[Int16, Int8, Utf8\] to the signature OneOf(.*)
failed(.|\n)*
SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100
# csv_query_approx_percentile_cont_with_histogram_bins
-statement error This feature is not implemented: Tdigest max_size value for
'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\).
+statement error DataFusion error: External error: This feature is not
implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt >
0 literal \(got data type Int64\)\.
SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM
aggregate_test_100 GROUP BY 1 ORDER BY 1
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64,
Utf8\)'\. You might need to add explicit type casts\.
+statement error DataFusion error: Error during planning: Error during
planning: Coercion from \[Int16, Float64, Utf8\] to the signature OneOf(.*)
failed(.|\n)*
SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64,
Float64\)'\. You might need to add explicit type casts\.
+statement error DataFusion error: Error during planning: Error during
planning: Coercion from \[Int16, Float64, Float64\] to the signature OneOf(.*)
failed(.|\n)*
SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100
-statement error DataFusion error: Error during planning: No function matches
the given name and argument types 'APPROX_PERCENTILE_CONT\(Float64, Float64,
Float64\)'\. You might need to add explicit type casts\.
+statement error DataFusion error: Error during planning: Error during
planning: Coercion from \[Float64, Float64, Float64\] to the signature
OneOf(.*) failed(.|\n)*
SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100
# array agg can use order by
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]