This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new cfb655d approx_quantile() aggregation function (#1539)
cfb655d is described below
commit cfb655dc09013d161ef15d9502718998a6c4f86e
Author: Dom <[email protected]>
AuthorDate: Mon Jan 31 20:41:56 2022 +0000
approx_quantile() aggregation function (#1539)
* feat: implement TDigest for approx quantile
Adds a [TDigest] implementation providing approximate quantile
estimations of large inputs using a small amount of (bounded) memory.
A TDigest is most accurate near either "end" of the quantile range (that
is, 0.1, 0.9, 0.95, etc) due to the use of a scalaing function that
increases resolution at the tails. The paper claims single digit part
per million errors for q ≤ 0.001 or q ≥ 0.999 using 100 centroids, and
in practice I have found accuracy to be more than acceptable for an
apprixmate function across the entire quantile range.
The implementation is a modified copy of
https://github.com/MnO2/t-digest, itself a Rust port of [Facebook's C++
implementation]. Both Facebook's implementation, and Mn02's Rust port
are Apache 2.0 licensed.
[TDigest]: https://arxiv.org/abs/1902.04023
[Facebook's C++ implementation]:
https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h
* feat: approx_quantile aggregation
Adds the ApproxQuantile physical expression, plumbing & test cases.
The function signature is:
approx_quantile(column, quantile)
Where column can be any numeric type (that can be cast to a float64) and
quantile is a float64 literal between 0 and 1.
* feat: approx_quantile dataframe function
Adds the approx_quantile() dataframe function, and exports it in the
prelude.
* refactor: bastilla approx_quantile support
Adds bastilla wire encoding for approx_quantile.
Adding support for this required modifying the AggregateExprNode proto
message to support propigating multiple LogicalExprNode aggregate
arguments - all the existing aggregations take a single argument, so
this wasn't needed before.
This commit adds "repeated" to the expr field, which I believe is
backwards compatible as described here:
https://developers.google.com/protocol-buffers/docs/proto3#updating
Specifically, adding "repeated" to an existing message field:
"For ... message fields, optional is compatible with repeated"
No existing tests needed fixing, and a new roundtrip test is included
that covers the change to allow multiple expr.
* refactor: use input type as return type
Casts the calculated quantile value to the same type as the input data.
* fixup! refactor: bastilla approx_quantile support
* refactor: rebase onto main
* refactor: validate quantile value
Ensures the quantile values is between 0 and 1, emitting a plan error if
not.
* refactor: rename to approx_percentile_cont
* refactor: clippy lints
---
ballista/rust/core/proto/ballista.proto | 3 +-
.../rust/core/src/serde/logical_plan/from_proto.rs | 6 +-
ballista/rust/core/src/serde/logical_plan/mod.rs | 21 +-
.../rust/core/src/serde/logical_plan/to_proto.rs | 14 +-
ballista/rust/core/src/serde/mod.rs | 3 +
datafusion/src/logical_plan/expr.rs | 9 +
datafusion/src/logical_plan/mod.rs | 14 +-
datafusion/src/physical_plan/aggregates.rs | 87 ++-
.../physical_plan/coercion_rule/aggregate_rule.rs | 118 ++-
.../expressions/approx_percentile_cont.rs | 313 ++++++++
datafusion/src/physical_plan/expressions/mod.rs | 4 +
datafusion/src/physical_plan/mod.rs | 1 +
datafusion/src/physical_plan/tdigest/mod.rs | 818 +++++++++++++++++++++
datafusion/src/prelude.rs | 12 +-
datafusion/tests/dataframe_functions.rs | 20 +
datafusion/tests/sql/aggregates.rs | 89 +++
16 files changed, 1485 insertions(+), 47 deletions(-)
diff --git a/ballista/rust/core/proto/ballista.proto
b/ballista/rust/core/proto/ballista.proto
index 15a7342..fb006e5 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -176,11 +176,12 @@ enum AggregateFunction {
STDDEV=11;
STDDEV_POP=12;
CORRELATION=13;
+ APPROX_PERCENTILE_CONT = 14;
}
message AggregateExprNode {
AggregateFunction aggr_function = 1;
- LogicalExprNode expr = 2;
+ repeated LogicalExprNode expr = 2;
}
enum BuiltInWindowFunction {
diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
index 5684855..044f823 100644
--- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
@@ -1065,7 +1065,11 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
Ok(Expr::AggregateFunction {
fun,
- args: vec![parse_required_expr(&expr.expr)?],
+ args: expr
+ .expr
+ .iter()
+ .map(|e| e.try_into())
+ .collect::<Result<Vec<_>, _>>()?,
distinct: false, //TODO
})
}
diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs
b/ballista/rust/core/src/serde/logical_plan/mod.rs
index c09b8a5..c00e3e4 100644
--- a/ballista/rust/core/src/serde/logical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/logical_plan/mod.rs
@@ -24,16 +24,14 @@ mod roundtrip_tests {
use super::super::{super::error::Result, protobuf};
use crate::error::BallistaError;
use core::panic;
- use datafusion::arrow::datatypes::UnionMode;
- use datafusion::logical_plan::Repartition;
use datafusion::{
- arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit},
+ arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit,
UnionMode},
datasource::object_store::local::LocalFileSystem,
logical_plan::{
col, CreateExternalTable, Expr, LogicalPlan, LogicalPlanBuilder,
- Partitioning, ToDFSchema,
+ Partitioning, Repartition, ToDFSchema,
},
- physical_plan::functions::BuiltinScalarFunction::Sqrt,
+ physical_plan::{aggregates, functions::BuiltinScalarFunction::Sqrt},
prelude::*,
scalar::ScalarValue,
sql::parser::FileType,
@@ -1001,4 +999,17 @@ mod roundtrip_tests {
Ok(())
}
+
+ #[test]
+ fn roundtrip_approx_percentile_cont() -> Result<()> {
+ let test_expr = Expr::AggregateFunction {
+ fun: aggregates::AggregateFunction::ApproxPercentileCont,
+ args: vec![col("bananas"), lit(0.42)],
+ distinct: false,
+ };
+
+ roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr);
+
+ Ok(())
+ }
}
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index eb5d810..4b13ce5 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -1074,6 +1074,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
AggregateFunction::ApproxDistinct => {
protobuf::AggregateFunction::ApproxDistinct
}
+ AggregateFunction::ApproxPercentileCont => {
+ protobuf::AggregateFunction::ApproxPercentileCont
+ }
AggregateFunction::ArrayAgg =>
protobuf::AggregateFunction::ArrayAgg,
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
@@ -1099,11 +1102,13 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
}
};
- let arg = &args[0];
- let aggregate_expr = Box::new(protobuf::AggregateExprNode {
+ let aggregate_expr = protobuf::AggregateExprNode {
aggr_function: aggr_function.into(),
- expr: Some(Box::new(arg.try_into()?)),
- });
+ expr: args
+ .iter()
+ .map(|v| v.try_into())
+ .collect::<Result<Vec<_>, _>>()?,
+ };
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),
})
@@ -1334,6 +1339,7 @@ impl From<&AggregateFunction> for
protobuf::AggregateFunction {
AggregateFunction::Stddev => Self::Stddev,
AggregateFunction::StddevPop => Self::StddevPop,
AggregateFunction::Correlation => Self::Correlation,
+ AggregateFunction::ApproxPercentileCont =>
Self::ApproxPercentileCont,
}
}
}
diff --git a/ballista/rust/core/src/serde/mod.rs
b/ballista/rust/core/src/serde/mod.rs
index 4026273..64a60dc 100644
--- a/ballista/rust/core/src/serde/mod.rs
+++ b/ballista/rust/core/src/serde/mod.rs
@@ -129,6 +129,9 @@ impl From<protobuf::AggregateFunction> for
AggregateFunction {
protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev,
protobuf::AggregateFunction::StddevPop =>
AggregateFunction::StddevPop,
protobuf::AggregateFunction::Correlation =>
AggregateFunction::Correlation,
+ protobuf::AggregateFunction::ApproxPercentileCont => {
+ AggregateFunction::ApproxPercentileCont
+ }
}
}
}
diff --git a/datafusion/src/logical_plan/expr.rs
b/datafusion/src/logical_plan/expr.rs
index 98c2969..a1e51e0 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -1647,6 +1647,15 @@ pub fn approx_distinct(expr: Expr) -> Expr {
}
}
+/// Calculate an approximation of the specified `percentile` for `expr`.
+pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
+ Expr::AggregateFunction {
+ fun: aggregates::AggregateFunction::ApproxPercentileCont,
+ distinct: false,
+ args: vec![expr, percentile],
+ }
+}
+
// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many
// varying arity functions
/// Create an convenience function representing a unary scalar function
diff --git a/datafusion/src/logical_plan/mod.rs
b/datafusion/src/logical_plan/mod.rs
index 56fec3c..06c6bf9 100644
--- a/datafusion/src/logical_plan/mod.rs
+++ b/datafusion/src/logical_plan/mod.rs
@@ -36,13 +36,13 @@ pub use builder::{
pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema};
pub use display::display_schema;
pub use expr::{
- abs, acos, and, approx_distinct, array, ascii, asin, atan, avg,
binary_expr,
- bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr,
- combine_filters, concat, concat_ws, cos, count, count_distinct,
create_udaf,
- create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor,
in_list,
- initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower,
lpad, ltrim,
- max, md5, min, normalize_col, normalize_cols, now, octet_length, or,
random,
- regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
+ abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii,
asin, atan,
+ avg, binary_expr, bit_length, btrim, case, ceil, character_length, chr,
col,
+ columnize_expr, combine_filters, concat, concat_ws, cos, count,
count_distinct,
+ create_udaf, create_udf, date_part, date_trunc, digest, exp,
exprlist_to_fields,
+ floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10,
log2,
+ lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now,
octet_length,
+ or, random, regexp_match, regexp_replace, repeat, replace, replace_col,
reverse,
rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256,
sha384, sha512,
signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan,
to_hex,
translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper,
when,
diff --git a/datafusion/src/physical_plan/aggregates.rs
b/datafusion/src/physical_plan/aggregates.rs
index c40fd71..8fc94d3 100644
--- a/datafusion/src/physical_plan/aggregates.rs
+++ b/datafusion/src/physical_plan/aggregates.rs
@@ -27,7 +27,7 @@
//! * Return type: a function `(arg_types) -> return_type`. E.g. for min,
([f32]) -> f32, ([f64]) -> f64.
use super::{
- functions::{Signature, Volatility},
+ functions::{Signature, TypeSignature, Volatility},
Accumulator, AggregateExpr, PhysicalExpr,
};
use crate::error::{DataFusionError, Result};
@@ -80,6 +80,8 @@ pub enum AggregateFunction {
CovariancePop,
/// Correlation
Correlation,
+ /// Approximate continuous percentile function
+ ApproxPercentileCont,
}
impl fmt::Display for AggregateFunction {
@@ -110,6 +112,7 @@ impl FromStr for AggregateFunction {
"covar_samp" => AggregateFunction::Covariance,
"covar_pop" => AggregateFunction::CovariancePop,
"corr" => AggregateFunction::Correlation,
+ "approx_percentile_cont" =>
AggregateFunction::ApproxPercentileCont,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
@@ -157,6 +160,7 @@ pub fn return_type(
coerced_data_types[0].clone(),
true,
)))),
+ AggregateFunction::ApproxPercentileCont =>
Ok(coerced_data_types[0].clone()),
}
}
@@ -331,6 +335,20 @@ pub fn create_aggregate_expr(
"CORR(DISTINCT) aggregations are not available".to_string(),
));
}
+ (AggregateFunction::ApproxPercentileCont, false) => {
+ Arc::new(expressions::ApproxPercentileCont::new(
+ // Pass in the desired percentile expr
+ coerced_phy_exprs,
+ name,
+ return_type,
+ )?)
+ }
+ (AggregateFunction::ApproxPercentileCont, true) => {
+ return Err(DataFusionError::NotImplemented(
+ "approx_percentile_cont(DISTINCT) aggregations are not
available"
+ .to_string(),
+ ));
+ }
})
}
@@ -389,17 +407,25 @@ pub(super) fn signature(fun: &AggregateFunction) ->
Signature {
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
+ AggregateFunction::ApproxPercentileCont => Signature::one_of(
+ // Accept any numeric value paired with a float64 percentile
+ NUMERICS
+ .iter()
+ .map(|t| TypeSignature::Exact(vec![t.clone(),
DataType::Float64]))
+ .collect(),
+ Volatility::Immutable,
+ ),
}
}
#[cfg(test)]
mod tests {
use super::*;
- use crate::error::Result;
use crate::physical_plan::expressions::{
- ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance,
DistinctArrayAgg,
- DistinctCount, Max, Min, Stddev, Sum, Variance,
+ ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation,
Count,
+ Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum,
Variance,
};
+ use crate::{error::Result, scalar::ScalarValue};
#[test]
fn test_count_arragg_approx_expr() -> Result<()> {
@@ -514,6 +540,59 @@ mod tests {
}
#[test]
+ fn test_agg_approx_percentile_phy_expr() {
+ for data_type in NUMERICS {
+ let input_schema =
+ Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
+ let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
+ Arc::new(
+ expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
+ ),
+
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))),
+ ];
+ let result_agg_phy_exprs = create_aggregate_expr(
+ &AggregateFunction::ApproxPercentileCont,
+ false,
+ &input_phy_exprs[..],
+ &input_schema,
+ "c1",
+ )
+ .expect("failed to create aggregate expr");
+
+
assert!(result_agg_phy_exprs.as_any().is::<ApproxPercentileCont>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new("c1", data_type.clone(), false),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
+ }
+
+ #[test]
+ fn test_agg_approx_percentile_invalid_phy_expr() {
+ for data_type in NUMERICS {
+ let input_schema =
+ Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
+ let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
+ Arc::new(
+ expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
+ ),
+
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))),
+ ];
+ let err = create_aggregate_expr(
+ &AggregateFunction::ApproxPercentileCont,
+ false,
+ &input_phy_exprs[..],
+ &input_schema,
+ "c1",
+ )
+ .expect_err("should fail due to invalid percentile");
+
+ assert!(matches!(err, DataFusionError::Plan(_)));
+ }
+ }
+
+ #[test]
fn test_min_max_expr() -> Result<()> {
let funcs = vec![AggregateFunction::Min, AggregateFunction::Max];
let data_types = vec![
diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
index c151fb7..bae2de7 100644
--- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
+++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
@@ -17,7 +17,6 @@
//! Support the coercion rule for aggregate function.
-use crate::arrow::datatypes::Schema;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::aggregates::AggregateFunction;
use crate::physical_plan::expressions::{
@@ -27,6 +26,10 @@ use crate::physical_plan::expressions::{
};
use crate::physical_plan::functions::{Signature, TypeSignature};
use crate::physical_plan::PhysicalExpr;
+use crate::{
+ arrow::datatypes::Schema,
+ physical_plan::expressions::is_approx_percentile_cont_supported_arg_type,
+};
use arrow::datatypes::DataType;
use std::ops::Deref;
use std::sync::Arc;
@@ -38,24 +41,9 @@ pub(crate) fn coerce_types(
input_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
- match signature.type_signature {
- TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count)
=> {
- if input_types.len() != agg_count {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} expects {:?} arguments, but {:?} were
provided",
- agg_fun,
- agg_count,
- input_types.len()
- )));
- }
- }
- _ => {
- return Err(DataFusionError::Internal(format!(
- "Aggregate functions do not support this {:?}",
- signature
- )));
- }
- };
+ // Validate input_types matches (at least one of) the func signature.
+ check_arg_count(agg_fun, input_types, &signature.type_signature)?;
+
match agg_fun {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Ok(input_types.to_vec())
@@ -151,7 +139,75 @@ pub(crate) fn coerce_types(
}
Ok(input_types.to_vec())
}
+ AggregateFunction::ApproxPercentileCont => {
+ if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
+ return Err(DataFusionError::Plan(format!(
+ "The function {:?} does not support inputs of type {:?}.",
+ agg_fun, input_types[0]
+ )));
+ }
+ if !matches!(input_types[1], DataType::Float64) {
+ return Err(DataFusionError::Plan(format!(
+ "The percentile argument for {:?} must be Float64, not
{:?}.",
+ agg_fun, input_types[1]
+ )));
+ }
+ Ok(input_types.to_vec())
+ }
+ }
+}
+
+/// Validate the length of `input_types` matches the `signature` for `agg_fun`.
+///
+/// This method DOES NOT validate the argument types - only that (at least one,
+/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
+/// number of input types.
+fn check_arg_count(
+ agg_fun: &AggregateFunction,
+ input_types: &[DataType],
+ signature: &TypeSignature,
+) -> Result<()> {
+ match signature {
+ TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count)
=> {
+ if input_types.len() != *agg_count {
+ return Err(DataFusionError::Plan(format!(
+ "The function {:?} expects {:?} arguments, but {:?} were
provided",
+ agg_fun,
+ agg_count,
+ input_types.len()
+ )));
+ }
+ }
+ TypeSignature::Exact(types) => {
+ if types.len() != input_types.len() {
+ return Err(DataFusionError::Plan(format!(
+ "The function {:?} expects {:?} arguments, but {:?} were
provided",
+ agg_fun,
+ types.len(),
+ input_types.len()
+ )));
+ }
+ }
+ TypeSignature::OneOf(variants) => {
+ let ok = variants
+ .iter()
+ .any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
+ if !ok {
+ return Err(DataFusionError::Plan(format!(
+ "The function {:?} does not accept {:?} function
arguments.",
+ agg_fun,
+ input_types.len()
+ )));
+ }
+ }
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "Aggregate functions do not support this {:?}",
+ signature
+ )));
+ }
}
+ Ok(())
}
fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
@@ -267,5 +323,29 @@ mod tests {
assert_eq!(*input_type, result.unwrap());
}
}
+
+ // ApproxPercentileCont input types
+ let input_types = vec![
+ vec![DataType::Int8, DataType::Float64],
+ vec![DataType::Int16, DataType::Float64],
+ vec![DataType::Int32, DataType::Float64],
+ vec![DataType::Int64, DataType::Float64],
+ vec![DataType::UInt8, DataType::Float64],
+ vec![DataType::UInt16, DataType::Float64],
+ vec![DataType::UInt32, DataType::Float64],
+ vec![DataType::UInt64, DataType::Float64],
+ vec![DataType::Float32, DataType::Float64],
+ vec![DataType::Float64, DataType::Float64],
+ ];
+ for input_type in &input_types {
+ let signature =
+
aggregates::signature(&AggregateFunction::ApproxPercentileCont);
+ let result = coerce_types(
+ &AggregateFunction::ApproxPercentileCont,
+ input_type,
+ &signature,
+ );
+ assert_eq!(*input_type, result.unwrap());
+ }
}
}
diff --git a/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs
b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs
new file mode 100644
index 0000000..cba30ee
--- /dev/null
+++ b/datafusion/src/physical_plan/expressions/approx_percentile_cont.rs
@@ -0,0 +1,313 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{any::Any, iter, sync::Arc};
+
+use arrow::{
+ array::{
+ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array,
+ Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
+ },
+ datatypes::{DataType, Field},
+};
+
+use crate::{
+ error::DataFusionError,
+ physical_plan::{tdigest::TDigest, Accumulator, AggregateExpr,
PhysicalExpr},
+ scalar::ScalarValue,
+};
+
+use crate::error::Result;
+
+use super::{format_state_name, Literal};
+
+/// Return `true` if `arg_type` is of a [`DataType`] that the
+/// [`ApproxPercentileCont`] aggregation can operate on.
+pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) ->
bool {
+ matches!(
+ arg_type,
+ DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ | DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::Float32
+ | DataType::Float64
+ )
+}
+
+/// APPROX_PERCENTILE_CONT aggregate expression
+#[derive(Debug)]
+pub struct ApproxPercentileCont {
+ name: String,
+ input_data_type: DataType,
+ expr: Arc<dyn PhysicalExpr>,
+ percentile: f64,
+}
+
+impl ApproxPercentileCont {
+ /// Create a new [`ApproxPercentileCont`] aggregate function.
+ pub fn new(
+ expr: Vec<Arc<dyn PhysicalExpr>>,
+ name: impl Into<String>,
+ input_data_type: DataType,
+ ) -> Result<Self> {
+ // Arguments should be [ColumnExpr, DesiredPercentileLiteral]
+ debug_assert_eq!(expr.len(), 2);
+
+ // Extract the desired percentile literal
+ let lit = expr[1]
+ .as_any()
+ .downcast_ref::<Literal>()
+ .ok_or_else(|| {
+ DataFusionError::Internal(
+ "desired percentile argument must be float
literal".to_string(),
+ )
+ })?
+ .value();
+ let percentile = match lit {
+ ScalarValue::Float32(Some(q)) => *q as f64,
+ ScalarValue::Float64(Some(q)) => *q as f64,
+ got => return Err(DataFusionError::NotImplemented(format!(
+ "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32
or Float64 literal (got data type {})",
+ got
+ )))
+ };
+
+ // Ensure the percentile is between 0 and 1.
+ if !(0.0..=1.0).contains(&percentile) {
+ return Err(DataFusionError::Plan(format!(
+ "Percentile value must be between 0.0 and 1.0 inclusive, {} is
invalid",
+ percentile
+ )));
+ }
+
+ Ok(Self {
+ name: name.into(),
+ input_data_type,
+ // The physical expr to evaluate during accumulation
+ expr: expr[0].clone(),
+ percentile,
+ })
+ }
+}
+
+impl AggregateExpr for ApproxPercentileCont {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(&self.name, self.input_data_type.clone(), false))
+ }
+
+ /// See [`TDigest::to_scalar_state()`] for a description of the serialised
+ /// state.
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![
+ Field::new(
+ &format_state_name(&self.name, "max_size"),
+ DataType::UInt64,
+ false,
+ ),
+ Field::new(
+ &format_state_name(&self.name, "sum"),
+ DataType::Float64,
+ false,
+ ),
+ Field::new(
+ &format_state_name(&self.name, "count"),
+ DataType::Float64,
+ false,
+ ),
+ Field::new(
+ &format_state_name(&self.name, "max"),
+ DataType::Float64,
+ false,
+ ),
+ Field::new(
+ &format_state_name(&self.name, "min"),
+ DataType::Float64,
+ false,
+ ),
+ Field::new(
+ &format_state_name(&self.name, "centroids"),
+ DataType::List(Box::new(Field::new("item", DataType::Float64,
true))),
+ false,
+ ),
+ ])
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr.clone()]
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ let accumulator: Box<dyn Accumulator> = match &self.input_data_type {
+ t @ (DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ | DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::Float32
+ | DataType::Float64) => {
+ Box::new(ApproxPercentileAccumulator::new(self.percentile,
t.clone()))
+ }
+ other => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Support for 'APPROX_PERCENTILE_CONT' for data type {} is
not implemented",
+ other
+ )))
+ }
+ };
+ Ok(accumulator)
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+}
+
+#[derive(Debug)]
+pub struct ApproxPercentileAccumulator {
+ digest: TDigest,
+ percentile: f64,
+ return_type: DataType,
+}
+
+impl ApproxPercentileAccumulator {
+ pub fn new(percentile: f64, return_type: DataType) -> Self {
+ Self {
+ digest: TDigest::new(100),
+ percentile,
+ return_type,
+ }
+ }
+}
+
+impl Accumulator for ApproxPercentileAccumulator {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ Ok(self.digest.to_scalar_state())
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ debug_assert_eq!(
+ values.len(),
+ 1,
+ "invalid number of values in batch percentile update"
+ );
+ let values = &values[0];
+
+ self.digest = match values.data_type() {
+ DataType::Float64 => {
+ let array =
values.as_any().downcast_ref::<Float64Array>().unwrap();
+ self.digest.merge_unsorted(array.values().iter().cloned())?
+ }
+ DataType::Float32 => {
+ let array =
values.as_any().downcast_ref::<Float32Array>().unwrap();
+ self.digest.merge_unsorted(array.values().iter().cloned())?
+ }
+ DataType::Int64 => {
+ let array =
values.as_any().downcast_ref::<Int64Array>().unwrap();
+ self.digest.merge_unsorted(array.values().iter().cloned())?
+ }
+ DataType::Int32 => {
+ let array =
values.as_any().downcast_ref::<Int32Array>().unwrap();
+ self.digest.merge_unsorted(array.values().iter().cloned())?
+ }
+ DataType::Int16 => {
+ let array =
values.as_any().downcast_ref::<Int16Array>().unwrap();
+ self.digest.merge_unsorted(array.values().iter().cloned())?
+ }
+ DataType::Int8 => {
+ let array =
values.as_any().downcast_ref::<Int8Array>().unwrap();
+ self.digest.merge_unsorted(array.values().iter().cloned())?
+ }
+ DataType::UInt64 => {
+ let array =
values.as_any().downcast_ref::<UInt64Array>().unwrap();
+ self.digest.merge_unsorted(array.values().iter().cloned())?
+ }
+ DataType::UInt32 => {
+ let array =
values.as_any().downcast_ref::<UInt32Array>().unwrap();
+ self.digest.merge_unsorted(array.values().iter().cloned())?
+ }
+ DataType::UInt16 => {
+ let array =
values.as_any().downcast_ref::<UInt16Array>().unwrap();
+ self.digest.merge_unsorted(array.values().iter().cloned())?
+ }
+ DataType::UInt8 => {
+ let array =
values.as_any().downcast_ref::<UInt8Array>().unwrap();
+ self.digest.merge_unsorted(array.values().iter().cloned())?
+ }
+ e => {
+ return Err(DataFusionError::Internal(format!(
+ "APPROX_PERCENTILE_CONT is not expected to receive the
type {:?}",
+ e
+ )));
+ }
+ };
+
+ Ok(())
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ let q = self.digest.estimate_quantile(self.percentile);
+
+ // These acceptable return types MUST match the validation in
+ // ApproxPercentile::create_accumulator.
+ Ok(match &self.return_type {
+ DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
+ DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
+ DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
+ DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
+ DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
+ DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
+ DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
+ DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
+ DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
+ DataType::Float64 => ScalarValue::Float64(Some(q as f64)),
+ v => unreachable!("unexpected return type {:?}", v),
+ })
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ };
+
+ let states = (0..states[0].len())
+ .map(|index| {
+ states
+ .iter()
+ .map(|array| ScalarValue::try_from_array(array, index))
+ .collect::<Result<Vec<_>>>()
+ .map(|state| TDigest::from_scalar_state(&state))
+ })
+ .chain(iter::once(Ok(self.digest.clone())))
+ .collect::<Result<Vec<_>>>()?;
+
+ self.digest = TDigest::merge_digests(&states);
+
+ Ok(())
+ }
+}
diff --git a/datafusion/src/physical_plan/expressions/mod.rs
b/datafusion/src/physical_plan/expressions/mod.rs
index ca14d7f..9344fbd 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -26,6 +26,7 @@ use arrow::compute::kernels::sort::{SortColumn, SortOptions};
use arrow::record_batch::RecordBatch;
mod approx_distinct;
+mod approx_percentile_cont;
mod array_agg;
mod average;
#[macro_use]
@@ -64,6 +65,9 @@ pub mod helpers {
}
pub use approx_distinct::ApproxDistinct;
+pub use approx_percentile_cont::{
+ is_approx_percentile_cont_supported_arg_type, ApproxPercentileCont,
+};
pub use array_agg::ArrayAgg;
pub(crate) use average::is_avg_support_arg_type;
pub use average::{avg_return_type, Avg, AvgAccumulator};
diff --git a/datafusion/src/physical_plan/mod.rs
b/datafusion/src/physical_plan/mod.rs
index 24aa6ad..725e475 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -659,6 +659,7 @@ pub mod repartition;
pub mod sorts;
pub mod stream;
pub mod string_expressions;
+pub(crate) mod tdigest;
pub mod type_coercion;
pub mod udaf;
pub mod udf;
diff --git a/datafusion/src/physical_plan/tdigest/mod.rs
b/datafusion/src/physical_plan/tdigest/mod.rs
new file mode 100644
index 0000000..6780adc
--- /dev/null
+++ b/datafusion/src/physical_plan/tdigest/mod.rs
@@ -0,0 +1,818 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with this
+// work for additional information regarding copyright ownership. The ASF
+// licenses this file to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations
under
+// the License.
+
+//! An implementation of the [TDigest sketch algorithm] providing approximate
+//! quantile calculations.
+//!
+//! The TDigest code in this module is modified from
+//! https://github.com/MnO2/t-digest, itself a rust reimplementation of
+//! [Facebook's Folly TDigest] implementation.
+//!
+//! Alterations include reduction of runtime heap allocations, broader type
+//! support, (de-)serialisation support, reduced type conversions and null
value
+//! tolerance.
+//!
+//! [TDigest sketch algorithm]: https://arxiv.org/abs/1902.04023
+//! [Facebook's Folly TDigest]:
https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h
+
+use arrow::datatypes::DataType;
+use ordered_float::OrderedFloat;
+use std::cmp::Ordering;
+
+use crate::{
+ error::{DataFusionError, Result},
+ scalar::ScalarValue,
+};
+
+// Cast a non-null [`ScalarValue::Float64`] to an [`OrderedFloat<f64>`], or
+// panic.
+macro_rules! cast_scalar_f64 {
+ ($value:expr ) => {
+ match &$value {
+ ScalarValue::Float64(Some(v)) => OrderedFloat::from(*v),
+ v => panic!("invalid type {:?}", v),
+ }
+ };
+}
+
+/// This trait is implemented for each type a [`TDigest`] can operate on,
+/// allowing it to support both numerical rust types (obtained from
+/// `PrimitiveArray` instances), and [`ScalarValue`] instances.
+pub(crate) trait TryIntoOrderedF64 {
+ /// A fallible conversion of a possibly null `self` into a
[`OrderedFloat<f64>`].
+ ///
+ /// If `self` is null, this method must return `Ok(None)`.
+ ///
+ /// If `self` cannot be coerced to the desired type, this method must
return
+ /// an `Err` variant.
+ fn try_as_f64(&self) -> Result<Option<OrderedFloat<f64>>>;
+}
+
+/// Generate an infallible conversion from `type` to an [`OrderedFloat<f64>`].
+macro_rules! impl_try_ordered_f64 {
+ ($type:ty) => {
+ impl TryIntoOrderedF64 for $type {
+ fn try_as_f64(&self) -> Result<Option<OrderedFloat<f64>>> {
+ Ok(Some(OrderedFloat::from(*self as f64)))
+ }
+ }
+ };
+}
+
+impl_try_ordered_f64!(f64);
+impl_try_ordered_f64!(f32);
+impl_try_ordered_f64!(i64);
+impl_try_ordered_f64!(i32);
+impl_try_ordered_f64!(i16);
+impl_try_ordered_f64!(i8);
+impl_try_ordered_f64!(u64);
+impl_try_ordered_f64!(u32);
+impl_try_ordered_f64!(u16);
+impl_try_ordered_f64!(u8);
+
+impl TryIntoOrderedF64 for ScalarValue {
+ fn try_as_f64(&self) -> Result<Option<OrderedFloat<f64>>> {
+ match self {
+ ScalarValue::Float32(v) => Ok(v.map(|v| OrderedFloat::from(v as
f64))),
+ ScalarValue::Float64(v) => Ok(v.map(|v| OrderedFloat::from(v as
f64))),
+ ScalarValue::Int8(v) => Ok(v.map(|v| OrderedFloat::from(v as
f64))),
+ ScalarValue::Int16(v) => Ok(v.map(|v| OrderedFloat::from(v as
f64))),
+ ScalarValue::Int32(v) => Ok(v.map(|v| OrderedFloat::from(v as
f64))),
+ ScalarValue::Int64(v) => Ok(v.map(|v| OrderedFloat::from(v as
f64))),
+ ScalarValue::UInt8(v) => Ok(v.map(|v| OrderedFloat::from(v as
f64))),
+ ScalarValue::UInt16(v) => Ok(v.map(|v| OrderedFloat::from(v as
f64))),
+ ScalarValue::UInt32(v) => Ok(v.map(|v| OrderedFloat::from(v as
f64))),
+ ScalarValue::UInt64(v) => Ok(v.map(|v| OrderedFloat::from(v as
f64))),
+
+ got => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Support for 'APPROX_PERCENTILE_CONT' for data type {} is
not implemented",
+ got
+ )))
+ }
+ }
+ }
+}
+
+/// Centroid implementation to the cluster mentioned in the paper.
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub(crate) struct Centroid {
+ mean: OrderedFloat<f64>,
+ weight: OrderedFloat<f64>,
+}
+
+impl PartialOrd for Centroid {
+ fn partial_cmp(&self, other: &Centroid) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl Ord for Centroid {
+ fn cmp(&self, other: &Centroid) -> Ordering {
+ self.mean.cmp(&other.mean)
+ }
+}
+
+impl Centroid {
+ pub(crate) fn new(
+ mean: impl Into<OrderedFloat<f64>>,
+ weight: impl Into<OrderedFloat<f64>>,
+ ) -> Self {
+ Centroid {
+ mean: mean.into(),
+ weight: weight.into(),
+ }
+ }
+
+ #[inline]
+ pub(crate) fn mean(&self) -> OrderedFloat<f64> {
+ self.mean
+ }
+
+ #[inline]
+ pub(crate) fn weight(&self) -> OrderedFloat<f64> {
+ self.weight
+ }
+
+ pub(crate) fn add(
+ &mut self,
+ sum: impl Into<OrderedFloat<f64>>,
+ weight: impl Into<OrderedFloat<f64>>,
+ ) -> f64 {
+ let new_sum = sum.into() + self.weight * self.mean;
+ let new_weight = self.weight + weight.into();
+ self.weight = new_weight;
+ self.mean = new_sum / new_weight;
+ new_sum.into_inner()
+ }
+}
+
+impl Default for Centroid {
+ fn default() -> Self {
+ Centroid {
+ mean: OrderedFloat::from(0.0),
+ weight: OrderedFloat::from(1.0),
+ }
+ }
+}
+
+/// T-Digest to be operated on.
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub(crate) struct TDigest {
+ centroids: Vec<Centroid>,
+ max_size: usize,
+ sum: OrderedFloat<f64>,
+ count: OrderedFloat<f64>,
+ max: OrderedFloat<f64>,
+ min: OrderedFloat<f64>,
+}
+
+impl TDigest {
+ pub(crate) fn new(max_size: usize) -> Self {
+ TDigest {
+ centroids: Vec::new(),
+ max_size,
+ sum: OrderedFloat::from(0.0),
+ count: OrderedFloat::from(0.0),
+ max: OrderedFloat::from(std::f64::NAN),
+ min: OrderedFloat::from(std::f64::NAN),
+ }
+ }
+
+ #[inline]
+ pub(crate) fn count(&self) -> f64 {
+ self.count.into_inner()
+ }
+
+ #[inline]
+ pub(crate) fn max(&self) -> f64 {
+ self.max.into_inner()
+ }
+
+ #[inline]
+ pub(crate) fn min(&self) -> f64 {
+ self.min.into_inner()
+ }
+
+ #[inline]
+ pub(crate) fn max_size(&self) -> usize {
+ self.max_size
+ }
+}
+
+impl Default for TDigest {
+ fn default() -> Self {
+ TDigest {
+ centroids: Vec::new(),
+ max_size: 100,
+ sum: OrderedFloat::from(0.0),
+ count: OrderedFloat::from(0.0),
+ max: OrderedFloat::from(std::f64::NAN),
+ min: OrderedFloat::from(std::f64::NAN),
+ }
+ }
+}
+
+impl TDigest {
+ fn k_to_q(k: f64, d: f64) -> OrderedFloat<f64> {
+ let k_div_d = k / d;
+ if k_div_d >= 0.5 {
+ let base = 1.0 - k_div_d;
+ 1.0 - 2.0 * base * base
+ } else {
+ 2.0 * k_div_d * k_div_d
+ }
+ .into()
+ }
+
+ fn clamp(
+ v: OrderedFloat<f64>,
+ lo: OrderedFloat<f64>,
+ hi: OrderedFloat<f64>,
+ ) -> OrderedFloat<f64> {
+ if v > hi {
+ hi
+ } else if v < lo {
+ lo
+ } else {
+ v
+ }
+ }
+
+ pub(crate) fn merge_unsorted<T: TryIntoOrderedF64>(
+ &self,
+ unsorted_values: impl IntoIterator<Item = T>,
+ ) -> Result<TDigest> {
+ let mut values = unsorted_values
+ .into_iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<Result<Vec<_>>>()?;
+
+ values.sort();
+
+ Ok(self.merge_sorted_f64(&values))
+ }
+
+ fn merge_sorted_f64(&self, sorted_values: &[OrderedFloat<f64>]) -> TDigest
{
+ debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest");
+
+ if sorted_values.is_empty() {
+ return self.clone();
+ }
+
+ let mut result = TDigest::new(self.max_size());
+ result.count = OrderedFloat::from(self.count() + (sorted_values.len()
as f64));
+
+ let maybe_min = *sorted_values.first().unwrap();
+ let maybe_max = *sorted_values.last().unwrap();
+
+ if self.count() > 0.0 {
+ result.min = std::cmp::min(self.min, maybe_min);
+ result.max = std::cmp::max(self.max, maybe_max);
+ } else {
+ result.min = maybe_min;
+ result.max = maybe_max;
+ }
+
+ let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size);
+
+ let mut k_limit: f64 = 1.0;
+ let mut q_limit_times_count =
+ Self::k_to_q(k_limit, self.max_size as f64) * result.count();
+ k_limit += 1.0;
+
+ let mut iter_centroids = self.centroids.iter().peekable();
+ let mut iter_sorted_values = sorted_values.iter().peekable();
+
+ let mut curr: Centroid = if let Some(c) = iter_centroids.peek() {
+ let curr = **iter_sorted_values.peek().unwrap();
+ if c.mean() < curr {
+ iter_centroids.next().unwrap().clone()
+ } else {
+ Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
+ }
+ } else {
+ Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
+ };
+
+ let mut weight_so_far = curr.weight();
+
+ let mut sums_to_merge = OrderedFloat::from(0.0);
+ let mut weights_to_merge = OrderedFloat::from(0.0);
+
+ while iter_centroids.peek().is_some() ||
iter_sorted_values.peek().is_some() {
+ let next: Centroid = if let Some(c) = iter_centroids.peek() {
+ if iter_sorted_values.peek().is_none()
+ || c.mean() < **iter_sorted_values.peek().unwrap()
+ {
+ iter_centroids.next().unwrap().clone()
+ } else {
+ Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
+ }
+ } else {
+ Centroid::new(*iter_sorted_values.next().unwrap(), 1.0)
+ };
+
+ let next_sum = next.mean() * next.weight();
+ weight_so_far += next.weight();
+
+ if weight_so_far <= q_limit_times_count {
+ sums_to_merge += next_sum;
+ weights_to_merge += next.weight();
+ } else {
+ result.sum = OrderedFloat::from(
+ result.sum.into_inner() + curr.add(sums_to_merge,
weights_to_merge),
+ );
+ sums_to_merge = 0.0.into();
+ weights_to_merge = 0.0.into();
+
+ compressed.push(curr.clone());
+ q_limit_times_count =
+ Self::k_to_q(k_limit, self.max_size as f64) *
result.count();
+ k_limit += 1.0;
+ curr = next;
+ }
+ }
+
+ result.sum = OrderedFloat::from(
+ result.sum.into_inner() + curr.add(sums_to_merge,
weights_to_merge),
+ );
+ compressed.push(curr);
+ compressed.shrink_to_fit();
+ compressed.sort();
+
+ result.centroids = compressed;
+ result
+ }
+
+ fn external_merge(
+ centroids: &mut Vec<Centroid>,
+ first: usize,
+ middle: usize,
+ last: usize,
+ ) {
+ let mut result: Vec<Centroid> = Vec::with_capacity(centroids.len());
+
+ let mut i = first;
+ let mut j = middle;
+
+ while i < middle && j < last {
+ match centroids[i].cmp(¢roids[j]) {
+ Ordering::Less => {
+ result.push(centroids[i].clone());
+ i += 1;
+ }
+ Ordering::Greater => {
+ result.push(centroids[j].clone());
+ j += 1;
+ }
+ Ordering::Equal => {
+ result.push(centroids[i].clone());
+ i += 1;
+ }
+ }
+ }
+
+ while i < middle {
+ result.push(centroids[i].clone());
+ i += 1;
+ }
+
+ while j < last {
+ result.push(centroids[j].clone());
+ j += 1;
+ }
+
+ i = first;
+ for centroid in result.into_iter() {
+ centroids[i] = centroid;
+ i += 1;
+ }
+ }
+
+ // Merge multiple T-Digests
+ pub(crate) fn merge_digests(digests: &[TDigest]) -> TDigest {
+ let n_centroids: usize = digests.iter().map(|d|
d.centroids.len()).sum();
+ if n_centroids == 0 {
+ return TDigest::default();
+ }
+
+ let max_size = digests.first().unwrap().max_size;
+ let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
+ let mut starts: Vec<usize> = Vec::with_capacity(digests.len());
+
+ let mut count: f64 = 0.0;
+ let mut min = OrderedFloat::from(std::f64::INFINITY);
+ let mut max = OrderedFloat::from(std::f64::NEG_INFINITY);
+
+ let mut start: usize = 0;
+ for digest in digests.iter() {
+ starts.push(start);
+
+ let curr_count: f64 = digest.count();
+ if curr_count > 0.0 {
+ min = std::cmp::min(min, digest.min);
+ max = std::cmp::max(max, digest.max);
+ count += curr_count;
+ for centroid in &digest.centroids {
+ centroids.push(centroid.clone());
+ start += 1;
+ }
+ }
+ }
+
+ let mut digests_per_block: usize = 1;
+ while digests_per_block < starts.len() {
+ for i in (0..starts.len()).step_by(digests_per_block * 2) {
+ if i + digests_per_block < starts.len() {
+ let first = starts[i];
+ let middle = starts[i + digests_per_block];
+ let last = if i + 2 * digests_per_block < starts.len() {
+ starts[i + 2 * digests_per_block]
+ } else {
+ centroids.len()
+ };
+
+ debug_assert!(first <= middle && middle <= last);
+ Self::external_merge(&mut centroids, first, middle, last);
+ }
+ }
+
+ digests_per_block *= 2;
+ }
+
+ let mut result = TDigest::new(max_size);
+ let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);
+
+ let mut k_limit: f64 = 1.0;
+ let mut q_limit_times_count =
+ Self::k_to_q(k_limit, max_size as f64) * (count as f64);
+
+ let mut iter_centroids = centroids.iter_mut();
+ let mut curr = iter_centroids.next().unwrap();
+ let mut weight_so_far = curr.weight();
+ let mut sums_to_merge = OrderedFloat::from(0.0);
+ let mut weights_to_merge = OrderedFloat::from(0.0);
+
+ for centroid in iter_centroids {
+ weight_so_far += centroid.weight();
+
+ if weight_so_far <= q_limit_times_count {
+ sums_to_merge += centroid.mean() * centroid.weight();
+ weights_to_merge += centroid.weight();
+ } else {
+ result.sum = OrderedFloat::from(
+ result.sum.into_inner() + curr.add(sums_to_merge,
weights_to_merge),
+ );
+ sums_to_merge = OrderedFloat::from(0.0);
+ weights_to_merge = OrderedFloat::from(0.0);
+ compressed.push(curr.clone());
+ q_limit_times_count =
+ Self::k_to_q(k_limit, max_size as f64) * (count as f64);
+ k_limit += 1.0;
+ curr = centroid;
+ }
+ }
+
+ result.sum = OrderedFloat::from(
+ result.sum.into_inner() + curr.add(sums_to_merge,
weights_to_merge),
+ );
+ compressed.push(curr.clone());
+ compressed.shrink_to_fit();
+ compressed.sort();
+
+ result.count = OrderedFloat::from(count as f64);
+ result.min = min;
+ result.max = max;
+ result.centroids = compressed;
+ result
+ }
+
+ /// To estimate the value located at `q` quantile
+ pub(crate) fn estimate_quantile(&self, q: f64) -> f64 {
+ if self.centroids.is_empty() {
+ return 0.0;
+ }
+
+ let count_ = self.count;
+ let rank = OrderedFloat::from(q) * count_;
+
+ let mut pos: usize;
+ let mut t;
+ if q > 0.5 {
+ if q >= 1.0 {
+ return self.max();
+ }
+
+ pos = 0;
+ t = count_;
+
+ for (k, centroid) in self.centroids.iter().enumerate().rev() {
+ t -= centroid.weight();
+
+ if rank >= t {
+ pos = k;
+ break;
+ }
+ }
+ } else {
+ if q <= 0.0 {
+ return self.min();
+ }
+
+ pos = self.centroids.len() - 1;
+ t = OrderedFloat::from(0.0);
+
+ for (k, centroid) in self.centroids.iter().enumerate() {
+ if rank < t + centroid.weight() {
+ pos = k;
+ break;
+ }
+
+ t += centroid.weight();
+ }
+ }
+
+ let mut delta = OrderedFloat::from(0.0);
+ let mut min = self.min;
+ let mut max = self.max;
+
+ if self.centroids.len() > 1 {
+ if pos == 0 {
+ delta = self.centroids[pos + 1].mean() -
self.centroids[pos].mean();
+ max = self.centroids[pos + 1].mean();
+ } else if pos == (self.centroids.len() - 1) {
+ delta = self.centroids[pos].mean() - self.centroids[pos -
1].mean();
+ min = self.centroids[pos - 1].mean();
+ } else {
+ delta = (self.centroids[pos + 1].mean() - self.centroids[pos -
1].mean())
+ / 2.0;
+ min = self.centroids[pos - 1].mean();
+ max = self.centroids[pos + 1].mean();
+ }
+ }
+
+ let value = self.centroids[pos].mean()
+ + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta;
+ Self::clamp(value, min, max).into_inner()
+ }
+
+ /// This method decomposes the [`TDigest`] and its [`Centroid`] instances
+ /// into a series of primitive scalar values.
+ ///
+ /// First the values of the TDigest are packed, followed by the variable
+ /// number of centroids packed into a [`ScalarValue::List`] of
+ /// [`ScalarValue::Float64`]:
+ ///
+ /// ```text
+ ///
+ /// ┌────────┬────────┬────────┬───────┬────────┬────────┐
+ /// │max_size│ sum │ count │ max │ min │centroid│
+ /// └────────┴────────┴────────┴───────┴────────┴────────┘
+ /// │
+ /// ┌─────────────────────┘
+ /// ▼
+ /// ┌ List ───┐
+ /// │┌ ─ ─ ─ ┐│
+ /// │ mean │
+ /// │├ ─ ─ ─ ┼│─ ─ Centroid 1
+ /// │ weight │
+ /// │└ ─ ─ ─ ┘│
+ /// │ │
+ /// │┌ ─ ─ ─ ┐│
+ /// │ mean │
+ /// │├ ─ ─ ─ ┼│─ ─ Centroid 2
+ /// │ weight │
+ /// │└ ─ ─ ─ ┘│
+ /// │ │
+ /// ...
+ ///
+ /// ```
+ ///
+ /// The [`TDigest::from_scalar_state()`] method reverses this processes,
+ /// consuming the output of this method and returning an unpacked
+ /// [`TDigest`].
+ pub(crate) fn to_scalar_state(&self) -> Vec<ScalarValue> {
+ // Gather up all the centroids
+ let centroids: Vec<_> = self
+ .centroids
+ .iter()
+ .flat_map(|c| [c.mean().into_inner(), c.weight().into_inner()])
+ .map(|v| ScalarValue::Float64(Some(v)))
+ .collect();
+
+ vec![
+ ScalarValue::UInt64(Some(self.max_size as u64)),
+ ScalarValue::Float64(Some(self.sum.into_inner())),
+ ScalarValue::Float64(Some(self.count.into_inner())),
+ ScalarValue::Float64(Some(self.max.into_inner())),
+ ScalarValue::Float64(Some(self.min.into_inner())),
+ ScalarValue::List(Some(Box::new(centroids)),
Box::new(DataType::Float64)),
+ ]
+ }
+
+ /// Unpack the serialised state of a [`TDigest`] produced by
+ /// [`Self::to_scalar_state()`].
+ ///
+ /// # Correctness
+ ///
+ /// Providing input to this method that was not obtained from
+ /// [`Self::to_scalar_state()`] results in undefined behaviour and may
+ /// panic.
+ pub(crate) fn from_scalar_state(state: &[ScalarValue]) -> Self {
+ assert_eq!(state.len(), 6, "invalid TDigest state");
+
+ let max_size = match &state[0] {
+ ScalarValue::UInt64(Some(v)) => *v as usize,
+ v => panic!("invalid max_size type {:?}", v),
+ };
+
+ let centroids: Vec<_> = match &state[5] {
+ ScalarValue::List(Some(c), d) if **d == DataType::Float64 => c
+ .chunks(2)
+ .map(|v| Centroid::new(cast_scalar_f64!(v[0]),
cast_scalar_f64!(v[1])))
+ .collect(),
+ v => panic!("invalid centroids type {:?}", v),
+ };
+
+ let max = cast_scalar_f64!(&state[3]);
+ let min = cast_scalar_f64!(&state[4]);
+ assert!(max >= min);
+
+ Self {
+ max_size,
+ sum: cast_scalar_f64!(state[1]),
+ count: cast_scalar_f64!(&state[2]),
+ max,
+ min,
+ centroids,
+ }
+ }
+}
+
+#[cfg(debug_assertions)]
+fn is_sorted(values: &[OrderedFloat<f64>]) -> bool {
+ values.windows(2).all(|w| w[0] <= w[1])
+}
+
+#[cfg(test)]
+mod tests {
+ use std::iter;
+
+ use super::*;
+
+ // A macro to assert the specified `quantile` estimated by `t` is within
the
+ // allowable relative error bound.
+ macro_rules! assert_error_bounds {
+ ($t:ident, quantile = $quantile:literal, want = $want:literal) => {
+ assert_error_bounds!(
+ $t,
+ quantile = $quantile,
+ want = $want,
+ allowable_error = 0.01
+ )
+ };
+ ($t:ident, quantile = $quantile:literal, want = $want:literal,
allowable_error = $re:literal) => {
+ let ans = $t.estimate_quantile($quantile);
+ let expected: f64 = $want;
+ let percentage: f64 = (expected - ans).abs() / expected;
+ assert!(
+ percentage < $re,
+ "relative error {} is more than {}% (got quantile {}, want
{})",
+ percentage,
+ $re,
+ ans,
+ expected
+ );
+ };
+ }
+
+ macro_rules! assert_state_roundtrip {
+ ($t:ident) => {
+ let state = $t.to_scalar_state();
+ let other = TDigest::from_scalar_state(&state);
+ assert_eq!($t, other);
+ };
+ }
+
+ #[test]
+ fn test_int64_uniform() {
+ let values = (1i64..=1000).map(|v| ScalarValue::Int64(Some(v)));
+
+ let t = TDigest::new(100);
+ let t = t.merge_unsorted(values).unwrap();
+
+ assert_error_bounds!(t, quantile = 0.1, want = 100.0);
+ assert_error_bounds!(t, quantile = 0.5, want = 500.0);
+ assert_error_bounds!(t, quantile = 0.9, want = 900.0);
+ assert_state_roundtrip!(t);
+ }
+
+ #[test]
+ fn test_int64_uniform_with_nulls() {
+ let values = (1i64..=1000).map(|v| ScalarValue::Int64(Some(v)));
+ // Prepend some NULLs
+ let values = iter::repeat(ScalarValue::Int64(None))
+ .take(10)
+ .chain(values);
+ // Append some more NULLs
+ let values =
values.chain(iter::repeat(ScalarValue::Int64(None)).take(10));
+
+ let t = TDigest::new(100);
+ let t = t.merge_unsorted(values).unwrap();
+
+ assert_error_bounds!(t, quantile = 0.1, want = 100.0);
+ assert_error_bounds!(t, quantile = 0.5, want = 500.0);
+ assert_error_bounds!(t, quantile = 0.9, want = 900.0);
+ assert_state_roundtrip!(t);
+ }
+
+ #[test]
+ fn test_centroid_addition_regression() {
+ //https://github.com/MnO2/t-digest/pull/1
+
+ let vals = vec![1.0, 1.0, 1.0, 2.0, 1.0, 1.0];
+ let mut t = TDigest::new(10);
+
+ for v in vals {
+ t = t.merge_unsorted([ScalarValue::Float64(Some(v))]).unwrap();
+ }
+
+ assert_error_bounds!(t, quantile = 0.5, want = 1.0);
+ assert_error_bounds!(t, quantile = 0.95, want = 2.0);
+ assert_state_roundtrip!(t);
+ }
+
+ #[test]
+ fn test_merge_unsorted_against_uniform_distro() {
+ let t = TDigest::new(100);
+ let values: Vec<_> = (1..=1_000_000)
+ .map(f64::from)
+ .map(|v| ScalarValue::Float64(Some(v)))
+ .collect();
+
+ let t = t.merge_unsorted(values).unwrap();
+
+ assert_error_bounds!(t, quantile = 1.0, want = 1_000_000.0);
+ assert_error_bounds!(t, quantile = 0.99, want = 990_000.0);
+ assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
+ assert_error_bounds!(t, quantile = 0.0, want = 1.0);
+ assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
+ assert_state_roundtrip!(t);
+ }
+
+ #[test]
+ fn test_merge_unsorted_against_skewed_distro() {
+ let t = TDigest::new(100);
+ let mut values: Vec<_> = (1..=600_000)
+ .map(f64::from)
+ .map(|v| ScalarValue::Float64(Some(v)))
+ .collect();
+ for _ in 0..400_000 {
+ values.push(ScalarValue::Float64(Some(1_000_000.0)));
+ }
+
+ let t = t.merge_unsorted(values).unwrap();
+
+ assert_error_bounds!(t, quantile = 0.99, want = 1_000_000.0);
+ assert_error_bounds!(t, quantile = 0.01, want = 10_000.0);
+ assert_error_bounds!(t, quantile = 0.5, want = 500_000.0);
+ assert_state_roundtrip!(t);
+ }
+
+ #[test]
+ fn test_merge_digests() {
+ let mut digests: Vec<TDigest> = Vec::new();
+
+ for _ in 1..=100 {
+ let t = TDigest::new(100);
+ let values: Vec<_> = (1..=1_000)
+ .map(f64::from)
+ .map(|v| ScalarValue::Float64(Some(v)))
+ .collect();
+ let t = t.merge_unsorted(values).unwrap();
+ digests.push(t)
+ }
+
+ let t = TDigest::merge_digests(&digests);
+
+ assert_error_bounds!(t, quantile = 1.0, want = 1000.0);
+ assert_error_bounds!(t, quantile = 0.99, want = 990.0);
+ assert_error_bounds!(t, quantile = 0.01, want = 10.0, allowable_error
= 0.2);
+ assert_error_bounds!(t, quantile = 0.0, want = 1.0);
+ assert_error_bounds!(t, quantile = 0.5, want = 500.0);
+ assert_state_roundtrip!(t);
+ }
+}
diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs
index abc7582..0aff006 100644
--- a/datafusion/src/prelude.rs
+++ b/datafusion/src/prelude.rs
@@ -30,10 +30,10 @@ pub use crate::execution::context::{ExecutionConfig,
ExecutionContext};
pub use crate::execution::options::AvroReadOptions;
pub use crate::execution::options::{CsvReadOptions, NdJsonReadOptions};
pub use crate::logical_plan::{
- array, ascii, avg, bit_length, btrim, character_length, chr, col, concat,
concat_ws,
- count, create_udf, date_part, date_trunc, digest, in_list, initcap, left,
length,
- lit, lower, lpad, ltrim, max, md5, min, now, octet_length, random,
regexp_match,
- regexp_replace, repeat, replace, reverse, right, rpad, rtrim, sha224,
sha256, sha384,
- sha512, split_part, starts_with, strpos, substr, sum, to_hex, translate,
trim, upper,
- Column, JoinType, Partitioning,
+ approx_percentile_cont, array, ascii, avg, bit_length, btrim,
character_length, chr,
+ col, concat, concat_ws, count, create_udf, date_part, date_trunc, digest,
in_list,
+ initcap, left, length, lit, lower, lpad, ltrim, max, md5, min, now,
octet_length,
+ random, regexp_match, regexp_replace, repeat, replace, reverse, right,
rpad, rtrim,
+ sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr,
sum, to_hex,
+ translate, trim, upper, Column, JoinType, Partitioning,
};
diff --git a/datafusion/tests/dataframe_functions.rs
b/datafusion/tests/dataframe_functions.rs
index b8efc98..d5118b3 100644
--- a/datafusion/tests/dataframe_functions.rs
+++ b/datafusion/tests/dataframe_functions.rs
@@ -154,6 +154,26 @@ async fn test_fn_btrim_with_chars() -> Result<()> {
}
#[tokio::test]
+async fn test_fn_approx_percentile_cont() -> Result<()> {
+ let expr = approx_percentile_cont(col("b"), lit(0.5));
+
+ let expected = vec![
+ "+-------------------------------------------+",
+ "| APPROXPERCENTILECONT(test.b,Float64(0.5)) |",
+ "+-------------------------------------------+",
+ "| 10 |",
+ "+-------------------------------------------+",
+ ];
+
+ let df = create_test_table()?;
+ let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?;
+
+ assert_batches_eq!(expected, &batches);
+
+ Ok(())
+}
+
+#[tokio::test]
async fn test_fn_character_length() -> Result<()> {
let expr = character_length(col("a"));
diff --git a/datafusion/tests/sql/aggregates.rs
b/datafusion/tests/sql/aggregates.rs
index 2d42870..a025d4e 100644
--- a/datafusion/tests/sql/aggregates.rs
+++ b/datafusion/tests/sql/aggregates.rs
@@ -354,6 +354,95 @@ async fn csv_query_approx_count() -> Result<()> {
Ok(())
}
+// This test executes the APPROX_PERCENTILE_CONT aggregation against the test
+// data, asserting the estimated quantiles are ±5% their actual values.
+//
+// Actual quantiles calculated with:
+//
+// ```r
+// read_csv("./testing/data/csv/aggregate_test_100.csv") |>
+// select_if(is.numeric) |>
+// summarise_all(~ quantile(., c(0.1, 0.5, 0.9)))
+// ```
+//
+// Giving:
+//
+// ```text
+// c2 c3 c4 c5 c6 c7 c8 c9
c10 c11 c12
+// <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
<dbl> <dbl> <dbl>
+// 1 1 -95.3 -22925. -1882606710 -7.25e18 18.9 2671. 472608672.
1.83e18 0.109 0.0714
+// 2 3 15.5 4599 377164262 1.13e18 134. 30634 2365817608.
9.30e18 0.491 0.551
+// 3 5 102. 25334. 1991374996. 7.37e18 231 57518. 3776538487.
1.61e19 0.834 0.946
+// ```
+//
+// Column `c12` is omitted due to a large relative error (~10%) due to the
small
+// float values.
+#[tokio::test]
+async fn csv_query_approx_percentile_cont() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx).await?;
+
+ // Generate an assertion that the estimated $percentile value for $column
is
+ // within 5% of the $actual percentile value.
+ macro_rules! percentile_test {
+ ($ctx:ident, column=$column:literal, percentile=$percentile:literal,
actual=$actual:literal) => {
+ let sql = format!("SELECT (ABS(1 - CAST(approx_percentile_cont({},
{}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column,
$percentile, $actual);
+ let actual = execute_to_batches(&mut ctx, &sql).await;
+ //
+ // "+------+",
+ // "| q |",
+ // "+------+",
+ // "| true |",
+ // "+------+",
+ //
+ let want = ["+------+", "| q |", "+------+", "| true |",
"+------+"];
+ assert_batches_eq!(want, &actual);
+ };
+ }
+
+ percentile_test!(ctx, column = "c2", percentile = 0.1, actual = 1.0);
+ percentile_test!(ctx, column = "c2", percentile = 0.5, actual = 3.0);
+ percentile_test!(ctx, column = "c2", percentile = 0.9, actual = 5.0);
+ ////////////////////////////////////
+ percentile_test!(ctx, column = "c3", percentile = 0.1, actual = -95.3);
+ percentile_test!(ctx, column = "c3", percentile = 0.5, actual = 15.5);
+ percentile_test!(ctx, column = "c3", percentile = 0.9, actual = 102.0);
+ ////////////////////////////////////
+ percentile_test!(ctx, column = "c4", percentile = 0.1, actual = -22925.0);
+ percentile_test!(ctx, column = "c4", percentile = 0.5, actual = 4599.0);
+ percentile_test!(ctx, column = "c4", percentile = 0.9, actual = 25334.0);
+ ////////////////////////////////////
+ percentile_test!(ctx, column = "c5", percentile = 0.1, actual =
-1882606710.0);
+ percentile_test!(ctx, column = "c5", percentile = 0.5, actual =
377164262.0);
+ percentile_test!(ctx, column = "c5", percentile = 0.9, actual =
1991374996.0);
+ ////////////////////////////////////
+ percentile_test!(ctx, column = "c6", percentile = 0.1, actual = -7.25e18);
+ percentile_test!(ctx, column = "c6", percentile = 0.5, actual = 1.13e18);
+ percentile_test!(ctx, column = "c6", percentile = 0.9, actual = 7.37e18);
+ ////////////////////////////////////
+ percentile_test!(ctx, column = "c7", percentile = 0.1, actual = 18.9);
+ percentile_test!(ctx, column = "c7", percentile = 0.5, actual = 134.0);
+ percentile_test!(ctx, column = "c7", percentile = 0.9, actual = 231.0);
+ ////////////////////////////////////
+ percentile_test!(ctx, column = "c8", percentile = 0.1, actual = 2671.0);
+ percentile_test!(ctx, column = "c8", percentile = 0.5, actual = 30634.0);
+ percentile_test!(ctx, column = "c8", percentile = 0.9, actual = 57518.0);
+ ////////////////////////////////////
+ percentile_test!(ctx, column = "c9", percentile = 0.1, actual =
472608672.0);
+ percentile_test!(ctx, column = "c9", percentile = 0.5, actual =
2365817608.0);
+ percentile_test!(ctx, column = "c9", percentile = 0.9, actual =
3776538487.0);
+ ////////////////////////////////////
+ percentile_test!(ctx, column = "c10", percentile = 0.1, actual = 1.83e18);
+ percentile_test!(ctx, column = "c10", percentile = 0.5, actual = 9.30e18);
+ percentile_test!(ctx, column = "c10", percentile = 0.9, actual = 1.61e19);
+ ////////////////////////////////////
+ percentile_test!(ctx, column = "c11", percentile = 0.1, actual = 0.109);
+ percentile_test!(ctx, column = "c11", percentile = 0.5, actual = 0.491);
+ percentile_test!(ctx, column = "c11", percentile = 0.9, actual = 0.834);
+
+ Ok(())
+}
+
#[tokio::test]
async fn query_count_without_from() -> Result<()> {
let mut ctx = ExecutionContext::new();