This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 3773fb7fb5 Convert `approx_median` to UDAF (#10840)
3773fb7fb5 is described below
commit 3773fb7fb54419f889e7d18b73e9eb48069eb08e
Author: Jax Liu <[email protected]>
AuthorDate: Mon Jun 10 14:45:15 2024 +0800
Convert `approx_median` to UDAF (#10840)
* move tdigest to physical-expr-common
* move approx_percentile_cont_accumulator to function-aggregate
* implement approx_meidan udaf
* remove approx_median aggregation function
* fix sqllogictests
* add removed type tests
* cargo fmt and clippy
* add logical roundtrip test
* fix dataframe test
* fix test and proto gen
* update lock in datafusion-cli
* fix typo
* fix test and doc
* fix sql_integration
* cargo fmt
* follow the checking style like other udaf
* add comment and modified dependency
* update lock and fmt
* add missing test annotation
---
datafusion-cli/Cargo.lock | 20 +-
.../core/tests/dataframe/dataframe_functions.rs | 3 +-
datafusion/expr/src/aggregate_function.rs | 8 +-
datafusion/expr/src/expr_fn.rs | 12 -
datafusion/expr/src/type_coercion/aggregates.rs | 10 -
.../functions-aggregate/src/approx_median.rs | 129 +++++++++++
.../src/approx_percentile_cont.rs | 255 +++++++++++++++++++++
datafusion/functions-aggregate/src/lib.rs | 5 +
datafusion/functions-aggregate/src/sum.rs | 2 +-
.../physical-expr-common/src/aggregate/mod.rs | 1 +
.../src/aggregate/tdigest.rs | 46 ++--
.../physical-expr/src/aggregate/approx_median.rs | 99 --------
.../src/aggregate/approx_percentile_cont.rs | 246 +-------------------
.../approx_percentile_cont_with_weight.rs | 6 +-
datafusion/physical-expr/src/aggregate/build_in.rs | 71 +-----
datafusion/physical-expr/src/aggregate/mod.rs | 2 -
datafusion/physical-expr/src/expressions/mod.rs | 1 -
datafusion/proto/proto/datafusion.proto | 2 +-
datafusion/proto/src/generated/pbjson.rs | 3 -
datafusion/proto/src/generated/prost.rs | 4 +-
datafusion/proto/src/logical_plan/from_proto.rs | 1 -
datafusion/proto/src/logical_plan/to_proto.rs | 4 -
datafusion/proto/src/physical_plan/to_proto.rs | 14 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 2 +
datafusion/sql/tests/sql_integration.rs | 12 +-
datafusion/sqllogictest/test_files/aggregate.slt | 10 +
26 files changed, 471 insertions(+), 497 deletions(-)
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 95b114ca4a..932f44d984 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -2880,9 +2880,9 @@ dependencies = [
[[package]]
name = "regex"
-version = "1.10.4"
+version = "1.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c"
+checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f"
dependencies = [
"aho-corasick",
"memchr",
@@ -2892,9 +2892,9 @@ dependencies = [
[[package]]
name = "regex-automata"
-version = "0.4.6"
+version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea"
+checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
dependencies = [
"aho-corasick",
"memchr",
@@ -2903,15 +2903,15 @@ dependencies = [
[[package]]
name = "regex-lite"
-version = "0.1.5"
+version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e"
+checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a"
[[package]]
name = "regex-syntax"
-version = "0.8.3"
+version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56"
+checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
[[package]]
name = "reqwest"
@@ -3846,9 +3846,9 @@ checksum =
"daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf8parse"
-version = "0.2.1"
+version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a"
+checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs
b/datafusion/core/tests/dataframe/dataframe_functions.rs
index 7d155bb16c..b05769a6ce 100644
--- a/datafusion/core/tests/dataframe/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe/dataframe_functions.rs
@@ -33,6 +33,7 @@ use datafusion::assert_batches_eq;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::Alias;
use datafusion_expr::ExprSchemable;
+use datafusion_functions_aggregate::expr_fn::approx_median;
fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
@@ -342,7 +343,7 @@ async fn test_fn_approx_median() -> Result<()> {
let expected = [
"+-----------------------+",
- "| APPROX_MEDIAN(test.b) |",
+ "| approx_median(test.b) |",
"+-----------------------+",
"| 10 |",
"+-----------------------+",
diff --git a/datafusion/expr/src/aggregate_function.rs
b/datafusion/expr/src/aggregate_function.rs
index 9e4f7a50ac..6227df814f 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -71,8 +71,6 @@ pub enum AggregateFunction {
ApproxPercentileCont,
/// Approximate continuous percentile function with weight
ApproxPercentileContWithWeight,
- /// ApproxMedian
- ApproxMedian,
/// Grouping
Grouping,
/// Bit And
@@ -112,7 +110,6 @@ impl AggregateFunction {
RegrSXY => "REGR_SXY",
ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
ApproxPercentileContWithWeight =>
"APPROX_PERCENTILE_CONT_WITH_WEIGHT",
- ApproxMedian => "APPROX_MEDIAN",
Grouping => "GROUPING",
BitAnd => "BIT_AND",
BitOr => "BIT_OR",
@@ -161,7 +158,6 @@ impl FromStr for AggregateFunction {
"regr_sxy" => AggregateFunction::RegrSXY,
// approximate
"approx_distinct" => AggregateFunction::ApproxDistinct,
- "approx_median" => AggregateFunction::ApproxMedian,
"approx_percentile_cont" =>
AggregateFunction::ApproxPercentileCont,
"approx_percentile_cont_with_weight" => {
AggregateFunction::ApproxPercentileContWithWeight
@@ -234,7 +230,6 @@ impl AggregateFunction {
AggregateFunction::ApproxPercentileContWithWeight => {
Ok(coerced_data_types[0].clone())
}
- AggregateFunction::ApproxMedian =>
Ok(coerced_data_types[0].clone()),
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
@@ -284,7 +279,8 @@ impl AggregateFunction {
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
Signature::uniform(1, vec![DataType::Boolean],
Volatility::Immutable)
}
- AggregateFunction::Avg | AggregateFunction::ApproxMedian => {
+
+ AggregateFunction::Avg => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2,
Volatility::Immutable),
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 0360478eac..5626d343a6 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -284,18 +284,6 @@ pub fn approx_distinct(expr: Expr) -> Expr {
))
}
-/// Calculate an approximation of the median for `expr`.
-pub fn approx_median(expr: Expr) -> Expr {
- Expr::AggregateFunction(AggregateFunction::new(
- aggregate_function::AggregateFunction::ApproxMedian,
- vec![expr],
- false,
- None,
- None,
- None,
- ))
-}
-
/// Calculate an approximation of the specified `percentile` for `expr`.
pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
diff --git a/datafusion/expr/src/type_coercion/aggregates.rs
b/datafusion/expr/src/type_coercion/aggregates.rs
index 4b4d526532..efd3c9f371 100644
--- a/datafusion/expr/src/type_coercion/aggregates.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -231,16 +231,6 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
- AggregateFunction::ApproxMedian => {
- if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
- return plan_err!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun,
- input_types[0]
- );
- }
- Ok(input_types.to_vec())
- }
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
AggregateFunction::StringAgg => {
diff --git a/datafusion/functions-aggregate/src/approx_median.rs
b/datafusion/functions-aggregate/src/approx_median.rs
new file mode 100644
index 0000000000..b8b86d3055
--- /dev/null
+++ b/datafusion/functions-aggregate/src/approx_median.rs
@@ -0,0 +1,129 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expressions for APPROX_MEDIAN that can be evaluated
MEDIAN at runtime during query execution
+
+use std::any::Any;
+use std::fmt::Debug;
+
+use arrow::{datatypes::DataType, datatypes::Field};
+use arrow_schema::DataType::{Float64, UInt64};
+
+use datafusion_common::{not_impl_err, plan_err, Result};
+use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
+use datafusion_expr::type_coercion::aggregates::NUMERICS;
+use datafusion_expr::utils::format_state_name;
+use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
+use datafusion_physical_expr_common::aggregate::utils::down_cast_any_ref;
+
+use crate::approx_percentile_cont::ApproxPercentileAccumulator;
+
+make_udaf_expr_and_func!(
+ ApproxMedian,
+ approx_median,
+ expression,
+ "Computes the approximate median of a set of numbers",
+ approx_median_udaf
+);
+
+/// APPROX_MEDIAN aggregate expression
+pub struct ApproxMedian {
+ signature: Signature,
+}
+
+impl Debug for ApproxMedian {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ f.debug_struct("ApproxMedian")
+ .field("name", &self.name())
+ .field("signature", &self.signature)
+ .finish()
+ }
+}
+
+impl Default for ApproxMedian {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl ApproxMedian {
+ /// Create a new APPROX_MEDIAN aggregate function
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::uniform(1, NUMERICS.to_vec(),
Volatility::Immutable),
+ }
+ }
+}
+
+impl AggregateUDFImpl for ApproxMedian {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+ Ok(vec![
+ Field::new(format_state_name(args.name, "max_size"), UInt64,
false),
+ Field::new(format_state_name(args.name, "sum"), Float64, false),
+ Field::new(format_state_name(args.name, "count"), Float64, false),
+ Field::new(format_state_name(args.name, "max"), Float64, false),
+ Field::new(format_state_name(args.name, "min"), Float64, false),
+ Field::new_list(
+ format_state_name(args.name, "centroids"),
+ Field::new("item", Float64, true),
+ false,
+ ),
+ ])
+ }
+
+ fn name(&self) -> &str {
+ "approx_median"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ if !arg_types[0].is_numeric() {
+ return plan_err!("ApproxMedian requires numeric input types");
+ }
+ Ok(arg_types[0].clone())
+ }
+
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ if acc_args.is_distinct {
+ return not_impl_err!(
+ "APPROX_MEDIAN(DISTINCT) aggregations are not available"
+ );
+ }
+
+ Ok(Box::new(ApproxPercentileAccumulator::new(
+ 0.5_f64,
+ acc_args.input_type.clone(),
+ )))
+ }
+}
+
+impl PartialEq<dyn Any> for ApproxMedian {
+ fn eq(&self, other: &dyn Any) -> bool {
+ down_cast_any_ref(other)
+ .downcast_ref::<Self>()
+ .map(|x| self.signature == x.signature)
+ .unwrap_or(false)
+ }
+}
diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs
b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
new file mode 100644
index 0000000000..e75417efc6
--- /dev/null
+++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
@@ -0,0 +1,255 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::{
+ array::{
+ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array,
+ Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
+ },
+ datatypes::DataType,
+};
+
+use datafusion_common::{downcast_value, internal_err, DataFusionError,
ScalarValue};
+use datafusion_expr::Accumulator;
+use datafusion_physical_expr_common::aggregate::tdigest::{
+ TDigest, TryIntoF64, DEFAULT_MAX_SIZE,
+};
+
+#[derive(Debug)]
+pub struct ApproxPercentileAccumulator {
+ digest: TDigest,
+ percentile: f64,
+ return_type: DataType,
+}
+
+impl ApproxPercentileAccumulator {
+ pub fn new(percentile: f64, return_type: DataType) -> Self {
+ Self {
+ digest: TDigest::new(DEFAULT_MAX_SIZE),
+ percentile,
+ return_type,
+ }
+ }
+
+ pub fn new_with_max_size(
+ percentile: f64,
+ return_type: DataType,
+ max_size: usize,
+ ) -> Self {
+ Self {
+ digest: TDigest::new(max_size),
+ percentile,
+ return_type,
+ }
+ }
+
+ // public for approx_percentile_cont_with_weight
+ pub fn merge_digests(&mut self, digests: &[TDigest]) {
+ let digests = digests.iter().chain(std::iter::once(&self.digest));
+ self.digest = TDigest::merge_digests(digests)
+ }
+
+ // public for approx_percentile_cont_with_weight
+ pub fn convert_to_float(values: &ArrayRef) ->
datafusion_common::Result<Vec<f64>> {
+ match values.data_type() {
+ DataType::Float64 => {
+ let array = downcast_value!(values, Float64Array);
+ Ok(array
+ .values()
+ .iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<datafusion_common::Result<Vec<_>>>()?)
+ }
+ DataType::Float32 => {
+ let array = downcast_value!(values, Float32Array);
+ Ok(array
+ .values()
+ .iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<datafusion_common::Result<Vec<_>>>()?)
+ }
+ DataType::Int64 => {
+ let array = downcast_value!(values, Int64Array);
+ Ok(array
+ .values()
+ .iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<datafusion_common::Result<Vec<_>>>()?)
+ }
+ DataType::Int32 => {
+ let array = downcast_value!(values, Int32Array);
+ Ok(array
+ .values()
+ .iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<datafusion_common::Result<Vec<_>>>()?)
+ }
+ DataType::Int16 => {
+ let array = downcast_value!(values, Int16Array);
+ Ok(array
+ .values()
+ .iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<datafusion_common::Result<Vec<_>>>()?)
+ }
+ DataType::Int8 => {
+ let array = downcast_value!(values, Int8Array);
+ Ok(array
+ .values()
+ .iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<datafusion_common::Result<Vec<_>>>()?)
+ }
+ DataType::UInt64 => {
+ let array = downcast_value!(values, UInt64Array);
+ Ok(array
+ .values()
+ .iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<datafusion_common::Result<Vec<_>>>()?)
+ }
+ DataType::UInt32 => {
+ let array = downcast_value!(values, UInt32Array);
+ Ok(array
+ .values()
+ .iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<datafusion_common::Result<Vec<_>>>()?)
+ }
+ DataType::UInt16 => {
+ let array = downcast_value!(values, UInt16Array);
+ Ok(array
+ .values()
+ .iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<datafusion_common::Result<Vec<_>>>()?)
+ }
+ DataType::UInt8 => {
+ let array = downcast_value!(values, UInt8Array);
+ Ok(array
+ .values()
+ .iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<datafusion_common::Result<Vec<_>>>()?)
+ }
+ e => internal_err!(
+ "APPROX_PERCENTILE_CONT is not expected to receive the type
{e:?}"
+ ),
+ }
+ }
+}
+
+impl Accumulator for ApproxPercentileAccumulator {
+ fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
+ Ok(self.digest.to_scalar_state().into_iter().collect())
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) ->
datafusion_common::Result<()> {
+ let values = &values[0];
+ let sorted_values = &arrow::compute::sort(values, None)?;
+ let sorted_values =
ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
+ self.digest = self.digest.merge_sorted_f64(&sorted_values);
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
+ if self.digest.count() == 0.0 {
+ return ScalarValue::try_from(self.return_type.clone());
+ }
+ let q = self.digest.estimate_quantile(self.percentile);
+
+ // These acceptable return types MUST match the validation in
+ // ApproxPercentile::create_accumulator.
+ Ok(match &self.return_type {
+ DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
+ DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
+ DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
+ DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
+ DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
+ DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
+ DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
+ DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
+ DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
+ DataType::Float64 => ScalarValue::Float64(Some(q)),
+ v => unreachable!("unexpected return type {:?}", v),
+ })
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) ->
datafusion_common::Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ }
+
+ let states = (0..states[0].len())
+ .map(|index| {
+ states
+ .iter()
+ .map(|array| ScalarValue::try_from_array(array, index))
+ .collect::<datafusion_common::Result<Vec<_>>>()
+ .map(|state| TDigest::from_scalar_state(&state))
+ })
+ .collect::<datafusion_common::Result<Vec<_>>>()?;
+
+ self.merge_digests(&states);
+
+ Ok(())
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self) + self.digest.size()
+ - std::mem::size_of_val(&self.digest)
+ + self.return_type.size()
+ - std::mem::size_of_val(&self.return_type)
+ }
+
+ fn supports_retract_batch(&self) -> bool {
+ true
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use arrow_schema::DataType;
+
+ use datafusion_physical_expr_common::aggregate::tdigest::TDigest;
+
+ use crate::approx_percentile_cont::ApproxPercentileAccumulator;
+
+ #[test]
+ fn test_combine_approx_percentile_accumulator() {
+ let mut digests: Vec<TDigest> = Vec::new();
+
+ // one TDigest with 50_000 values from 1 to 1_000
+ for _ in 1..=50 {
+ let t = TDigest::new(100);
+ let values: Vec<_> = (1..=1_000).map(f64::from).collect();
+ let t = t.merge_unsorted_f64(values);
+ digests.push(t)
+ }
+
+ let t1 = TDigest::merge_digests(&digests);
+ let t2 = TDigest::merge_digests(&digests);
+
+ let mut accumulator =
+ ApproxPercentileAccumulator::new_with_max_size(0.5,
DataType::Float64, 100);
+
+ accumulator.merge_digests(&[t1]);
+ assert_eq!(accumulator.digest.count(), 50_000.0);
+ accumulator.merge_digests(&[t2]);
+ assert_eq!(accumulator.digest.count(), 100_000.0);
+ }
+}
diff --git a/datafusion/functions-aggregate/src/lib.rs
b/datafusion/functions-aggregate/src/lib.rs
index b8a2e7032a..274ab8302e 100644
--- a/datafusion/functions-aggregate/src/lib.rs
+++ b/datafusion/functions-aggregate/src/lib.rs
@@ -62,6 +62,9 @@ pub mod stddev;
pub mod sum;
pub mod variance;
+pub mod approx_median;
+pub mod approx_percentile_cont;
+
use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
use datafusion_expr::AggregateUDF;
@@ -70,6 +73,7 @@ use std::sync::Arc;
/// Fluent-style API for creating `Expr`s
pub mod expr_fn {
+ pub use super::approx_median::approx_median;
pub use super::covariance::covar_pop;
pub use super::covariance::covar_samp;
pub use super::first_last::first_value;
@@ -95,6 +99,7 @@ pub fn all_default_aggregate_functions() ->
Vec<Arc<AggregateUDF>> {
variance::var_pop_udaf(),
stddev::stddev_udaf(),
stddev::stddev_pop_udaf(),
+ approx_median::approx_median_udaf(),
]
}
diff --git a/datafusion/functions-aggregate/src/sum.rs
b/datafusion/functions-aggregate/src/sum.rs
index 9d3fa25222..b9293bc2ca 100644
--- a/datafusion/functions-aggregate/src/sum.rs
+++ b/datafusion/functions-aggregate/src/sum.rs
@@ -46,7 +46,7 @@ make_udaf_expr_and_func!(
Sum,
sum,
expression,
- "Returns the first value in a group of values.",
+ "Returns the sum of a group of values.",
sum_udaf
);
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index 2273418c60..ec02df57b8 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -17,6 +17,7 @@
pub mod groups_accumulator;
pub mod stats;
+pub mod tdigest;
pub mod utils;
use arrow::datatypes::{DataType, Field, Schema};
diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs
b/datafusion/physical-expr-common/src/aggregate/tdigest.rs
similarity index 95%
rename from datafusion/physical-expr/src/aggregate/tdigest.rs
rename to datafusion/physical-expr-common/src/aggregate/tdigest.rs
index e3b23b91d0..5107d0ab8e 100644
--- a/datafusion/physical-expr/src/aggregate/tdigest.rs
+++ b/datafusion/physical-expr-common/src/aggregate/tdigest.rs
@@ -28,7 +28,7 @@
//! [Facebook's Folly TDigest]:
https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h
use arrow::datatypes::DataType;
-use arrow_array::types::Float64Type;
+use arrow::datatypes::Float64Type;
use datafusion_common::cast::as_primitive_array;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
@@ -50,7 +50,7 @@ macro_rules! cast_scalar_f64 {
/// This trait is implemented for each type a [`TDigest`] can operate on,
/// allowing it to support both numerical rust types (obtained from
/// `PrimitiveArray` instances), and [`ScalarValue`] instances.
-pub(crate) trait TryIntoF64 {
+pub trait TryIntoF64 {
/// A fallible conversion of a possibly null `self` into a [`f64`].
///
/// If `self` is null, this method must return `Ok(None)`.
@@ -84,7 +84,7 @@ impl_try_ordered_f64!(u8);
/// Centroid implementation to the cluster mentioned in the paper.
#[derive(Debug, PartialEq, Clone)]
-pub(crate) struct Centroid {
+pub struct Centroid {
mean: f64,
weight: f64,
}
@@ -104,21 +104,21 @@ impl Ord for Centroid {
}
impl Centroid {
- pub(crate) fn new(mean: f64, weight: f64) -> Self {
+ pub fn new(mean: f64, weight: f64) -> Self {
Centroid { mean, weight }
}
#[inline]
- pub(crate) fn mean(&self) -> f64 {
+ pub fn mean(&self) -> f64 {
self.mean
}
#[inline]
- pub(crate) fn weight(&self) -> f64 {
+ pub fn weight(&self) -> f64 {
self.weight
}
- pub(crate) fn add(&mut self, sum: f64, weight: f64) -> f64 {
+ pub fn add(&mut self, sum: f64, weight: f64) -> f64 {
let new_sum = sum + self.weight * self.mean;
let new_weight = self.weight + weight;
self.weight = new_weight;
@@ -138,7 +138,7 @@ impl Default for Centroid {
/// T-Digest to be operated on.
#[derive(Debug, PartialEq, Clone)]
-pub(crate) struct TDigest {
+pub struct TDigest {
centroids: Vec<Centroid>,
max_size: usize,
sum: f64,
@@ -148,7 +148,7 @@ pub(crate) struct TDigest {
}
impl TDigest {
- pub(crate) fn new(max_size: usize) -> Self {
+ pub fn new(max_size: usize) -> Self {
TDigest {
centroids: Vec::new(),
max_size,
@@ -159,7 +159,7 @@ impl TDigest {
}
}
- pub(crate) fn new_with_centroid(max_size: usize, centroid: Centroid) ->
Self {
+ pub fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self {
TDigest {
centroids: vec![centroid.clone()],
max_size,
@@ -171,27 +171,27 @@ impl TDigest {
}
#[inline]
- pub(crate) fn count(&self) -> f64 {
+ pub fn count(&self) -> f64 {
self.count
}
#[inline]
- pub(crate) fn max(&self) -> f64 {
+ pub fn max(&self) -> f64 {
self.max
}
#[inline]
- pub(crate) fn min(&self) -> f64 {
+ pub fn min(&self) -> f64 {
self.min
}
#[inline]
- pub(crate) fn max_size(&self) -> usize {
+ pub fn max_size(&self) -> usize {
self.max_size
}
/// Size in bytes including `Self`.
- pub(crate) fn size(&self) -> usize {
+ pub fn size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<Centroid>() * self.centroids.capacity())
}
@@ -228,14 +228,14 @@ impl TDigest {
v.clamp(lo, hi)
}
- #[cfg(test)]
- pub(crate) fn merge_unsorted_f64(&self, unsorted_values: Vec<f64>) ->
TDigest {
+ // public for testing in other modules
+ pub fn merge_unsorted_f64(&self, unsorted_values: Vec<f64>) -> TDigest {
let mut values = unsorted_values;
values.sort_by(|a, b| a.total_cmp(b));
self.merge_sorted_f64(&values)
}
- pub(crate) fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest {
+ pub fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest {
#[cfg(debug_assertions)]
debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest");
@@ -370,9 +370,7 @@ impl TDigest {
}
// Merge multiple T-Digests
- pub(crate) fn merge_digests<'a>(
- digests: impl IntoIterator<Item = &'a TDigest>,
- ) -> TDigest {
+ pub fn merge_digests<'a>(digests: impl IntoIterator<Item = &'a TDigest>)
-> TDigest {
let digests = digests.into_iter().collect::<Vec<_>>();
let n_centroids: usize = digests.iter().map(|d|
d.centroids.len()).sum();
if n_centroids == 0 {
@@ -465,7 +463,7 @@ impl TDigest {
}
/// To estimate the value located at `q` quantile
- pub(crate) fn estimate_quantile(&self, q: f64) -> f64 {
+ pub fn estimate_quantile(&self, q: f64) -> f64 {
if self.centroids.is_empty() {
return 0.0;
}
@@ -569,7 +567,7 @@ impl TDigest {
/// The [`TDigest::from_scalar_state()`] method reverses this processes,
/// consuming the output of this method and returning an unpacked
/// [`TDigest`].
- pub(crate) fn to_scalar_state(&self) -> Vec<ScalarValue> {
+ pub fn to_scalar_state(&self) -> Vec<ScalarValue> {
// Gather up all the centroids
let centroids: Vec<ScalarValue> = self
.centroids
@@ -598,7 +596,7 @@ impl TDigest {
/// Providing input to this method that was not obtained from
/// [`Self::to_scalar_state()`] results in undefined behaviour and may
/// panic.
- pub(crate) fn from_scalar_state(state: &[ScalarValue]) -> Self {
+ pub fn from_scalar_state(state: &[ScalarValue]) -> Self {
assert_eq!(state.len(), 6, "invalid TDigest state");
let max_size = match &state[0] {
diff --git a/datafusion/physical-expr/src/aggregate/approx_median.rs
b/datafusion/physical-expr/src/aggregate/approx_median.rs
deleted file mode 100644
index cbbfef5a89..0000000000
--- a/datafusion/physical-expr/src/aggregate/approx_median.rs
+++ /dev/null
@@ -1,99 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Defines physical expressions for APPROX_MEDIAN that can be evaluated
MEDIAN at runtime during query execution
-
-use crate::aggregate::utils::down_cast_any_ref;
-use crate::expressions::{lit, ApproxPercentileCont};
-use crate::{AggregateExpr, PhysicalExpr};
-use arrow::{datatypes::DataType, datatypes::Field};
-use datafusion_common::Result;
-use datafusion_expr::Accumulator;
-use std::any::Any;
-use std::sync::Arc;
-
-/// MEDIAN aggregate expression
-#[derive(Debug)]
-pub struct ApproxMedian {
- name: String,
- expr: Arc<dyn PhysicalExpr>,
- data_type: DataType,
- approx_percentile: ApproxPercentileCont,
-}
-
-impl ApproxMedian {
- /// Create a new APPROX_MEDIAN aggregate function
- pub fn try_new(
- expr: Arc<dyn PhysicalExpr>,
- name: impl Into<String>,
- data_type: DataType,
- ) -> Result<Self> {
- let name: String = name.into();
- let approx_percentile = ApproxPercentileCont::new(
- vec![expr.clone(), lit(0.5_f64)],
- name.clone(),
- data_type.clone(),
- )?;
- Ok(Self {
- name,
- expr,
- data_type,
- approx_percentile,
- })
- }
-}
-
-impl AggregateExpr for ApproxMedian {
- /// Return a reference to Any that can be used for downcasting
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, self.data_type.clone(), true))
- }
-
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- self.approx_percentile.create_accumulator()
- }
-
- fn state_fields(&self) -> Result<Vec<Field>> {
- self.approx_percentile.state_fields()
- }
-
- fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- vec![self.expr.clone()]
- }
-
- fn name(&self) -> &str {
- &self.name
- }
-}
-
-impl PartialEq<dyn Any> for ApproxMedian {
- fn eq(&self, other: &dyn Any) -> bool {
- down_cast_any_ref(other)
- .downcast_ref::<Self>()
- .map(|x| {
- self.name == x.name
- && self.data_type == x.data_type
- && self.expr.eq(&x.expr)
- && self.approx_percentile == x.approx_percentile
- })
- .unwrap_or(false)
- }
-}
diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
index 63a4c85f9e..f2068bbc92 100644
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
@@ -15,26 +15,19 @@
// specific language governing permissions and limitations
// under the License.
-use crate::aggregate::tdigest::TryIntoF64;
-use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE};
-use crate::aggregate::utils::down_cast_any_ref;
-use crate::expressions::format_state_name;
-use crate::{AggregateExpr, PhysicalExpr};
-use arrow::{
- array::{
- ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array,
- Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
- },
- datatypes::{DataType, Field},
-};
+use std::{any::Any, sync::Arc};
+
+use arrow::datatypes::{DataType, Field};
use arrow_array::RecordBatch;
use arrow_schema::Schema;
-use datafusion_common::{
- downcast_value, internal_err, not_impl_err, plan_err, DataFusionError,
Result,
- ScalarValue,
-};
+
+use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result,
ScalarValue};
use datafusion_expr::{Accumulator, ColumnarValue};
-use std::{any::Any, sync::Arc};
+use
datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator;
+
+use crate::aggregate::utils::down_cast_any_ref;
+use crate::expressions::format_state_name;
+use crate::{AggregateExpr, PhysicalExpr};
/// APPROX_PERCENTILE_CONT aggregate expression
#[derive(Debug)]
@@ -195,7 +188,7 @@ impl AggregateExpr for ApproxPercentileCont {
}
#[allow(rustdoc::private_intra_doc_links)]
- /// See [`TDigest::to_scalar_state()`] for a description of the serialised
+ /// See
[`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`]
for a description of the serialised
/// state.
fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![
@@ -254,220 +247,3 @@ impl PartialEq<dyn Any> for ApproxPercentileCont {
.unwrap_or(false)
}
}
-
-#[derive(Debug)]
-pub struct ApproxPercentileAccumulator {
- digest: TDigest,
- percentile: f64,
- return_type: DataType,
-}
-
-impl ApproxPercentileAccumulator {
- pub fn new(percentile: f64, return_type: DataType) -> Self {
- Self {
- digest: TDigest::new(DEFAULT_MAX_SIZE),
- percentile,
- return_type,
- }
- }
-
- pub fn new_with_max_size(
- percentile: f64,
- return_type: DataType,
- max_size: usize,
- ) -> Self {
- Self {
- digest: TDigest::new(max_size),
- percentile,
- return_type,
- }
- }
-
- pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
- let digests = digests.iter().chain(std::iter::once(&self.digest));
- self.digest = TDigest::merge_digests(digests)
- }
-
- pub(crate) fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
- match values.data_type() {
- DataType::Float64 => {
- let array = downcast_value!(values, Float64Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::<Result<Vec<_>>>()?)
- }
- DataType::Float32 => {
- let array = downcast_value!(values, Float32Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::<Result<Vec<_>>>()?)
- }
- DataType::Int64 => {
- let array = downcast_value!(values, Int64Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::<Result<Vec<_>>>()?)
- }
- DataType::Int32 => {
- let array = downcast_value!(values, Int32Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::<Result<Vec<_>>>()?)
- }
- DataType::Int16 => {
- let array = downcast_value!(values, Int16Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::<Result<Vec<_>>>()?)
- }
- DataType::Int8 => {
- let array = downcast_value!(values, Int8Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::<Result<Vec<_>>>()?)
- }
- DataType::UInt64 => {
- let array = downcast_value!(values, UInt64Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::<Result<Vec<_>>>()?)
- }
- DataType::UInt32 => {
- let array = downcast_value!(values, UInt32Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::<Result<Vec<_>>>()?)
- }
- DataType::UInt16 => {
- let array = downcast_value!(values, UInt16Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::<Result<Vec<_>>>()?)
- }
- DataType::UInt8 => {
- let array = downcast_value!(values, UInt8Array);
- Ok(array
- .values()
- .iter()
- .filter_map(|v| v.try_as_f64().transpose())
- .collect::<Result<Vec<_>>>()?)
- }
- e => internal_err!(
- "APPROX_PERCENTILE_CONT is not expected to receive the type
{e:?}"
- ),
- }
- }
-}
-
-impl Accumulator for ApproxPercentileAccumulator {
- fn state(&mut self) -> Result<Vec<ScalarValue>> {
- Ok(self.digest.to_scalar_state().into_iter().collect())
- }
-
- fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let values = &values[0];
- let sorted_values = &arrow::compute::sort(values, None)?;
- let sorted_values =
ApproxPercentileAccumulator::convert_to_float(sorted_values)?;
- self.digest = self.digest.merge_sorted_f64(&sorted_values);
- Ok(())
- }
-
- fn evaluate(&mut self) -> Result<ScalarValue> {
- if self.digest.count() == 0.0 {
- return ScalarValue::try_from(self.return_type.clone());
- }
- let q = self.digest.estimate_quantile(self.percentile);
-
- // These acceptable return types MUST match the validation in
- // ApproxPercentile::create_accumulator.
- Ok(match &self.return_type {
- DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
- DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
- DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
- DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
- DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
- DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
- DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
- DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
- DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
- DataType::Float64 => ScalarValue::Float64(Some(q)),
- v => unreachable!("unexpected return type {:?}", v),
- })
- }
-
- fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
- if states.is_empty() {
- return Ok(());
- }
-
- let states = (0..states[0].len())
- .map(|index| {
- states
- .iter()
- .map(|array| ScalarValue::try_from_array(array, index))
- .collect::<Result<Vec<_>>>()
- .map(|state| TDigest::from_scalar_state(&state))
- })
- .collect::<Result<Vec<_>>>()?;
-
- self.merge_digests(&states);
-
- Ok(())
- }
-
- fn size(&self) -> usize {
- std::mem::size_of_val(self) + self.digest.size()
- - std::mem::size_of_val(&self.digest)
- + self.return_type.size()
- - std::mem::size_of_val(&self.return_type)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use crate::aggregate::approx_percentile_cont::ApproxPercentileAccumulator;
- use crate::aggregate::tdigest::TDigest;
- use arrow_schema::DataType;
-
- #[test]
- fn test_combine_approx_percentile_accumulator() {
- let mut digests: Vec<TDigest> = Vec::new();
-
- // one TDigest with 50_000 values from 1 to 1_000
- for _ in 1..=50 {
- let t = TDigest::new(100);
- let values: Vec<_> = (1..=1_000).map(f64::from).collect();
- let t = t.merge_unsorted_f64(values);
- digests.push(t)
- }
-
- let t1 = TDigest::merge_digests(&digests);
- let t2 = TDigest::merge_digests(&digests);
-
- let mut accumulator =
- ApproxPercentileAccumulator::new_with_max_size(0.5,
DataType::Float64, 100);
-
- accumulator.merge_digests(&[t1]);
- assert_eq!(accumulator.digest.count(), 50_000.0);
- accumulator.merge_digests(&[t2]);
- assert_eq!(accumulator.digest.count(), 100_000.0);
- }
-}
diff --git
a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
index 3fa715a592..07c2aff343 100644
---
a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
+++
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs
@@ -15,14 +15,16 @@
// specific language governing permissions and limitations
// under the License.
-use crate::aggregate::approx_percentile_cont::ApproxPercentileAccumulator;
-use crate::aggregate::tdigest::{Centroid, TDigest, DEFAULT_MAX_SIZE};
use crate::expressions::ApproxPercentileCont;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::{
array::ArrayRef,
datatypes::{DataType, Field},
};
+use
datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator;
+use datafusion_physical_expr_common::aggregate::tdigest::{
+ Centroid, TDigest, DEFAULT_MAX_SIZE,
+};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index f0cff53fb3..89de6ad49c 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -280,18 +280,6 @@ pub fn create_aggregate_expr(
"approx_percentile_cont_with_weight(DISTINCT) aggregations are
not available"
);
}
- (AggregateFunction::ApproxMedian, false) => {
- Arc::new(expressions::ApproxMedian::try_new(
- input_phy_exprs[0].clone(),
- name,
- data_type,
- )?)
- }
- (AggregateFunction::ApproxMedian, true) => {
- return not_impl_err!(
- "APPROX_MEDIAN(DISTINCT) aggregations are not available"
- );
- }
(AggregateFunction::NthValue, _) => {
let expr = &input_phy_exprs[0];
let Some(n) = input_phy_exprs[1]
@@ -337,9 +325,8 @@ mod tests {
use super::*;
use crate::expressions::{
- try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont,
ArrayAgg, Avg,
- BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg,
DistinctCount,
- Max, Min,
+ try_cast, ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, BitAnd,
BitOr,
+ BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, DistinctCount, Max,
Min,
};
use datafusion_common::{plan_err, DataFusionError, ScalarValue};
@@ -686,60 +673,6 @@ mod tests {
Ok(())
}
- #[test]
- fn test_median_expr() -> Result<()> {
- let funcs = vec![AggregateFunction::ApproxMedian];
- let data_types = vec![
- DataType::UInt32,
- DataType::UInt64,
- DataType::Int32,
- DataType::Int64,
- DataType::Float32,
- DataType::Float64,
- ];
- for fun in funcs {
- for data_type in &data_types {
- let input_schema =
- Schema::new(vec![Field::new("c1", data_type.clone(),
true)]);
- let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
- expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
- )];
- let result_agg_phy_exprs = create_physical_agg_expr_for_test(
- &fun,
- false,
- &input_phy_exprs[0..1],
- &input_schema,
- "c1",
- )?;
-
- if fun == AggregateFunction::ApproxMedian {
-
assert!(result_agg_phy_exprs.as_any().is::<ApproxMedian>());
- assert_eq!("c1", result_agg_phy_exprs.name());
- assert_eq!(
- Field::new("c1", data_type.clone(), true),
- result_agg_phy_exprs.field().unwrap()
- );
- }
- }
- }
- Ok(())
- }
-
- #[test]
- fn test_median() -> Result<()> {
- let observed =
AggregateFunction::ApproxMedian.return_type(&[DataType::Utf8]);
- assert!(observed.is_err());
-
- let observed =
AggregateFunction::ApproxMedian.return_type(&[DataType::Int32])?;
- assert_eq!(DataType::Int32, observed);
-
- let observed =
-
AggregateFunction::ApproxMedian.return_type(&[DataType::Decimal128(10, 6)]);
- assert!(observed.is_err());
-
- Ok(())
- }
-
#[test]
fn test_min_max() -> Result<()> {
let observed = AggregateFunction::Min.return_type(&[DataType::Utf8])?;
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs
b/datafusion/physical-expr/src/aggregate/mod.rs
index 2c14c1550e..9db80f155a 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -18,10 +18,8 @@
pub use datafusion_physical_expr_common::aggregate::AggregateExpr;
mod hyperloglog;
-mod tdigest;
pub(crate) mod approx_distinct;
-pub(crate) mod approx_median;
pub(crate) mod approx_percentile_cont;
pub(crate) mod approx_percentile_cont_with_weight;
pub(crate) mod array_agg;
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index 476cbe3907..656cc570ca 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -38,7 +38,6 @@ pub mod helpers {
}
pub use crate::aggregate::approx_distinct::ApproxDistinct;
-pub use crate::aggregate::approx_median::ApproxMedian;
pub use crate::aggregate::approx_percentile_cont::ApproxPercentileCont;
pub use
crate::aggregate::approx_percentile_cont_with_weight::ApproxPercentileContWithWeight;
pub use crate::aggregate::array_agg::ArrayAgg;
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 0071a43bbe..9f23824b3a 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -487,7 +487,7 @@ enum AggregateFunction {
// STDDEV_POP = 12;
CORRELATION = 13;
APPROX_PERCENTILE_CONT = 14;
- APPROX_MEDIAN = 15;
+ // APPROX_MEDIAN = 15;
APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
GROUPING = 17;
// MEDIAN = 18;
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index e6aded8901..28f80c5ee1 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -540,7 +540,6 @@ impl serde::Serialize for AggregateFunction {
Self::ArrayAgg => "ARRAY_AGG",
Self::Correlation => "CORRELATION",
Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
- Self::ApproxMedian => "APPROX_MEDIAN",
Self::ApproxPercentileContWithWeight =>
"APPROX_PERCENTILE_CONT_WITH_WEIGHT",
Self::Grouping => "GROUPING",
Self::BitAnd => "BIT_AND",
@@ -578,7 +577,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"ARRAY_AGG",
"CORRELATION",
"APPROX_PERCENTILE_CONT",
- "APPROX_MEDIAN",
"APPROX_PERCENTILE_CONT_WITH_WEIGHT",
"GROUPING",
"BIT_AND",
@@ -645,7 +643,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction {
"ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg),
"CORRELATION" => Ok(AggregateFunction::Correlation),
"APPROX_PERCENTILE_CONT" =>
Ok(AggregateFunction::ApproxPercentileCont),
- "APPROX_MEDIAN" => Ok(AggregateFunction::ApproxMedian),
"APPROX_PERCENTILE_CONT_WITH_WEIGHT" =>
Ok(AggregateFunction::ApproxPercentileContWithWeight),
"GROUPING" => Ok(AggregateFunction::Grouping),
"BIT_AND" => Ok(AggregateFunction::BitAnd),
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 7ec9187491..9741b2bc42 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1931,7 +1931,7 @@ pub enum AggregateFunction {
/// STDDEV_POP = 12;
Correlation = 13,
ApproxPercentileCont = 14,
- ApproxMedian = 15,
+ /// APPROX_MEDIAN = 15;
ApproxPercentileContWithWeight = 16,
Grouping = 17,
/// MEDIAN = 18;
@@ -1967,7 +1967,6 @@ impl AggregateFunction {
AggregateFunction::ArrayAgg => "ARRAY_AGG",
AggregateFunction::Correlation => "CORRELATION",
AggregateFunction::ApproxPercentileCont =>
"APPROX_PERCENTILE_CONT",
- AggregateFunction::ApproxMedian => "APPROX_MEDIAN",
AggregateFunction::ApproxPercentileContWithWeight => {
"APPROX_PERCENTILE_CONT_WITH_WEIGHT"
}
@@ -2001,7 +2000,6 @@ impl AggregateFunction {
"ARRAY_AGG" => Some(Self::ArrayAgg),
"CORRELATION" => Some(Self::Correlation),
"APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont),
- "APPROX_MEDIAN" => Some(Self::ApproxMedian),
"APPROX_PERCENTILE_CONT_WITH_WEIGHT" => {
Some(Self::ApproxPercentileContWithWeight)
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index a77d361983..5c083fa27a 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -164,7 +164,6 @@ impl From<protobuf::AggregateFunction> for
AggregateFunction {
protobuf::AggregateFunction::ApproxPercentileContWithWeight => {
Self::ApproxPercentileContWithWeight
}
- protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian,
protobuf::AggregateFunction::Grouping => Self::Grouping,
protobuf::AggregateFunction::NthValueAgg => Self::NthValue,
protobuf::AggregateFunction::StringAgg => Self::StringAgg,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 9c4c7685b3..e2259896b2 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -133,7 +133,6 @@ impl From<&AggregateFunction> for
protobuf::AggregateFunction {
AggregateFunction::ApproxPercentileContWithWeight => {
Self::ApproxPercentileContWithWeight
}
- AggregateFunction::ApproxMedian => Self::ApproxMedian,
AggregateFunction::Grouping => Self::Grouping,
AggregateFunction::NthValue => Self::NthValueAgg,
AggregateFunction::StringAgg => Self::StringAgg,
@@ -430,9 +429,6 @@ pub fn serialize_expr(
AggregateFunction::RegrSXX =>
protobuf::AggregateFunction::RegrSxx,
AggregateFunction::RegrSYY =>
protobuf::AggregateFunction::RegrSyy,
AggregateFunction::RegrSXY =>
protobuf::AggregateFunction::RegrSxy,
- AggregateFunction::ApproxMedian => {
- protobuf::AggregateFunction::ApproxMedian
- }
AggregateFunction::Grouping =>
protobuf::AggregateFunction::Grouping,
AggregateFunction::NthValue => {
protobuf::AggregateFunction::NthValueAgg
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index 5d07d5c0fa..19ba4a40d5 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -23,12 +23,12 @@ use
datafusion::datasource::file_format::parquet::ParquetSink;
use datafusion::physical_expr::window::{NthValueKind,
SlidingAggregateWindowExpr};
use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr};
use datafusion::physical_plan::expressions::{
- ApproxDistinct, ApproxMedian, ApproxPercentileCont,
ApproxPercentileContWithWeight,
- ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr,
CaseExpr,
- CastExpr, Column, Correlation, Count, CumeDist, DistinctArrayAgg,
DistinctBitXor,
- DistinctCount, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal,
Max, Min,
- NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile,
OrderSensitiveArrayAgg, Rank,
- RankType, Regr, RegrType, RowNumber, StringAgg, TryCastExpr, WindowShift,
+ ApproxDistinct, ApproxPercentileCont, ApproxPercentileContWithWeight,
ArrayAgg, Avg,
+ BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr,
Column,
+ Correlation, Count, CumeDist, DistinctArrayAgg, DistinctBitXor,
DistinctCount,
+ Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min,
NegativeExpr,
+ NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank,
RankType, Regr,
+ RegrType, RowNumber, StringAgg, TryCastExpr, WindowShift,
};
use datafusion::physical_plan::udaf::AggregateFunctionExpr;
use datafusion::physical_plan::windows::{BuiltInWindowExpr,
PlainAggregateWindowExpr};
@@ -296,8 +296,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) ->
Result<AggrFn> {
.is_some()
{
protobuf::AggregateFunction::ApproxPercentileContWithWeight
- } else if aggr_expr.downcast_ref::<ApproxMedian>().is_some() {
- protobuf::AggregateFunction::ApproxMedian
} else if aggr_expr.downcast_ref::<StringAgg>().is_some() {
protobuf::AggregateFunction::StringAgg
} else if aggr_expr.downcast_ref::<NthValueAgg>().is_some() {
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index b1cad69b14..699697dd2f 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -33,6 +33,7 @@ use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::execution::FunctionRegistry;
+use datafusion::functions_aggregate::approx_median::approx_median;
use datafusion::functions_aggregate::expr_fn::{
covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum,
var_pop,
var_sample,
@@ -658,6 +659,7 @@ async fn roundtrip_expr_api() -> Result<()> {
var_pop(lit(2.2)),
stddev(lit(2.2)),
stddev_pop(lit(2.2)),
+ approx_median(lit(2)),
];
// ensure expressions created with the expr api can be round tripped
diff --git a/datafusion/sql/tests/sql_integration.rs
b/datafusion/sql/tests/sql_integration.rs
index 6a99f9719d..7b9d39a2b5 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -37,6 +37,7 @@ use datafusion_sql::{
planner::{ParserOptions, SqlToRel},
};
+use datafusion_functions_aggregate::approx_median::approx_median_udaf;
use rstest::rstest;
use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect};
@@ -1649,8 +1650,8 @@ fn select_count_column() {
#[test]
fn select_approx_median() {
let sql = "SELECT approx_median(age) FROM person";
- let expected = "Projection: APPROX_MEDIAN(person.age)\
- \n Aggregate: groupBy=[[]],
aggr=[[APPROX_MEDIAN(person.age)]]\
+ let expected = "Projection: approx_median(person.age)\
+ \n Aggregate: groupBy=[[]],
aggr=[[approx_median(person.age)]]\
\n TableScan: person";
quick_test(sql, expected);
}
@@ -2581,8 +2582,8 @@ fn approx_median_window() {
let sql =
"SELECT order_id, APPROX_MEDIAN(qty) OVER(PARTITION BY order_id) from
orders";
let expected = "\
- Projection: orders.order_id, APPROX_MEDIAN(orders.qty) PARTITION BY
[orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\
- \n WindowAggr: windowExpr=[[APPROX_MEDIAN(orders.qty) PARTITION BY
[orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
+ Projection: orders.order_id, approx_median(orders.qty) PARTITION BY
[orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\
+ \n WindowAggr: windowExpr=[[approx_median(orders.qty) PARTITION BY
[orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: orders";
quick_test(sql, expected);
}
@@ -2700,7 +2701,8 @@ fn logical_plan_with_dialect_and_options(
DataType::Int32,
))
.with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64))
- .with_udaf(sum_udaf());
+ .with_udaf(sum_udaf())
+ .with_udaf(approx_median_udaf());
let planner = SqlToRel::new_with_options(&context, options);
let result = DFParser::parse_sql_with_dialect(sql, dialect);
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index 9958f8ac38..a245793ebd 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -518,6 +518,11 @@ SELECT approx_median(c12) FROM aggregate_test_100
----
0.555006541052
+# csv_query_approx_median_4
+# test with string, approx median only supports numeric
+statement error
+SELECT approx_median(c1) FROM aggregate_test_100
+
# csv_query_median_1
query I
SELECT median(c2) FROM aggregate_test_100
@@ -637,6 +642,11 @@ select median(c), arrow_typeof(median(c)) from t;
----
0.0003 Decimal128(10, 4)
+query RT
+select approx_median(c), arrow_typeof(approx_median(c)) from t;
+----
+0.00035 Float64
+
statement ok
drop table t;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]