This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 6b4bbd0fe `sum(distinct)` support (#2405)
6b4bbd0fe is described below
commit 6b4bbd0fe33bf840553346c0a2db554e013cd00d
Author: DuRipeng <[email protected]>
AuthorDate: Wed May 4 09:11:23 2022 +0800
`sum(distinct)` support (#2405)
* sum(distinct) support
* fix clippy
* merge state() code logic
* revise annotation
* remove u64->i63 coercion
---
datafusion/core/tests/sql/aggregates.rs | 57 ++++
datafusion/physical-expr/src/aggregate/build_in.rs | 10 +-
datafusion/physical-expr/src/aggregate/mod.rs | 1 +
datafusion/physical-expr/src/aggregate/sum.rs | 9 +
.../physical-expr/src/aggregate/sum_distinct.rs | 294 +++++++++++++++++++++
datafusion/physical-expr/src/expressions/mod.rs | 1 +
6 files changed, 367 insertions(+), 5 deletions(-)
diff --git a/datafusion/core/tests/sql/aggregates.rs
b/datafusion/core/tests/sql/aggregates.rs
index 9a6d2d64e..b488e880d 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -1235,6 +1235,63 @@ async fn simple_avg() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn query_sum_distinct() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("c1", DataType::Int64, true),
+ Field::new("c2", DataType::Int64, true),
+ ]));
+
+ let data = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(Int64Array::from(vec![
+ Some(0),
+ Some(1),
+ None,
+ Some(3),
+ Some(3),
+ ])),
+ Arc::new(Int64Array::from(vec![
+ None,
+ Some(1),
+ Some(1),
+ Some(2),
+ Some(2),
+ ])),
+ ],
+ )?;
+
+ let table = MemTable::try_new(schema, vec![vec![data]])?;
+ let ctx = SessionContext::new();
+ ctx.register_table("test", Arc::new(table))?;
+
+ // 2 different aggregate functions: avg and sum(distinct)
+ let sql = "SELECT AVG(c1), SUM(DISTINCT c2) FROM test";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+--------------+-----------------------+",
+ "| AVG(test.c1) | SUM(DISTINCT test.c2) |",
+ "+--------------+-----------------------+",
+ "| 1.75 | 3 |",
+ "+--------------+-----------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ // 2 sum(distinct) functions
+ let sql = "SELECT SUM(DISTINCT c1), SUM(DISTINCT c2) FROM test";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----------------------+-----------------------+",
+ "| SUM(DISTINCT test.c1) | SUM(DISTINCT test.c2) |",
+ "+-----------------------+-----------------------+",
+ "| 4 | 3 |",
+ "+-----------------------+-----------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
#[tokio::test]
async fn query_count_distinct() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32,
true)]));
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index f91e01336..784cac81b 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -87,11 +87,11 @@ pub fn create_aggregate_expr(
name,
return_type,
)),
- (AggregateFunction::Sum, true) => {
- return Err(DataFusionError::NotImplemented(
- "SUM(DISTINCT) aggregations are not available".to_string(),
- ));
- }
+ (AggregateFunction::Sum, true) =>
Arc::new(expressions::DistinctSum::new(
+ vec![coerced_phy_exprs[0].clone()],
+ name,
+ return_type,
+ )),
(AggregateFunction::ApproxDistinct, _) => {
Arc::new(expressions::ApproxDistinct::new(
coerced_phy_exprs[0].clone(),
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs
b/datafusion/physical-expr/src/aggregate/mod.rs
index 106087db5..ae8de14e2 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -42,6 +42,7 @@ mod hyperloglog;
pub(crate) mod stats;
pub(crate) mod stddev;
pub(crate) mod sum;
+pub(crate) mod sum_distinct;
mod tdigest;
pub(crate) mod variance;
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs
b/datafusion/physical-expr/src/aggregate/sum.rs
index 00404ee99..cca54733d 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -297,6 +297,15 @@ pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) ->
Result<ScalarValue> {
(ScalarValue::Int64(lhs), ScalarValue::Int8(rhs)) => {
typed_sum!(lhs, rhs, Int64, i64)
}
+ (ScalarValue::Int64(lhs), ScalarValue::UInt32(rhs)) => {
+ typed_sum!(lhs, rhs, Int64, i64)
+ }
+ (ScalarValue::Int64(lhs), ScalarValue::UInt16(rhs)) => {
+ typed_sum!(lhs, rhs, Int64, i64)
+ }
+ (ScalarValue::Int64(lhs), ScalarValue::UInt8(rhs)) => {
+ typed_sum!(lhs, rhs, Int64, i64)
+ }
e => {
return Err(DataFusionError::Internal(format!(
"Sum is not expected to receive a scalar {:?}",
diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs
b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
new file mode 100644
index 000000000..238722726
--- /dev/null
+++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs
@@ -0,0 +1,294 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::aggregate::sum;
+use crate::expressions::format_state_name;
+use arrow::datatypes::{DataType, Field};
+use std::any::Any;
+use std::fmt::Debug;
+use std::sync::Arc;
+
+use ahash::RandomState;
+use arrow::array::{Array, ArrayRef};
+use std::collections::HashSet;
+
+use crate::{AggregateExpr, PhysicalExpr};
+use datafusion_common::ScalarValue;
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::Accumulator;
+
+/// Expression for a SUM(DISTINCT) aggregation.
+#[derive(Debug)]
+pub struct DistinctSum {
+ /// Column name
+ name: String,
+ /// The DataType for the final sum
+ data_type: DataType,
+ /// The input arguments, only contains 1 item for sum
+ exprs: Vec<Arc<dyn PhysicalExpr>>,
+}
+
+impl DistinctSum {
+ /// Create a SUM(DISTINCT) aggregate function.
+ pub fn new(
+ exprs: Vec<Arc<dyn PhysicalExpr>>,
+ name: String,
+ data_type: DataType,
+ ) -> Self {
+ Self {
+ name,
+ data_type,
+ exprs,
+ }
+ }
+}
+
+impl AggregateExpr for DistinctSum {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(&self.name, self.data_type.clone(), true))
+ }
+
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ // State field is a List which stores items to rebuild hash set.
+ Ok(vec![Field::new(
+ &format_state_name(&self.name, "sum distinct"),
+ DataType::List(Box::new(Field::new("item", self.data_type.clone(),
true))),
+ false,
+ )])
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ self.exprs.clone()
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ Ok(Box::new(DistinctSumAccumulator::try_new(&self.data_type)?))
+ }
+}
+
+#[derive(Debug)]
+struct DistinctSumAccumulator {
+ hash_values: HashSet<ScalarValue, RandomState>,
+ data_type: DataType,
+}
+impl DistinctSumAccumulator {
+ pub fn try_new(data_type: &DataType) -> Result<Self> {
+ Ok(Self {
+ hash_values: HashSet::default(),
+ data_type: data_type.clone(),
+ })
+ }
+
+ fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
+ values.iter().for_each(|v| {
+ // If the value is NULL, it is not included in the final sum.
+ if !v.is_null() {
+ self.hash_values.insert(v.clone());
+ }
+ });
+
+ Ok(())
+ }
+
+ fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ }
+
+ states.iter().try_for_each(|state| match state {
+ ScalarValue::List(Some(values), _) => self.update(values.as_ref()),
+ _ => Err(DataFusionError::Internal(format!(
+ "Unexpected accumulator state {:?}",
+ state
+ ))),
+ })
+ }
+}
+
+impl Accumulator for DistinctSumAccumulator {
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ // 1. Stores aggregate state in `ScalarValue::List`
+ // 2. Constructs `ScalarValue::List` state from distinct numeric
stored in hash set
+ let state_out = {
+ let mut distinct_values = Box::new(Vec::new());
+ let data_type = Box::new(self.data_type.clone());
+ self.hash_values
+ .iter()
+ .for_each(|distinct_value|
distinct_values.push(distinct_value.clone()));
+ vec![ScalarValue::List(Some(distinct_values), data_type)]
+ };
+ Ok(state_out)
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ if values.is_empty() {
+ return Ok(());
+ }
+
+ let scalar_values = (0..values[0].len())
+ .map(|index| ScalarValue::try_from_array(&values[0], index))
+ .collect::<Result<Vec<_>>>()?;
+ self.update(&scalar_values)
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ if states.is_empty() {
+ return Ok(());
+ }
+
+ (0..states[0].len()).try_for_each(|index| {
+ let v = states
+ .iter()
+ .map(|array| ScalarValue::try_from_array(array, index))
+ .collect::<Result<Vec<_>>>()?;
+ self.merge(&v)
+ })
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ let mut sum_value = ScalarValue::try_from(&self.data_type)?;
+ self.hash_values.iter().for_each(|distinct_value| {
+ sum_value = sum::sum(&sum_value, distinct_value).unwrap()
+ });
+ Ok(sum_value)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::expressions::col;
+ use crate::expressions::tests::aggregate;
+ use arrow::record_batch::RecordBatch;
+ use arrow::{array::*, datatypes::*};
+ use datafusion_common::Result;
+
+ fn run_update_batch(
+ return_type: DataType,
+ arrays: &[ArrayRef],
+ ) -> Result<(Vec<ScalarValue>, ScalarValue)> {
+ let agg = DistinctSum::new(vec![], String::from("__col_name__"),
return_type);
+
+ let mut accum = agg.create_accumulator()?;
+ accum.update_batch(arrays)?;
+
+ Ok((accum.state()?, accum.evaluate()?))
+ }
+
+ macro_rules! generic_test_sum_distinct {
+ ($ARRAY:expr, $DATATYPE:expr, $EXPECTED:expr, $EXPECTED_DATATYPE:expr)
=> {{
+ let schema = Schema::new(vec![Field::new("a", $DATATYPE, false)]);
+
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![$ARRAY])?;
+
+ let agg = Arc::new(DistinctSum::new(
+ vec![col("a", &schema)?],
+ "count_distinct_a".to_string(),
+ $EXPECTED_DATATYPE,
+ ));
+ let actual = aggregate(&batch, agg)?;
+ let expected = ScalarValue::from($EXPECTED);
+
+ assert_eq!(expected, actual);
+
+ Ok(())
+ }};
+ }
+
+ #[test]
+ fn sum_distinct_update_batch() -> Result<()> {
+ let array_int64: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 3]));
+ let arrays = vec![array_int64];
+ let (states, result) = run_update_batch(DataType::Int64, &arrays)?;
+
+ assert_eq!(states.len(), 1);
+ assert_eq!(result, ScalarValue::Int64(Some(4)));
+
+ Ok(())
+ }
+
+ #[test]
+ fn sum_distinct_i32_with_nulls() -> Result<()> {
+ let array = Arc::new(Int32Array::from(vec![
+ Some(1),
+ Some(1),
+ None,
+ Some(2),
+ Some(2),
+ Some(3),
+ ]));
+ generic_test_sum_distinct!(
+ array,
+ DataType::Int32,
+ ScalarValue::from(6i64),
+ DataType::Int64
+ )
+ }
+
+ #[test]
+ fn sum_distinct_u32_with_nulls() -> Result<()> {
+ let array: ArrayRef = Arc::new(UInt32Array::from(vec![
+ Some(1_u32),
+ Some(1_u32),
+ Some(3_u32),
+ Some(3_u32),
+ None,
+ ]));
+ generic_test_sum_distinct!(
+ array,
+ DataType::UInt32,
+ ScalarValue::from(4i64),
+ DataType::Int64
+ )
+ }
+
+ #[test]
+ fn sum_distinct_f64() -> Result<()> {
+ let array: ArrayRef =
+ Arc::new(Float64Array::from(vec![1_f64, 1_f64, 3_f64, 3_f64,
3_f64]));
+ generic_test_sum_distinct!(
+ array,
+ DataType::Float64,
+ ScalarValue::from(4_f64),
+ DataType::Float64
+ )
+ }
+
+ #[test]
+ fn sum_distinct_decimal_with_nulls() -> Result<()> {
+ let array: ArrayRef = Arc::new(
+ (1..6)
+ .map(|i| if i == 2 { None } else { Some(i % 2) })
+ .collect::<DecimalArray>()
+ .with_precision_and_scale(35, 0)?,
+ );
+ generic_test_sum_distinct!(
+ array,
+ DataType::Decimal(35, 0),
+ ScalarValue::Decimal128(Some(1), 38, 0),
+ DataType::Decimal(38, 0)
+ )
+ }
+}
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index 2cdceab6c..d081720b8 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -55,6 +55,7 @@ pub use crate::aggregate::min_max::{MaxAccumulator,
MinAccumulator};
pub use crate::aggregate::stats::StatsType;
pub use crate::aggregate::stddev::{Stddev, StddevPop};
pub use crate::aggregate::sum::Sum;
+pub use crate::aggregate::sum_distinct::DistinctSum;
pub use crate::aggregate::variance::{Variance, VariancePop};
pub use crate::window::cume_dist::cume_dist;